Source code for tapestry.field

from __future__ import annotations

from enum import Enum
from typing import Any, TYPE_CHECKING, NoReturn
from dataclasses import dataclass
from pydantic.fields import FieldInfo
from surrealdb import RecordID


# Import at module level to avoid circular imports
if TYPE_CHECKING:
    from .node import Node
    from .edge import Edge


[docs] class Field[T]: """Represents a field in a SurrealDB model with type-safe comparisons."""
[docs] def __init__(self, name: str, model_class: str, field_type: type[T]): self.name = name self.model_class = model_class self.field_type = field_type
def __repr__(self): return f"{self.model_class}.{self.name}" # Comparison operators with proper type hints def __eq__(self, other: T) -> Condition: # type: ignore[override] return Condition(self, "==", other) def __ne__(self, other: T) -> Condition: # type: ignore[override] return Condition(self, "!=", other) def __lt__(self, other: T) -> Condition: return Condition(self, "<", other) def __le__(self, other: T) -> Condition: return Condition(self, "<=", other) def __gt__(self, other: T) -> Condition: return Condition(self, ">", other) def __ge__(self, other: T) -> Condition: return Condition(self, ">=", other)
[docs] def __matmul__(self, other: T) -> Condition: """Fulltext search operator using @ symbol.""" return Condition(self, "@", other)
[docs] class ComputedFieldDescriptor[T](Field[T]): """Descriptor for computed fields that delegates to Pydantic's property."""
[docs] def __init__(self, name: str, model_class: str, field_type: type[T], original_descriptor: property | None = None): super().__init__(name, model_class, field_type) self.original_descriptor = original_descriptor
def __get__(self, instance, owner): if instance is None: # When accessed on the class, return the Field for query building return self # When accessed on an instance, delegate to Pydantic's descriptor if self.original_descriptor: return self.original_descriptor.__get__(instance, owner) raise AttributeError(f"No descriptor for {self.name}") def __set__(self, instance, value): if self.original_descriptor and hasattr(self.original_descriptor, "__set__"): return self.original_descriptor.__set__(instance, value) raise AttributeError(f"can't set attribute {self.name}") def __delete__(self, instance): if self.original_descriptor and hasattr(self.original_descriptor, "__delete__"): return self.original_descriptor.__delete__(instance) raise AttributeError(f"can't delete attribute {self.name}")
[docs] def __matmul__(self, other: T) -> Condition: """Fulltext search operator using @ symbol.""" return Condition(self, "@", other)
[docs] class NestedField[T](Field[T]): """Field representing a nested path like role.entity."""
[docs] def __init__(self, path: str, parent_class: str, field_type: type[T]): self.path = path self.parent_class = parent_class self.field_type = field_type # Use the full path as the name super().__init__(path, parent_class, field_type)
def __repr__(self): return f"{self.parent_class}.{self.path}"
[docs] class NestedFieldDescriptor: """Descriptor for nested field access in queries."""
[docs] def __init__(self, field_name: str, parent_class: str, nested_class: type): self.field_name = field_name self.parent_class = parent_class self.nested_class = nested_class # Create field descriptors for the nested class if hasattr(nested_class, "model_fields"): for nested_field_name, nested_field_info in nested_class.model_fields.items(): nested_field_type = ( nested_field_info.annotation if isinstance(nested_field_info, FieldInfo) else nested_field_info.return_type ) # Create a custom field that represents the nested path nested_field = NestedField(f"{field_name}.{nested_field_name}", parent_class, nested_field_type) setattr(self, nested_field_name, nested_field)
[docs] def __eq__(self, other) -> Condition: # type: ignore[override] """Allow direct comparison on the nested field.""" field = Field(self.field_name, self.parent_class, self.nested_class) return Condition(field, "==", other)
def __ne__(self, other) -> Condition: # type: ignore[override] field = Field(self.field_name, self.parent_class, self.nested_class) return Condition(field, "!=", other)
[docs] def __matmul__(self, other) -> Condition: """Allow fulltext search on the nested field.""" field = Field(self.field_name, self.parent_class, self.nested_class) return Condition(field, "@", other)
[docs] class Direction(Enum): """Graph traversal direction.""" FORWARD = "->" BACKWARD = "<-" BIDIRECTIONAL = "<->"
[docs] @dataclass class Condition: """Represents a query condition.""" left: Any operator: str right: Any def __bool__(self) -> NoReturn: ...
[docs] def __and__(self, other: Condition) -> LogicalCondition: """Combine conditions with AND.""" return LogicalCondition(self, "AND", other)
[docs] def __or__(self, other: Condition) -> LogicalCondition: """Combine conditions with OR.""" return LogicalCondition(self, "OR", other)
[docs] def __invert__(self) -> NotCondition: """Negate the condition.""" return NotCondition(self)
[docs] def to_surreal(self) -> str: """Convert to SurrealQL syntax.""" left_str = self._format_value(self.left) right_str = self._format_value(self.right) # Handle different operator types if self.operator in ("=", "=="): # Use = for RecordID comparisons, IS for other values if isinstance(self.right, RecordID) or (hasattr(self.right, "id") and isinstance(self.right.id, RecordID)): op = "=" else: op = "IS" elif self.operator == "!=": if isinstance(self.right, RecordID) or (hasattr(self.right, "id") and isinstance(self.right.id, RecordID)): op = "!=" else: op = "IS NOT" elif self.operator == "@": # Fulltext search operator - SurrealDB uses @N@ syntax where N is the index position # We default to @0@ for the first fulltext index op = "@0@" else: op = self.operator return f"{left_str} {op} {right_str}"
def _format_value(self, value: Any) -> str: """Format a value for SurrealQL.""" if isinstance(value, Field): return value.name elif isinstance(value, RecordID): return f"{value.table_name}:{value.id}" elif hasattr(value, "id") and isinstance(value.id, RecordID): return f"{value.id.table_name}:{value.id.id}" elif isinstance(value, str): # Check if it's a field reference or a literal string if "." in value and not value.startswith('"'): return value # Field reference return f'"{value}"' # String literal elif isinstance(value, (int, float)): return str(value) elif isinstance(value, bool): return "true" if value else "false" elif value is None: return "NULL" elif hasattr(value, "__class__") and hasattr(value.__class__, "__mro__"): # Check if it's an Enum if Enum in value.__class__.__mro__: # Use the enum's value if it's a string, otherwise use its name if isinstance(value.value, str): return f'"{value.value}"' else: return f'"{value.name}"' return str(value)
[docs] @dataclass class LogicalCondition: """Represents a logical combination of conditions.""" left: Condition | LogicalCondition operator: str right: Condition | LogicalCondition
[docs] def __and__(self, other: Condition | LogicalCondition) -> LogicalCondition: """Combine with another condition using AND.""" return LogicalCondition(self, "AND", other)
[docs] def __or__(self, other: Condition | LogicalCondition) -> LogicalCondition: """Combine with another condition using OR.""" return LogicalCondition(self, "OR", other)
[docs] def to_surreal(self) -> str: """Convert to SurrealQL syntax.""" left_str = self.left.to_surreal() right_str = self.right.to_surreal() return f"({left_str} {self.operator} {right_str})"
[docs] @dataclass class NotCondition: """Represents a negated condition.""" condition: Condition | LogicalCondition
[docs] def to_surreal(self) -> str: """Convert to SurrealQL syntax.""" return f"NOT ({self.condition.to_surreal()})"
[docs] @dataclass class Traversal: """Represents a graph traversal operation.""" direction: Direction target: type[Node] | type[Edge] | str recursion_depth: int = 1 select_fields: str | None = None where_condition: Condition | LogicalCondition | None = None
[docs] def recurse(self, depth: int) -> Traversal: """Set recursion depth for this traversal.""" self.recursion_depth = depth return self
[docs] def select(self, fields: str = "*") -> Traversal: """Select specific fields from the traversal result.""" self.select_fields = fields return self
[docs] def where(self, condition: Condition | LogicalCondition) -> Traversal: """Add a WHERE condition to this traversal.""" self.where_condition = condition return self
[docs] def to_surreal(self) -> str: """Convert to SurrealQL syntax.""" # Import here to avoid circular imports at module level from .node import Node target_name = self._get_target_name() # Build the basic traversal with optional WHERE clause where_clause = f"[WHERE {self.where_condition.to_surreal()} ]" if self.where_condition is not None else "" if self.direction == Direction.BACKWARD: result = f"<-{target_name}{where_clause}" elif self.direction == Direction.FORWARD: result = f"->{target_name}{where_clause}" else: # BIDIRECTIONAL result = f"<->{target_name}{where_clause}" # Add recursion if needed if self.recursion_depth > 1: result = f".{{{self.recursion_depth}}}({result})" # Check if target is a Node class to determine field selection is_node = False if isinstance(self.target, type): try: is_node = issubclass(self.target, Node) except: pass # Add field selection or default .* for nodes if self.select_fields: result += f".{self.select_fields}" elif is_node: result += ".*" return result
def _get_target_name(self) -> str: """Get the name of the target table/edge.""" if isinstance(self.target, type): return self.target.__name__.lower() return str(self.target)