using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Microsoft.CodeAnalysis; using NodePipeline.Engine.CodeGeneration.Abstractions.Models; namespace NodePipeline.Engine.CodeGeneration.Abstractions; public static class NodeGeneratorHelper { public static string GetNodeNameCaseString(NodeDescriptor node) { var nodeNameString = node.IsNameDefinedInAttribute ? string.Concat("\"", node.Type, "\"") : $"var t when t == NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\")"; return $"case {nodeNameString}:"; } public static string GetNodeName(NodeDescriptor node, bool addSemicolon = false) { var nodeNameString = node.IsNameDefinedInAttribute ? $"\"{node.Type}\"" : $"NodeNamePrefixSettings.TrimNodeType(\"{node.Type}\"){(addSemicolon ? ";" : string.Empty)}"; return nodeNameString; } public static IncrementalValueProvider> GetNodeSymbols( IncrementalGeneratorInitializationContext context) { var compilationProvider = context.CompilationProvider; return compilationProvider.Select((compilation, _) => { var results = ImmutableArray.CreateBuilder(); var nodeInterface = compilation .GetTypeByMetadataName("NodePipeline.Abstractions.Interfaces.Nodes.INode"); if (nodeInterface is null) return results.ToImmutable(); var visited = new HashSet(SymbolEqualityComparer.Default); foreach (var assemblySymbol in compilation.SourceModule.ReferencedAssemblySymbols.Concat([ compilation.Assembly ])) foreach (var type in GetAllTypesFromAssembly(assemblySymbol)) if (visited.Add(type) && type.TypeKind == TypeKind.Class && type.AllInterfaces.Contains(nodeInterface)) results.Add(type); return results.ToImmutable(); }); } private static IEnumerable GetAllTypesFromAssembly(IAssemblySymbol assembly) { foreach (var ns in GetNamespaces(assembly.GlobalNamespace)) foreach (var type in ns.GetTypeMembers()) foreach (var nested in GetNested(type)) yield return nested; } private static IEnumerable GetNamespaces(INamespaceSymbol ns) { yield return ns; foreach (var child in ns.GetNamespaceMembers()) foreach (var nested in GetNamespaces(child)) yield return nested; } private static IEnumerable GetNested(INamedTypeSymbol type) { yield return type; foreach (var nested in type.GetTypeMembers()) foreach (var deeper in GetNested(nested)) yield return deeper; } }