78 lines
2.9 KiB
C#
78 lines
2.9 KiB
C#
using System.Collections.Generic;
|
|
using System.Collections.Immutable;
|
|
using System.Linq;
|
|
using Microsoft.CodeAnalysis;
|
|
using NodePipeline.Engine.CodeGeneration.Abstractions.Models;
|
|
|
|
namespace NodePipeline.Engine.CodeGeneration.Abstractions;
|
|
|
|
public static class NodeGeneratorHelper
|
|
{
|
|
public static string GetNodeNameCaseString(NodeDescriptor node)
|
|
{
|
|
var nodeNameString = node.IsNameDefinedInAttribute
|
|
? string.Concat("\"", node.Type, "\"")
|
|
: $"var t when t == NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\")";
|
|
return $"case {nodeNameString}:";
|
|
}
|
|
|
|
public static string GetNodeName(NodeDescriptor node, bool addSemicolon = false)
|
|
{
|
|
var nodeNameString = node.IsNameDefinedInAttribute
|
|
? $"\"{node.Type}\""
|
|
: $"NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\"){(addSemicolon ? ";" : string.Empty)}";
|
|
return nodeNameString;
|
|
}
|
|
|
|
public static IncrementalValueProvider<ImmutableArray<INamedTypeSymbol>> GetNodeSymbols(
|
|
IncrementalGeneratorInitializationContext context)
|
|
{
|
|
var compilationProvider = context.CompilationProvider;
|
|
|
|
return compilationProvider.Select((compilation, _) =>
|
|
{
|
|
var results = ImmutableArray.CreateBuilder<INamedTypeSymbol>();
|
|
|
|
var nodeInterface = compilation
|
|
.GetTypeByMetadataName("NodePipeline.Abstractions.Interfaces.Nodes.INode");
|
|
if (nodeInterface is null) return results.ToImmutable();
|
|
|
|
var visited = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
|
|
foreach (var assemblySymbol in compilation.SourceModule.ReferencedAssemblySymbols.Concat([
|
|
compilation.Assembly
|
|
]))
|
|
foreach (var type in GetAllTypesFromAssembly(assemblySymbol))
|
|
if (visited.Add(type)
|
|
&& type.TypeKind == TypeKind.Class
|
|
&& type.AllInterfaces.Contains(nodeInterface))
|
|
results.Add(type);
|
|
|
|
return results.ToImmutable();
|
|
});
|
|
}
|
|
|
|
private static IEnumerable<INamedTypeSymbol> GetAllTypesFromAssembly(IAssemblySymbol assembly)
|
|
{
|
|
foreach (var ns in GetNamespaces(assembly.GlobalNamespace))
|
|
foreach (var type in ns.GetTypeMembers())
|
|
foreach (var nested in GetNested(type))
|
|
yield return nested;
|
|
}
|
|
|
|
private static IEnumerable<INamespaceSymbol> GetNamespaces(INamespaceSymbol ns)
|
|
{
|
|
yield return ns;
|
|
|
|
foreach (var child in ns.GetNamespaceMembers())
|
|
foreach (var nested in GetNamespaces(child))
|
|
yield return nested;
|
|
}
|
|
|
|
private static IEnumerable<INamedTypeSymbol> GetNested(INamedTypeSymbol type)
|
|
{
|
|
yield return type;
|
|
foreach (var nested in type.GetTypeMembers())
|
|
foreach (var deeper in GetNested(nested))
|
|
yield return deeper;
|
|
}
|
|
} |