Source code for tapestry.table

import re
import keyword
import inspect
import types

from enum import Enum
from pydantic import BaseModel
from datetime import date, datetime
from surrealdb import Geometry, RecordID
from dataclasses import dataclass, field

# from pydantic_core import PydanticUndefined
from typing import get_origin, get_args, Literal, Any

from .field import Field
from .utils import flatten_type
from .tokenizer import Text, Tokenizer, FrenchTokenizer


[docs] @dataclass(frozen=True) class Indexed: """Opt-in index annotation for model fields. Use with ``typing.Annotated`` to request a database index on a field:: from typing import Annotated from tapestry import Node, Indexed class Person(Node): ext_id: Annotated[str | None, Indexed(unique=True)] = None party: Annotated[str | None, Indexed()] = None Attributes: unique: If True, the index enforces uniqueness. """ unique: bool = False
# This file is bloated, should be re-writen
[docs] class Reference[H]: """ Type annotation for defining foreign key references to other tables. Reference fields create relationships between tables in SurrealDB by storing RecordIDs that point to records in other tables. This is similar to foreign keys in traditional databases but leverages SurrealDB's graph capabilities. Type Parameters: H: The Node or Edge class that this field references Example: >>> from tapestry import Node, Reference >>> >>> class Person(Node): ... name: str ... email: str ... >>> class Article(Node): ... title: str ... author: Reference[Person] # References a Person record ... editor: Reference[Person] | None # Optional reference ... >>> # When creating an article >>> person = await Person(name="John Doe", email="john@example.com").create(db) >>> article = Article( ... title="Introduction to SurrealDB", ... author=person # Can assign the Person instance directly ... ) >>> await article.create(db) >>> >>> # The author field will store the Person's RecordID >>> print(article.author) # RecordID(table='person', id='...') Notes: - Reference fields store RecordIDs, not the full referenced object - Can reference any Node or Edge subclass - Supports optional references with Union[Reference[T], None] - When querying, references can be traversed using graph operators - Unlike Edge tables, References are unidirectional and don't create separate relation tables See Also: - Edge: For bidirectional relationships with properties - Node: Base class for tables that can be referenced """ pass
[docs] @dataclass(frozen=True) class Table: name: str base_class: type model_class: type fields: tuple[Field[Any], ...] name_to_type: dict[str, type] = field(default_factory=dict) relation: Link | None = None index: dict[str, str] = field(default_factory=dict) tokenizers: tuple[Tokenizer, ...] = () record_link_fields: tuple[str, ...] = () indexed_fields: dict[str, bool] = field(default_factory=dict) def __post_init__(self): indices = {} tokenizers = [] record_links = [] opt_in: dict[str, bool] = {} for field in self.fields: if field.name in ("id", "is_stub_"): continue if field.field_type: for tp in tuple(flatten_type(field.field_type, base_type=Text)): if inspect.isclass(tp) and issubclass(tp, Text): tokenizer = (get_args(tp) or (FrenchTokenizer,))[0] tokenizers.append(tokenizer) indices[field.name] = tokenizer.name for tp in tuple(flatten_type(field.field_type, base_type=Reference)): origin = get_origin(tp) if origin is Reference or (tp is Reference and not origin): args = get_args(tp) if not args or not issubclass(args[0], self.base_class): raise TypeError("References must point to a Node class") # Detect record link fields (foreign keys) if self._is_record_link_type(field.field_type): record_links.append(field.name) # Detect opt-in Indexed metadata from Pydantic FieldInfo if hasattr(self.model_class, "model_fields"): for fname, finfo in self.model_class.model_fields.items(): if hasattr(finfo, "metadata") and finfo.metadata: for meta in finfo.metadata: if isinstance(meta, Indexed): opt_in[fname] = meta.unique break object.__setattr__(self, "index", indices) object.__setattr__(self, "tokenizers", tuple(tokenizers)) object.__setattr__(self, "name_to_type", {f.name: f.field_type for f in self.fields}) object.__setattr__(self, "record_link_fields", tuple(record_links)) object.__setattr__(self, "indexed_fields", opt_in) def _is_record_link_type(self, tp: Any) -> bool: """Check whether *tp* resolves to a direct record link. Returns ``True`` when the type is (or contains via Union) a subclass of ``self.base_class``, but does **not** recurse into generic containers like ``list[...]`` or ``set[...]``. """ if tp is type(None): return False try: if inspect.isclass(tp) and issubclass(tp, self.base_class): return True except TypeError: return False origin = get_origin(tp) args = get_args(tp) if origin is types.UnionType or (origin is not None and hasattr(origin, "__name__") and origin.__name__ == "UnionType"): return any(self._is_record_link_type(a) for a in (args or ())) return False # -------------------- helpers to map python type -> surreal type ---------- @staticmethod def _quote_ident(name: str) -> str: if keyword.iskeyword(name) or not name.isidentifier(): return f'"{name}"' return name @staticmethod def _enum_to_literal_union(tp: type[Enum]) -> str: # parts = [] # for m in tp: # v = m.value # if isinstance(v, str): # parts.append(f'"{v}"') # else: # parts.append(f'"{m.name}"') # return " | ".join(parts) return " | ".join(f'"{m.name}"' for m in tp) @staticmethod def _literal_to_union(tp: Any) -> str: # Literal[...] -> produce Surreal literal union args = get_args(tp) parts = [] for a in args: if a is None: parts.append("NONE") elif isinstance(a, str): parts.append(f'"{a}"') elif isinstance(a, bool): parts.append("true" if a else "false") else: parts.append(str(a)) return " | ".join(parts) def _pytype_to_surreal(self, field_name: str, tp: Any, _skip_serializer_check: bool = False) -> str: """Map Python typing annotation tp -> Surreal type expression string.""" origin = get_origin(tp) or getattr(tp, "__origin__", None) args = get_args(tp) or getattr(tp, "__args__", ()) # Check for Pydantic model subclass first (before checking custom types) if inspect.isclass(tp) and issubclass(tp, self.base_class): return f"record<{tp.__name__.lower()}>" # Check for Pydantic model subclass first (before checking custom types) if inspect.isclass(tp) and issubclass(tp, Reference): print(tp.__annotations__) assert False # should I really make a REFERENCE here ? # return f"record<{tp.__name__.lower()}> REFERENCE" return f"record<{tp.__name__.lower()}> REFERENCE" if inspect.isclass(tp) and issubclass(tp, Text): return "string" # Check for field serializers on the parent class (skip if recursing from serializer) if not _skip_serializer_check and inspect.isclass(self.model_class) and issubclass(self.model_class, BaseModel): # Look for field_serializer method (format: serialize__{field_name}) serializer_name = f"serialize__{field_name}" if hasattr(self.model_class, serializer_name): serializer_method = getattr(self.model_class, serializer_name) if callable(serializer_method): sig = inspect.signature(serializer_method) if sig.return_annotation != inspect.Signature.empty: # Skip serializer check on recursive call to avoid infinite loop return self._pytype_to_surreal(field_name, sig.return_annotation, _skip_serializer_check=True) # Check if the type itself is a Pydantic BaseModel with a model_serializer if ( not _skip_serializer_check and inspect.isclass(tp) and issubclass(tp, BaseModel) and not issubclass(tp, self.base_class) ): # Look for model_serializer method for name in dir(tp): if not name.startswith("_"): attr = getattr(tp, name) if callable(attr) and hasattr(attr, "__func__"): # Check if this method has a model_serializer decorator # by checking if it has the right signature (self) -> ReturnType try: sig = inspect.signature(attr) # model_serializer methods typically have (self) -> ReturnType signature params = list(sig.parameters.values()) if ( len(params) == 1 and params[0].name == "self" and sig.return_annotation != inspect.Signature.empty ): # This looks like a model_serializer, use its return type # Skip serializer check on recursive call to avoid infinite loop return self._pytype_to_surreal(field_name, sig.return_annotation, _skip_serializer_check=True) except (ValueError, TypeError): continue if hasattr(tp, "__get_pydantic_core_schema__"): # Try to infer the underlying type from the core schema try: # Get the core schema (Pydantic v2) # Create a simple handler that returns the schema for basic types class SimpleHandler: def __call__(self, source_type): # Return a simple any schema for unknown types from pydantic_core import core_schema return core_schema.any_schema() def generate_schema(self, source_type): from pydantic_core import core_schema return core_schema.any_schema() @property def handler(self): return self handler = SimpleHandler() # Try to get schema schema = tp.__get_pydantic_core_schema__(tp, handler) # Analyze the schema to determine the underlying type return self._analyze_core_schema(schema) except Exception: pass # # Fallback: Check if it's a known SurrealDB type # # Check by class name or direct type comparison # if tp is Geometry or (hasattr(tp, '__name__') and 'Geometry' in tp.__name__): # return 'geometry<any>' # if tp is RecordID or (hasattr(tp, '__name__') and 'RecordID' in tp.__name__): # # Try to extract table name from type arguments if available # if hasattr(tp, '__args__') and tp.__args__: # table_name = tp.__args__[0] # if isinstance(table_name, str): # return f'record<{table_name}>' # return 'record<any>' # For other custom types, try to get the python type if available origin_attr = getattr(tp, "__origin__", None) if origin_attr is not None: return self._pytype_to_surreal(field_name, origin_attr) # Default fallback for custom Pydantic types return "any" # Optional / Union[..., None] if origin is types.UnionType or (origin is not None and hasattr(origin, "__name__") and origin.__name__ == "UnionType"): # if Optional[...] (Union with None) if any(a is type(None) for a in args): non_none = tuple(a for a in args if a is not type(None)) match non_none: case (): return "NONE" case (s,): inner = self._pytype_to_surreal(field_name, s, _skip_serializer_check=True) # Don't wrap 'any' in option<> since 'any' already includes None if inner == "any": return "any" return f"option<{inner}>" case many: others = " | ".join( self._pytype_to_surreal(field_name, some, _skip_serializer_check=True) for some in many ) # Don't wrap 'any' in option<> since 'any' already includes None if others == "any": return "any" return f"option<{others}>" # generic union -> join pieces union_parts = [self._pytype_to_surreal(field_name, a, _skip_serializer_check=True) for a in args] # Deduplicate and simplify unique_parts = [] seen = set() for part in union_parts: if part not in seen: seen.add(part) unique_parts.append(part) # If 'any' is in the union, just return 'any' if "any" in unique_parts: return "any" # If only one type after deduplication, return it directly if len(unique_parts) == 1: return unique_parts[0] return " | ".join(unique_parts) # Literal if origin is Literal: return self._literal_to_union(tp) # List / list[...] -> array<...> if origin is list: inner = args[0] if args else Any return f"array<{self._pytype_to_surreal(field_name, inner, _skip_serializer_check=True)}>" # Set / set[...] -> set<...> if origin is set: inner = args[0] if args else Any return f"set<{self._pytype_to_surreal(field_name, inner, _skip_serializer_check=True)}>" # Enum classes if inspect.isclass(tp) and issubclass(tp, Enum): return self._enum_to_literal_union(tp) # SurrealDB-specific types (check without needing __get_pydantic_core_schema__) if tp is Geometry: return "geometry<any>" if tp is RecordID: return "record<any>" # primitives if tp is str: return "string" if tp in (int, float): return "number" if tp is bool: return "bool" if tp in (date, datetime): return "datetime" # fallback return "any" @classmethod def _analyze_core_schema(cls, schema) -> str: """Analyze a Pydantic core schema object to determine the Surreal type.""" if schema is None: return "any" # Handle dict-based schemas if isinstance(schema, dict): schema_type = schema.get("type") # Check for function validators with wrapped schemas if schema_type in ("function-after", "function-before", "function-wrap", "with-info"): # Look for the wrapped/inner schema if "schema" in schema: return cls._analyze_core_schema(schema["schema"]) elif "inner" in schema: return cls._analyze_core_schema(schema["inner"]) # Basic type mappings if schema_type == "str": return "string" elif schema_type == "int": return "number" elif schema_type == "float": return "number" elif schema_type == "bool": return "bool" elif schema_type == "datetime": return "datetime" elif schema_type == "date": return "datetime" elif schema_type == "bytes": return "bytes" elif schema_type == "list": items_schema = schema.get("items_schema") if items_schema: inner = cls._analyze_core_schema(items_schema) return f"array<{inner}>" return "array<any>" elif schema_type == "dict": return "object" elif schema_type == "set": items_schema = schema.get("items_schema") if items_schema: inner = cls._analyze_core_schema(items_schema) return f"array<{inner}>" return "array<any>" elif schema_type == "tuple": return "array<any>" elif schema_type in ("union", "chain"): key = "choices" if schema_type == "union" else "steps" match sorted({cls._analyze_core_schema(step) for step in schema.get(key, [])}): case []: return "any" case [one]: return one case many: # filter any, otherwise it does not makes any sense return " | ".join(m for m in many if m != "any") elif schema_type == "nullable": inner = schema.get("schema") if inner: return f"option<{cls._analyze_core_schema(inner)}>" return "option<any>" elif schema_type == "none": return "NONE" elif schema_type == "literal": expected = schema.get("expected", []) if expected: value = expected[0] if value is None: return "NONE" elif isinstance(value, str): return f'"{value}"' elif isinstance(value, bool): return "true" if value else "false" else: return str(value) elif schema_type == "any": return "any" # Handle core_schema objects with attributes if hasattr(schema, "__dict__"): attrs = schema.__dict__ if "type" in attrs: # Reconstruct as dict and recurse return cls._analyze_core_schema(attrs) return "any" # -------------------- schema generation ---------------------------------
[docs] @staticmethod def target_name(tp): if inspect.isclass(tp) and issubclass(tp, BaseModel): return tp.__name__.lower() # fallback: use type __name__ or str(tp) return getattr(tp, "__name__", str(tp)).lower()
[docs] def generate_table_sql(self) -> str: stmts: list[str] = [] if self.relation: in_tbl = self._pytype_to_surreal("_in", self.relation._in) out_tbl = self._pytype_to_surreal("_out", self.relation._out) in_tbl = re.sub(r"record<([^>]+)>", r"\1", in_tbl) out_tbl = re.sub(r"record<([^>]+)>", r"\1", out_tbl) stmts.append(f"DEFINE TABLE IF NOT EXISTS {self.name} TYPE RELATION IN {in_tbl} OUT {out_tbl};") if self.relation.symetric: stmts.append(f"DEFINE FIELD key ON TABLE {self.name} VALUE <string>array::sort([in, out]);") stmts.append(f"DEFINE INDEX only_one_link ON TABLE {self.name} FIELDS key UNIQUE;") else: stmts.append(f"DEFINE TABLE IF NOT EXISTS {self.name} SCHEMAFULL;") for ft in self.fields: if ft.name == "id": # only define if not None continue if ft.name == "is_stub_": continue surreal_t = self._pytype_to_surreal(ft.name, ft.field_type) stmts.append(f"DEFINE FIELD {self._quote_ident(ft.name)} ON {self.name} TYPE {surreal_t};") # -- Indexes ----------------------------------------------------------- # Auto-index edge in/out fields if self.relation is not None: stmts.append(f"DEFINE INDEX idx_{self.name}_in ON TABLE {self.name} FIELDS in;") stmts.append(f"DEFINE INDEX idx_{self.name}_out ON TABLE {self.name} FIELDS out;") # Auto-index record link fields (foreign keys) for field_name in self.record_link_fields: stmts.append(f"DEFINE INDEX idx_{self.name}_{field_name} ON TABLE {self.name} FIELDS {field_name};") # Opt-in indexed fields for field_name, unique in self.indexed_fields.items(): unique_suffix = " UNIQUE" if unique else "" stmts.append(f"DEFINE INDEX idx_{self.name}_{field_name} ON TABLE {self.name} FIELDS {field_name}{unique_suffix};") # Fulltext search indexes (preserve existing naming convention) for field, analyzer in self.index.items(): stmts.append( f"DEFINE INDEX {field}_search ON TABLE {self.name} FIELDS {field} FULLTEXT ANALYZER {analyzer} BM25 HIGHLIGHTS;" ) return "\n".join(stmts)