NodePipeline/NodePipeline.Engine.CodeGeneration.Abstractions/NodeGeneratorHelper.cs
2026-01-02 20:55:25 +03:00

78 lines
2.9 KiB
C#

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using NodePipeline.Engine.CodeGeneration.Abstractions.Models;
namespace NodePipeline.Engine.CodeGeneration.Abstractions;
public static class NodeGeneratorHelper
{
public static string GetNodeNameCaseString(NodeDescriptor node)
{
var nodeNameString = node.IsNameDefinedInAttribute
? string.Concat("\"", node.Type, "\"")
: $"var t when t == NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\")";
return $"case {nodeNameString}:";
}
public static string GetNodeName(NodeDescriptor node, bool addSemicolon = false)
{
var nodeNameString = node.IsNameDefinedInAttribute
? $"\"{node.Type}\""
: $"NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\"){(addSemicolon ? ";" : string.Empty)}";
return nodeNameString;
}
public static IncrementalValueProvider<ImmutableArray<INamedTypeSymbol>> GetNodeSymbols(
IncrementalGeneratorInitializationContext context)
{
var compilationProvider = context.CompilationProvider;
return compilationProvider.Select((compilation, _) =>
{
var results = ImmutableArray.CreateBuilder<INamedTypeSymbol>();
var nodeInterface = compilation
.GetTypeByMetadataName("NodePipeline.Abstractions.Interfaces.Nodes.INode");
if (nodeInterface is null) return results.ToImmutable();
var visited = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
foreach (var assemblySymbol in compilation.SourceModule.ReferencedAssemblySymbols.Concat([
compilation.Assembly
]))
foreach (var type in GetAllTypesFromAssembly(assemblySymbol))
if (visited.Add(type)
&& type.TypeKind == TypeKind.Class
&& type.AllInterfaces.Contains(nodeInterface))
results.Add(type);
return results.ToImmutable();
});
}
private static IEnumerable<INamedTypeSymbol> GetAllTypesFromAssembly(IAssemblySymbol assembly)
{
foreach (var ns in GetNamespaces(assembly.GlobalNamespace))
foreach (var type in ns.GetTypeMembers())
foreach (var nested in GetNested(type))
yield return nested;
}
private static IEnumerable<INamespaceSymbol> GetNamespaces(INamespaceSymbol ns)
{
yield return ns;
foreach (var child in ns.GetNamespaceMembers())
foreach (var nested in GetNamespaces(child))
yield return nested;
}
private static IEnumerable<INamedTypeSymbol> GetNested(INamedTypeSymbol type)
{
yield return type;
foreach (var nested in type.GetTypeMembers())
foreach (var deeper in GetNested(nested))
yield return deeper;
}
}