from __future__ import annotations
from collections.abc import Mapping
import inspect
import types
from enum import Enum
from copy import deepcopy
from itertools import chain
from queue import LifoQueue
from surrealdb import RecordID
from typing import Any, ClassVar, Unpack, get_origin, get_args
from pydantic_core.core_schema import ValidationInfo
from pydantic.fields import FieldInfo, ComputedFieldInfo
from pydantic import BaseModel, field_validator, model_serializer, Field, ConfigDict
from .table import Table
from .utils import convert_types, replace_type
from .field import Field as CustomField, NestedFieldDescriptor, ComputedFieldDescriptor
# Registry base class -------------------------------------------------------
[docs]
class Base(BaseModel):
"""
Base class for creating SurrealDB table models with Pydantic.
The Base class provides the foundation for defining SurrealDB tables as Python classes
using Pydantic models. All table models should inherit from this class, which handles
automatic registration, schema generation, and serialization/deserialization.
Attributes:
id (RecordID | None): The SurrealDB record ID. Automatically assigned when records
are created in the database. Can be None for new records.
Class Attributes:
_registry: Tuple of all registered tables in the system
_tokenizers: Set of tokenizer definitions for full-text search
_to_create: Queue for chaining object creations
child_classes: Dictionary mapping table names to their model classes
Example:
>>> from tapestry import Base
>>> from datetime import date
>>>
>>> class Person(Base):
... first_name: str
... last_name: str
... date_of_birth: date
...
>>> # Person is automatically registered and can generate SurrealDB schema
>>> schema = Base.generate_schema()
Notes:
- Subclasses are automatically registered upon definition
- Field types are automatically mapped to SurrealDB types
- Supports relationships through Reference fields
- Handles enum serialization automatically
- Provides full-text search capabilities with Text fields
"""
id: RecordID | None = Field(exclude=False, default=None)
model_config: ClassVar[ConfigDict] = ConfigDict(
arbitrary_types_allowed=True,
use_enum_values=False, # would be great to have a 'use_enum_names'
validate_assignment=True,
)
_registry: ClassVar[tuple[Table, ...]] = ()
_tokenizers: ClassVar[set[str]] = set()
_to_create: ClassVar[LifoQueue[Base]]
_discarded: ClassVar[bool] = False
child_classes: ClassVar[dict[str, type[Base]]] = {}
[docs]
@classmethod
@field_validator("id")
def validate_record_id(cls, v: Any) -> RecordID | None:
"""
Validate and convert record ID values.
Args:
v: The value to validate (can be None, RecordID, or string)
Returns:
Optional[RecordID]: A valid RecordID or None
Notes:
Automatically converts string IDs to RecordID instances with the
appropriate table name.
"""
if v is None:
return None
if isinstance(v, RecordID):
return v
# Convert string to RecordID if needed
return RecordID(cls.__name__.lower(), v)
@model_serializer(mode="wrap")
def _serialize(self, serializer, info):
"""
Custom serializer for handling record references.
When serializing nested objects, this method ensures that only the record ID
is serialized for referenced records, not the entire object.
Args:
serializer: The default Pydantic serializer
info: Serialization context information
Returns:
Serialized representation of the object or just its ID for references
Raises:
Exception: If attempting to reference a record that hasn't been created yet
"""
if not info.context:
return serializer(self)
if info.context.get("root"):
info.context["root"] = False
return serializer(self)
else:
if self.id is None:
raise Exception(f"You should create your record before referecing to it : create {self}")
# self._to_create.put(self)
# maybe check here in context if we are inserting in db
return self.id
[docs]
@field_validator("*", mode="before")
@classmethod
def validate_enums(cls, v: Any, info: ValidationInfo) -> Any:
"""
Universal validator for enum fields and record references.
Automatically handles:
- Converting enum names (strings) to enum instances
- Validating enum values
- Processing RecordID references
Args:
v: The value to validate
info: Validation context with field information
Returns:
The validated/converted value
"""
constructor = cls._table.name_to_type.get(info.field_name)
if inspect.isclass(constructor) and issubclass(constructor, Enum):
if isinstance(v, constructor):
return v
name = constructor[v]
return name
# if isinstance(v, RecordID) and not info.field_name == "id":
# try:
# constructor = cls.child_classes[v.table_name]
# except KeyError:
# return v
# values = {key: None for key in constructor.model_fields}
# values["id"] = v
# stub = constructor.model_construct(**values)
# # def _frozen_setattr(self, name, value):
# # raise AttributeError(f"Instance {type(self).__name__} is frozen; cannot set {name!r}")
# # stub.__setattr__ = MethodType(_frozen_setattr, stub)
# print("stub : ", stub, constructor)
# return stub
return v
def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]):
discarded = kwargs.pop("discarded", False)
# Create unique queue for each subclass
# this queue will be used to chain several object creations
cls._to_create = LifoQueue()
cls._discarded = bool(discarded)
return super().__init_subclass__(**kwargs)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
relation = kwargs.pop("relation", None)
super().__pydantic_init_subclass__(**kwargs)
if cls._discarded:
return
model_fields_copy: dict[str, FieldInfo | ComputedFieldInfo] = dict(deepcopy(cls.model_fields))
model_fields_copy.update(cls.model_computed_fields) # type: ignore[arg-type]
model_fields: Mapping[str, FieldInfo | ComputedFieldInfo] = model_fields_copy
fields: list[CustomField[Any]] = []
for field_name, field_info in model_fields.items():
field_type = field_info.annotation if isinstance(field_info, FieldInfo) else field_info.return_type
if isinstance(field_info, ComputedFieldInfo):
field_descriptor = ComputedFieldDescriptor[field_type](
field_name, cls.__name__.lower(), field_type or Any, getattr(cls, field_name, None)
)
else:
field_descriptor = CustomField[field_type](field_name, cls.__name__.lower(), field_type or Any)
fields.append(field_descriptor)
# Handle different field types
nested_class = None
try:
# Check if it's a Base subclass directly
if field_type and hasattr(field_type, "__mro__") and Base in field_type.__mro__:
nested_class = field_type
# Check if it's Optional[Base subclass] (Union with None)
elif get_origin(field_type) is types.UnionType or (
get_origin(field_type) is not None
and hasattr(get_origin(field_type), "__name__")
and get_origin(field_type).__name__ == "UnionType"
):
# Get the args of the Union
args = get_args(field_type)
# Filter out None and check if any remaining type is a Base subclass
for arg in args:
if arg is not type(None) and hasattr(arg, "__mro__") and Base in arg.__mro__:
nested_class = arg
break
if nested_class:
# Create a nested field descriptor that allows chaining
nested_descriptor = NestedFieldDescriptor(field_name, cls.__name__.lower(), nested_class)
setattr(cls, field_name, nested_descriptor)
else:
setattr(cls, field_name, field_descriptor)
except (TypeError, AttributeError):
# For other complex types, just use the field descriptor
setattr(cls, field_name, field_descriptor)
# replace Base instance with Union[annotation, RecordID]
if field_name in cls.model_fields:
cls.model_fields[field_name].annotation = replace_type(
cls.model_fields[field_name].annotation, Base, Base | RecordID
)
table = Table(name=cls.__name__.lower(), base_class=Base, model_class=cls, fields=tuple(fields), relation=relation)
for tokenizer in table.tokenizers:
Base._tokenizers.add(tokenizer.define())
cls._table = table
# I want to keep _registry as immutable as possible and only append data to it
Base.add_table(table, cls)
cls.model_rebuild(force=True)
[docs]
@classmethod
def registered_tables(cls) -> set[str]:
"""
Get the names of all registered tables.
Returns:
set[str]: Set of table names that have been registered
Example:
>>> tables = Base.registered_tables()
>>> print(tables)
{'person', 'entity', 'role', ...}
"""
return {f.name for f in cls._registry}
[docs]
@classmethod
def add_table(cls, table: Table, child_class: type[Base]):
"""
Register a new table in the system registry.
Args:
table: The Table definition to register
child_class: The model class associated with the table
Notes:
This is called automatically when subclasses are defined.
Users typically don't need to call this directly.
"""
cls._registry = tuple(t for t in chain(cls._registry, (table,)))
cls.child_classes[table.name] = child_class
[docs]
@classmethod
def registered_models(cls) -> list[Table]:
"""
Get all registered table definitions.
Returns:
list[Table]: List of Table objects that have been registered
Example:
>>> models = Base.registered_models()
>>> for model in models:
... print(f"Table: {model.name}")
"""
return list(cls._registry)
[docs]
@classmethod
def generate_schema(cls) -> str:
"""
Generate complete SurrealQL schema for all registered tables.
Creates the SQL statements needed to define all tables, fields, indexes,
and tokenizers in SurrealDB. This should be executed when setting up
a new database or updating the schema.
Returns:
str: Complete SurrealQL schema definition
Example:
>>> async with AsyncSurreal(url) as db:
... await db.signin({"username": "root", "password": "root"})
... await db.use("mydb", "myns")
... schema = Base.generate_schema()
... await db.query(schema)
Notes:
- Includes table definitions with SCHEMAFULL
- Defines all fields with proper types
- Sets up full-text search indexes
- Configures tokenizers for text analysis
- Creates relationship constraints
"""
blocks = [t.generate_table_sql() for t in cls._registry]
return "\n\n".join(chain(cls._tokenizers, blocks))
[docs]
@classmethod
def deserialize_record(cls, data: dict) -> Any:
"""
Deserialize a SurrealDB record into a Pydantic model instance.
Automatically converts SurrealDB records to the appropriate model class
based on the record's table name.
Args:
data: Dictionary containing the record data from SurrealDB
Returns:
An instance of the appropriate model class, or the original data
if no matching model is found
Notes:
- Handles edge records by converting ``in`` and ``out`` to ``in_`` and ``out_``
- Automatically determines the model class from the record ID
- Validates data using Pydantic validation
"""
if not isinstance(data, dict):
return data
if "in" in data and "out" in data:
data["in_"] = data.pop("in", None)
data["out_"] = data.pop("out", None)
# Get the model class from the record ID
record_id = data.get("id")
if not record_id or not isinstance(record_id, RecordID):
return data
model_class = cls.child_classes.get(record_id.table_name)
if not model_class:
return data
# The model validator will handle RecordID conversion
# here maybe specify if the users wants related records as ids or full objects
return model_class.model_validate(data, context="could this parameter be of any use ?")
[docs]
@classmethod
def deserialize_response(cls, response: Any) -> Any:
"""
Deserialize a complete SurrealDB response.
Recursively processes responses to convert all records to their
appropriate model instances.
Args:
response: The response from SurrealDB (can be list, dict, or primitive)
Returns:
The deserialized response with records converted to model instances
Example:
>>> result = await db.select("person")
>>> people = Base.deserialize_response(result)
>>> # people is now a list of Person instances
"""
if isinstance(response, list):
return [cls.deserialize_record(item) if isinstance(item, dict) else item for item in response]
elif isinstance(response, dict):
return cls.deserialize_record(response)
return response
[docs]
def db_dump(self) -> dict[str, Any]:
"""
Serialize the model instance for database insertion/update.
Prepares the model data for sending to SurrealDB by:
- Removing the ID field (handled separately by SurrealDB)
- Converting Python types to SurrealDB-compatible formats
- Serializing nested objects appropriately
Returns:
dict[str, Any]: Dictionary ready for database operations
Example:
>>> person = Person(first_name="John", last_name="Doe")
>>> data = person.db_dump()
>>> await db.create("person", data)
"""
serialized = self.model_dump(context={"root": True})
serialized.pop("id", None)
# this convert_types fonction is only here
return convert_types(serialized)