NodePipeline/NodePipeline.Engine.CodeGeneration.Abstractions/NodeModelBuilder.cs
2026-01-02 20:55:25 +03:00

389 lines
16 KiB
C#

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using NodePipeline.Abstractions.Interfaces.Nodes;
using NodePipeline.Engine.Abstractions;
using NodePipeline.Engine.CodeGeneration.Abstractions.Models;
// using NodePipeline.Engine.CodeGeneration.Models;
namespace NodePipeline.Engine.CodeGeneration.Abstractions;
public static class NodeModelBuilder
{
public static (NodesModel Model, List<Diagnostic> Diagnostics) Build(ImmutableArray<INamedTypeSymbol> nodes)
{
var items = new List<NodeDescriptor>();
var diagnostics = new List<Diagnostic>();
foreach (var node in nodes)
{
var fields = new List<NodeFieldDescriptor>();
foreach (var prop in node.GetMembers().OfType<IPropertySymbol>())
{
if (prop.Type is not INamedTypeSymbol nf || !nf.Name.StartsWith("NodeField")) continue;
var nodeFieldAttr = prop.GetAttributes()
.FirstOrDefault(a => a.AttributeClass?.Name == "NodeFieldAttribute");
if (nodeFieldAttr is null) continue;
var descriptor = CreateFieldDescriptor(nodeFieldAttr, nf, prop, out var diagnostic);
if (diagnostic != null) diagnostics.Add(diagnostic);
fields.Add(descriptor);
}
var nodeValidators = GetNodeValidators(node);
var assemblyQualifiedName = node.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var typeNameShort = GetTypeNameShort(node, assemblyQualifiedName);
var nodeType = GetNodeType(node, typeNameShort, out var isNameDefinedInAttribute);
var hasParameterlessConstructor = node.Constructors.Any(c => c.Parameters.Length == 0);
items.Add(new NodeDescriptor(nodeType, isNameDefinedInAttribute, typeNameShort, assemblyQualifiedName,
fields, hasParameterlessConstructor, nodeValidators));
}
return (new NodesModel(items), diagnostics);
}
private static HashSet<ValidatorDescriptor> GetNodeFieldValidators(ImmutableArray<AttributeData> propAttributes)
{
var result = new HashSet<ValidatorDescriptor>();
foreach (var attr in propAttributes)
{
if (attr.AttributeClass?.Name != "HasValidatorAttribute" ||
attr.ConstructorArguments.Length <= 0) continue;
if (attr.ConstructorArguments[0].Value is not INamedTypeSymbol validatorType) continue;
var validatorTypeName = validatorType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var hasParameterlessCtor = validatorType.Constructors.Any(c =>
c.Parameters.Length == 0 && c.DeclaredAccessibility == Accessibility.Public);
result.Add(new ValidatorDescriptor(validatorTypeName, hasParameterlessCtor));
}
return result;
}
private static object? GetFieldDefaultValue(ImmutableArray<AttributeData> propAttributes,
ITypeSymbol fieldType, IPropertySymbol prop, out Diagnostic? diagnostic)
{
object? result = null;
var defaultValueAttr = propAttributes.FirstOrDefault(a =>
a.AttributeClass?.Name is "HasDefaultValueAttribute" or "DefaultValueAttribute");
if (defaultValueAttr is { ConstructorArguments.Length: > 0 })
{
var value = defaultValueAttr.ConstructorArguments[0].Value;
var typeArg = defaultValueAttr.ConstructorArguments.ElementAtOrDefault(1);
diagnostic = ValidateDefaultValueCompatibility(fieldType, value, typeArg, prop);
if (diagnostic != null) return null;
result = value;
}
diagnostic = null;
return result;
}
private static Diagnostic? ValidateDefaultValueCompatibility(ITypeSymbol fieldType, object? value,
TypedConstant? attributeTypeArg, IPropertySymbol prop)
{
if (value is null) return null;
if (attributeTypeArg is { Value: INamedTypeSymbol attrType })
{
if (!SymbolEqualityComparer.Default.Equals(fieldType, attrType))
return ReportDiagnostic(prop,
$"Default value type '{attrType.Name}' does not match field type '{fieldType.Name}'.");
return null;
}
switch (value)
{
case int:
{
if (!IsIntegerType(fieldType))
return ReportDiagnostic(prop,
$"Integer default value is not compatible with field type '{fieldType.Name}'.");
return null;
}
case decimal:
{
if (!IsDecimalCompatible(fieldType))
return ReportDiagnostic(prop,
$"Decimal default value is not compatible with field type '{fieldType.Name}'.");
return null;
}
case string:
{
if (fieldType.SpecialType != SpecialType.System_String)
return ReportDiagnostic(prop,
$"String default value is not compatible with field type '{fieldType.Name}'.");
return null;
}
}
var valueType = value.GetType().FullName;
if (!string.Equals(fieldType.ToDisplayString(), valueType, StringComparison.Ordinal))
return ReportDiagnostic(prop,
$"Default value type '{valueType}' does not match field type '{fieldType.Name}'.");
return null;
}
private static bool IsIntegerType(ITypeSymbol type)
{
return type.SpecialType is
SpecialType.System_Int32 or
SpecialType.System_Int64 or
SpecialType.System_Int16 or
SpecialType.System_Byte or
SpecialType.System_UInt32 or
SpecialType.System_UInt64 or
SpecialType.System_UInt16 or
SpecialType.System_SByte;
}
private static bool IsDecimalCompatible(ITypeSymbol type)
{
return type.SpecialType is
SpecialType.System_Decimal or
SpecialType.System_Double or
SpecialType.System_Single;
}
private static Diagnostic ReportDiagnostic(IPropertySymbol prop, string message)
{
var descriptor = new DiagnosticDescriptor(
"NP0001",
"Invalid default value",
message,
"NodePipeline",
DiagnosticSeverity.Error,
true);
var location = prop.Locations.FirstOrDefault();
return Diagnostic.Create(descriptor, location);
}
private static bool CanFieldDefaultValueBeInitialized(ITypeSymbol typeArg)
{
if (typeArg.IsValueType || typeArg.TypeKind == TypeKind.Enum) return true;
if (typeArg.IsAbstract || typeArg.TypeKind == TypeKind.Interface) return false;
var constructors = typeArg.GetMembers().OfType<IMethodSymbol>()
.Where(m => m.MethodKind == MethodKind.Constructor);
return constructors.Any(c =>
c.Parameters.Length == 0 && c.DeclaredAccessibility == Accessibility.Public);
}
private static (decimal? min, decimal? max) GetNumericBounds(ImmutableArray<AttributeData> propAttributes)
{
decimal? numberMinBound = null;
decimal? numberMaxBound = null;
var hasValueBetweenAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasValueBetweenAttribute");
if (hasValueBetweenAttr is { ConstructorArguments.Length: >= 2 })
{
var min = hasValueBetweenAttr.ConstructorArguments[0].Value as decimal? ??
Convert.ToDecimal(hasValueBetweenAttr.ConstructorArguments[0].Value);
var max = hasValueBetweenAttr.ConstructorArguments[1].Value as decimal? ??
Convert.ToDecimal(hasValueBetweenAttr.ConstructorArguments[1].Value);
numberMinBound = min;
numberMaxBound = max;
}
var hasMinValueAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasMinValueAttribute");
if (hasMinValueAttr is { ConstructorArguments.Length: > 0 })
{
var minVal = hasMinValueAttr.ConstructorArguments[0].Value as decimal? ??
Convert.ToDecimal(hasMinValueAttr.ConstructorArguments[0].Value);
if (!numberMinBound.HasValue || minVal > numberMinBound.Value)
numberMinBound = minVal;
}
var hasMaxValueAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasMaxValueAttribute");
if (hasMaxValueAttr is not { ConstructorArguments.Length: > 0 }) return (numberMinBound, numberMaxBound);
var maxVal = hasMaxValueAttr.ConstructorArguments[0].Value as decimal? ??
Convert.ToDecimal(hasMaxValueAttr.ConstructorArguments[0].Value);
if (!numberMaxBound.HasValue || maxVal < numberMaxBound.Value)
numberMaxBound = maxVal;
return (numberMinBound, numberMaxBound);
}
private static (int? min, int? max) GetStringLengthBounds(ImmutableArray<AttributeData> propAttributes)
{
int? stringMinLength = null;
int? stringMaxLength = null;
var hasLengthBetweenAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasLengthBetweenAttribute");
if (hasLengthBetweenAttr is { ConstructorArguments.Length: >= 2 })
{
var minLen = hasLengthBetweenAttr.ConstructorArguments[0].Value as int? ??
Convert.ToInt32(hasLengthBetweenAttr.ConstructorArguments[0].Value);
var maxLen = hasLengthBetweenAttr.ConstructorArguments[1].Value as int? ??
Convert.ToInt32(hasLengthBetweenAttr.ConstructorArguments[1].Value);
stringMinLength = minLen;
stringMaxLength = maxLen;
}
var hasMinLengthAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasMinLengthAttribute");
if (hasMinLengthAttr is { ConstructorArguments.Length: > 0 })
{
var minLenVal = hasMinLengthAttr.ConstructorArguments[0].Value as int? ??
Convert.ToInt32(hasMinLengthAttr.ConstructorArguments[0].Value);
if (!stringMinLength.HasValue || minLenVal > stringMinLength.Value)
stringMinLength = minLenVal;
}
var hasMaxLengthAttr =
propAttributes.FirstOrDefault(a => a.AttributeClass?.Name == "HasMaxLengthAttribute");
if (hasMaxLengthAttr is not { ConstructorArguments.Length: > 0 }) return (stringMinLength, stringMaxLength);
var maxLenVal = hasMaxLengthAttr.ConstructorArguments[0].Value as int? ??
Convert.ToInt32(hasMaxLengthAttr.ConstructorArguments[0].Value);
if (!stringMaxLength.HasValue || maxLenVal < stringMaxLength.Value)
stringMaxLength = maxLenVal;
return (stringMinLength, stringMaxLength);
}
private static NodeFieldDescriptor.NodeFieldMetaData CreateNodeFieldMetaData(ITypeSymbol typeArg,
IPropertySymbol prop, FieldDirection fieldDir, out Diagnostic? diagnostic)
{
var propAttributes = prop.GetAttributes();
var fieldValidators = GetNodeFieldValidators(propAttributes);
var defaultValue = GetFieldDefaultValue(propAttributes, typeArg, prop, out diagnostic);
var canDefaultValueBeInitialized = CanFieldDefaultValueBeInitialized(typeArg);
var isRequired = propAttributes.Any(a => a.AttributeClass?.Name == "RequiredAttribute");
var disallowNullableOutput = fieldDir == FieldDirection.Input &&
propAttributes.Any(a =>
a.AttributeClass?.Name == "DisallowNullableOutputAttribute");
var (numberMinBound, numberMaxBound) = GetNumericBounds(propAttributes);
var (stringMinLength, stringMaxLength) = GetStringLengthBounds(propAttributes);
return new NodeFieldDescriptor.NodeFieldMetaData(isRequired, disallowNullableOutput,
fieldValidators, defaultValue, canDefaultValueBeInitialized,
numberMinBound, numberMaxBound, stringMinLength, stringMaxLength);
}
private static NodeFieldDescriptor CreateFieldDescriptor(AttributeData nodeFieldAttr, INamedTypeSymbol nf,
IPropertySymbol prop, out Diagnostic? diagnostic)
{
var directionArg = nodeFieldAttr.ConstructorArguments[0].Value?.ToString();
var paramNameArg = nodeFieldAttr.ConstructorArguments.Length > 1
? nodeFieldAttr.ConstructorArguments[1].Value?.ToString()
: null;
var fieldDir = Enum.TryParse(directionArg, out FieldDirection dir) ? dir : FieldDirection.Parameter;
var typeArg = nf.TypeArguments[0];
var valueType = TrimSuffix(typeArg.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), "?");
var isNullableValueType =
typeArg.NullableAnnotation == NullableAnnotation.Annotated; //TODO: consider NotAnnotated
var isValueReferenceType = !typeArg.IsValueType && typeArg.TypeKind != TypeKind.Enum;
var nodeFieldMetaData = CreateNodeFieldMetaData(typeArg, prop, fieldDir, out diagnostic);
return new NodeFieldDescriptor(prop.Name,
!string.IsNullOrWhiteSpace(paramNameArg) ? paramNameArg! : prop.Name, valueType,
isNullableValueType, isValueReferenceType, typeArg.TypeKind == TypeKind.Enum, fieldDir,
nodeFieldMetaData);
}
private static HashSet<ValidatorDescriptor> GetNodeValidators(INamedTypeSymbol node)
{
var result = new HashSet<ValidatorDescriptor>();
foreach (var attr in node.GetAttributes())
{
if (attr.AttributeClass?.Name != "HasNodeValidatorAttribute" ||
attr.ConstructorArguments.Length <= 0) continue;
if (attr.ConstructorArguments[0].Value is not INamedTypeSymbol validatorType) continue;
var validatorTypeName = validatorType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var hasParameterlessCtor = validatorType.Constructors.Any(c =>
c.Parameters.Length == 0 && c.DeclaredAccessibility == Accessibility.Public);
result.Add(new ValidatorDescriptor(validatorTypeName, hasParameterlessCtor));
}
return result;
}
private static string GetTypeNameShort(INamedTypeSymbol node, string assemblyQualifiedName)
{
var assemblyName = node.ContainingAssembly.Identity.GetDisplayName();
var fullName = TrimPrefix(assemblyQualifiedName, assemblyName);
return TrimPrefix(fullName, "global::");
}
private static string GetNodeType(INamedTypeSymbol node, string typeNameShort, out bool isNameDefinedInAttribute)
{
var typeAttr = node.GetAttributes()
.FirstOrDefault(a => a.AttributeClass?.Name == "NodeTypeAttribute");
return GetNodeType(typeAttr, typeNameShort, out isNameDefinedInAttribute);
}
private static string GetNodeType(AttributeData? typeAttr, string typeNameShort, out bool isNameDefinedInAttribute)
{
const string suffix = "Node";
var nameFromAttribute = typeAttr?.ConstructorArguments[0].Value?.ToString();
if (!string.IsNullOrWhiteSpace(nameFromAttribute))
{
isNameDefinedInAttribute = true;
return nameFromAttribute!;
}
isNameDefinedInAttribute = false;
foreach (var prefix in NodeNamePrefixSettings.Prefixes)
if (typeNameShort.StartsWith(prefix, StringComparison.Ordinal))
return TrimSuffix(typeNameShort.Substring(prefix.Length), suffix);
return TrimSuffix(typeNameShort, suffix);
}
private static string TrimSuffix(string name, string suffix)
{
if (string.IsNullOrEmpty(name) || string.IsNullOrEmpty(suffix)
|| name.Length < suffix.Length)
return name;
return name.EndsWith(suffix, StringComparison.Ordinal)
? name.Substring(0, name.Length - suffix.Length)
: name;
}
private static string TrimPrefix(string name, string prefix)
{
if (string.IsNullOrEmpty(name) || string.IsNullOrEmpty(prefix)
|| name.Length < prefix.Length)
return name;
return name.StartsWith(prefix, StringComparison.Ordinal)
? name.Substring(prefix.Length, name.Length - prefix.Length)
: name;
}
public sealed class NodesModel(IReadOnlyList<NodeDescriptor> nodes)
{
public IReadOnlyList<NodeDescriptor> Nodes { get; } = nodes ?? throw new ArgumentNullException(nameof(nodes));
}
}