uni_pydantic

uni-pydantic: Pydantic-based OGM for Uni Graph Database.

This package provides a type-safe Object-Graph Mapping layer on top of the Uni graph database, using Pydantic v2 for model definitions.

Example:

from uni_db import Uni from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector

class Person(UniNode): ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... embedding: Vector[1536] ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")

db = Uni("./my_graph") session = UniSession(db) session.register(Person) session.sync_schema()

alice = Person(name="Alice", age=30, email="alice@example.com") session.add(alice) session.commit()

Query with type safety

adults = session.query(Person).filter(Person.age >= 18).all()

  1# SPDX-License-Identifier: Apache-2.0
  2# Copyright 2024-2026 Dragonscale Team
  3
  4"""
  5uni-pydantic: Pydantic-based OGM for Uni Graph Database.
  6
  7This package provides a type-safe Object-Graph Mapping layer on top of
  8the Uni graph database, using Pydantic v2 for model definitions.
  9
 10Example:
 11    >>> from uni_db import Uni
 12    >>> from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector
 13    >>>
 14    >>> class Person(UniNode):
 15    ...     name: str
 16    ...     age: int | None = None
 17    ...     email: str = Field(unique=True)
 18    ...     embedding: Vector[1536]
 19    ...     friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
 20    >>>
 21    >>> db = Uni("./my_graph")
 22    >>> session = UniSession(db)
 23    >>> session.register(Person)
 24    >>> session.sync_schema()
 25    >>>
 26    >>> alice = Person(name="Alice", age=30, email="alice@example.com")
 27    >>> session.add(alice)
 28    >>> session.commit()
 29    >>>
 30    >>> # Query with type safety
 31    >>> adults = session.query(Person).filter(Person.age >= 18).all()
 32"""
 33
 34__version__ = "1.1.0"
 35
 36# Base classes
 37# Async support
 38from .async_query import AsyncQueryBuilder
 39from .async_session import AsyncUniSession, AsyncUniTransaction
 40from .base import UniEdge, UniNode
 41
 42# Database wrappers
 43from .database import AsyncUniDatabase, UniDatabase
 44
 45# Exceptions
 46from .exceptions import (
 47    BulkLoadError,
 48    CypherInjectionError,
 49    LazyLoadError,
 50    NotPersisted,
 51    NotRegisteredError,
 52    NotTrackedError,
 53    QueryError,
 54    RelationshipError,
 55    SchemaError,
 56    SessionError,
 57    TransactionError,
 58    TypeMappingError,
 59    UniPydanticError,
 60    ValidationError,
 61)
 62
 63# Field configuration
 64from .fields import (
 65    Direction,
 66    Field,
 67    FieldConfig,
 68    IndexType,
 69    Relationship,
 70    RelationshipConfig,
 71    RelationshipDescriptor,
 72    VectorMetric,
 73    get_field_config,
 74)
 75
 76# Lifecycle hooks
 77from .hooks import (
 78    after_create,
 79    after_delete,
 80    after_load,
 81    after_update,
 82    before_create,
 83    before_delete,
 84    before_load,
 85    before_update,
 86)
 87
 88# Query builder
 89from .query import (
 90    FilterExpr,
 91    FilterOp,
 92    ModelProxy,
 93    OrderByClause,
 94    PropertyProxy,
 95    QueryBuilder,
 96    TraversalStep,
 97    VectorSearchConfig,
 98)
 99
100# Schema generation
101from .schema import (
102    DatabaseSchema,
103    EdgeTypeSchema,
104    LabelSchema,
105    PropertySchema,
106    SchemaGenerator,
107    generate_schema,
108)
109
110# Session management
111from .session import UniSession, UniTransaction
112
113# Type utilities
114from .types import (
115    DATETIME_TYPES,
116    Btic,
117    Vector,
118    db_to_python_value,
119    get_vector_dimensions,
120    is_list_type,
121    is_optional,
122    python_to_db_value,
123    python_type_to_uni,
124    uni_to_python_type,
125    unwrap_annotated,
126)
127
128__all__ = [
129    # Version
130    "__version__",
131    # Base classes
132    "UniNode",
133    "UniEdge",
134    # Session
135    "UniSession",
136    "UniTransaction",
137    # Async Session
138    "AsyncUniSession",
139    "AsyncUniTransaction",
140    # Fields
141    "Field",
142    "FieldConfig",
143    "Relationship",
144    "RelationshipConfig",
145    "RelationshipDescriptor",
146    "get_field_config",
147    "IndexType",
148    "Direction",
149    "VectorMetric",
150    # Types
151    "Btic",
152    "Vector",
153    "python_type_to_uni",
154    "uni_to_python_type",
155    "get_vector_dimensions",
156    "is_optional",
157    "is_list_type",
158    "unwrap_annotated",
159    "python_to_db_value",
160    "db_to_python_value",
161    "DATETIME_TYPES",
162    # Query
163    "QueryBuilder",
164    "AsyncQueryBuilder",
165    "FilterExpr",
166    "FilterOp",
167    "PropertyProxy",
168    "ModelProxy",
169    "OrderByClause",
170    "TraversalStep",
171    "VectorSearchConfig",
172    # Schema
173    "SchemaGenerator",
174    "DatabaseSchema",
175    "LabelSchema",
176    "EdgeTypeSchema",
177    "PropertySchema",
178    "generate_schema",
179    # Database
180    "UniDatabase",
181    "AsyncUniDatabase",
182    # Hooks
183    "before_create",
184    "after_create",
185    "before_update",
186    "after_update",
187    "before_delete",
188    "after_delete",
189    "before_load",
190    "after_load",
191    # Exceptions
192    "UniPydanticError",
193    "SchemaError",
194    "TypeMappingError",
195    "ValidationError",
196    "SessionError",
197    "NotRegisteredError",
198    "NotPersisted",
199    "NotTrackedError",
200    "TransactionError",
201    "QueryError",
202    "RelationshipError",
203    "LazyLoadError",
204    "BulkLoadError",
205    "CypherInjectionError",
206]
__version__ = '1.1.0'
class UniNode(pydantic.main.BaseModel):
159class UniNode(BaseModel, metaclass=UniModelMeta):
160    """
161    Base class for graph node models.
162
163    Subclass this to define your node types. Each UniNode subclass
164    represents a vertex label in the graph database.
165
166    Attributes:
167        __label__: The vertex label name. Defaults to the class name.
168        __relationships__: Dictionary of relationship configurations.
169
170    Private Attributes:
171        _vid: The vertex ID assigned by the database.
172        _uid: The unique identifier (content-addressed hash).
173        _session: Reference to the owning session.
174        _dirty: Set of modified field names.
175
176    Example:
177        >>> class Person(UniNode):
178        ...     __label__ = "Person"
179        ...
180        ...     name: str
181        ...     age: int | None = None
182        ...     email: str = Field(unique=True)
183        ...
184        ...     friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
185    """
186
187    model_config = ConfigDict(
188        # Allow extra fields for future extensibility
189        extra="forbid",
190        # Validate on assignment for dirty tracking
191        validate_assignment=True,
192        # Allow arbitrary types (for Vector, etc.)
193        arbitrary_types_allowed=True,
194        # Use enum values
195        use_enum_values=True,
196    )
197
198    # Class-level configuration
199    __label__: ClassVar[str] = ""
200    __relationships__: ClassVar[dict[str, RelationshipConfig]] = {}
201
202    # Private attributes for session tracking
203    _vid: int | None = PrivateAttr(default=None)
204    _uid: str | None = PrivateAttr(default=None)
205    _session: UniSession | None = PrivateAttr(default=None)
206    _dirty: set[str] = PrivateAttr(default_factory=set)
207    _is_new: bool = PrivateAttr(default=True)
208
209    def __init_subclass__(cls, **kwargs: Any) -> None:
210        super().__init_subclass__(**kwargs)
211        # Set default label to class name if not specified
212        if not cls.__label__:
213            cls.__label__ = cls.__name__
214
215    def model_post_init(self, __context: Any) -> None:
216        """Clear dirty tracking after construction."""
217        super().model_post_init(__context)
218        self._dirty = set()
219
220    @property
221    def vid(self) -> int | None:
222        """The vertex ID assigned by the database."""
223        return self._vid
224
225    @property
226    def uid(self) -> str | None:
227        """The unique identifier (content-addressed hash)."""
228        return self._uid
229
230    @property
231    def is_persisted(self) -> bool:
232        """Whether this node has been saved to the database."""
233        return self._vid is not None
234
235    @property
236    def is_dirty(self) -> bool:
237        """Whether this node has unsaved changes."""
238        return bool(self._dirty)
239
240    def __setattr__(self, name: str, value: Any) -> None:
241        # Track dirty fields (but not private attributes)
242        if not name.startswith("_") and hasattr(self, "_dirty"):
243            self._dirty.add(name)
244        super().__setattr__(name, value)
245
246    def _mark_clean(self) -> None:
247        """Mark all fields as clean (called after commit)."""
248        self._dirty.clear()
249        self._is_new = False
250
251    def _attach_session(
252        self, session: UniSession, vid: int, uid: str | None = None
253    ) -> None:
254        """Attach this node to a session with its database IDs."""
255        self._session = session
256        self._vid = vid
257        self._uid = uid
258        self._is_new = False
259
260    @classmethod
261    def get_property_fields(cls) -> dict[str, FieldInfo]:
262        """Get all property fields (excluding relationships)."""
263        return {
264            name: info
265            for name, info in cls.model_fields.items()
266            if name not in cls.__relationships__
267        }
268
269    @classmethod
270    def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
271        """Get all relationship field configurations."""
272        return cls.__relationships__
273
274    def to_properties(self) -> dict[str, Any]:
275        """Convert to a property dictionary for database storage.
276
277        Uses python_to_db_value() for type conversion. Includes None
278        explicitly so null-outs work.
279        """
280        return _model_to_properties(self, self.get_property_fields())
281
282    @classmethod
283    def from_properties(
284        cls,
285        props: dict[str, Any],
286        *,
287        vid: int | None = None,
288        uid: str | None = None,
289        session: UniSession | None = None,
290    ) -> UniNode:
291        """Create an instance from a property dictionary.
292
293        Accepts _id (string->int vid) and _label from uni-db node dicts.
294        Does not mutate the input dict.
295        """
296        data = dict(props)
297
298        raw_id = data.pop("_id", None)
299        if raw_id is not None and vid is None:
300            vid = int(raw_id) if not isinstance(raw_id, int) else raw_id
301        data.pop("_label", None)
302
303        converted = _convert_db_values(data, cls)
304
305        instance = cls.model_validate(converted)
306        if vid is not None:
307            instance._vid = vid
308        if uid is not None:
309            instance._uid = uid
310        if session is not None:
311            instance._session = session
312        instance._is_new = vid is None
313        instance._dirty = set()
314        return instance
315
316    def __repr__(self) -> str:
317        vid_str = f"vid={self._vid}" if self._vid else "unsaved"
318        return f"{self.__class__.__name__}({vid_str}, {super().__repr__()})"

Base class for graph node models.

Subclass this to define your node types. Each UniNode subclass represents a vertex label in the graph database.

Attributes: __label__: The vertex label name. Defaults to the class name. __relationships__: Dictionary of relationship configurations.

Private Attributes: _vid: The vertex ID assigned by the database. _uid: The unique identifier (content-addressed hash). _session: Reference to the owning session. _dirty: Set of modified field names.

Example:

class Person(UniNode): ... __label__ = "Person" ... ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")

vid: int | None
220    @property
221    def vid(self) -> int | None:
222        """The vertex ID assigned by the database."""
223        return self._vid

The vertex ID assigned by the database.

uid: str | None
225    @property
226    def uid(self) -> str | None:
227        """The unique identifier (content-addressed hash)."""
228        return self._uid

The unique identifier (content-addressed hash).

is_persisted: bool
230    @property
231    def is_persisted(self) -> bool:
232        """Whether this node has been saved to the database."""
233        return self._vid is not None

Whether this node has been saved to the database.

is_dirty: bool
235    @property
236    def is_dirty(self) -> bool:
237        """Whether this node has unsaved changes."""
238        return bool(self._dirty)

Whether this node has unsaved changes.

@classmethod
def get_property_fields(cls) -> dict[str, pydantic.fields.FieldInfo]:
260    @classmethod
261    def get_property_fields(cls) -> dict[str, FieldInfo]:
262        """Get all property fields (excluding relationships)."""
263        return {
264            name: info
265            for name, info in cls.model_fields.items()
266            if name not in cls.__relationships__
267        }

Get all property fields (excluding relationships).

@classmethod
def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
269    @classmethod
270    def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
271        """Get all relationship field configurations."""
272        return cls.__relationships__

Get all relationship field configurations.

def to_properties(self) -> dict[str, typing.Any]:
274    def to_properties(self) -> dict[str, Any]:
275        """Convert to a property dictionary for database storage.
276
277        Uses python_to_db_value() for type conversion. Includes None
278        explicitly so null-outs work.
279        """
280        return _model_to_properties(self, self.get_property_fields())

Convert to a property dictionary for database storage.

Uses python_to_db_value() for type conversion. Includes None explicitly so null-outs work.

@classmethod
def from_properties( cls, props: dict[str, typing.Any], *, vid: int | None = None, uid: str | None = None, session: UniSession | None = None) -> UniNode:
282    @classmethod
283    def from_properties(
284        cls,
285        props: dict[str, Any],
286        *,
287        vid: int | None = None,
288        uid: str | None = None,
289        session: UniSession | None = None,
290    ) -> UniNode:
291        """Create an instance from a property dictionary.
292
293        Accepts _id (string->int vid) and _label from uni-db node dicts.
294        Does not mutate the input dict.
295        """
296        data = dict(props)
297
298        raw_id = data.pop("_id", None)
299        if raw_id is not None and vid is None:
300            vid = int(raw_id) if not isinstance(raw_id, int) else raw_id
301        data.pop("_label", None)
302
303        converted = _convert_db_values(data, cls)
304
305        instance = cls.model_validate(converted)
306        if vid is not None:
307            instance._vid = vid
308        if uid is not None:
309            instance._uid = uid
310        if session is not None:
311            instance._session = session
312        instance._is_new = vid is None
313        instance._dirty = set()
314        return instance

Create an instance from a property dictionary.

Accepts _id (string->int vid) and _label from uni-db node dicts. Does not mutate the input dict.

class UniEdge(pydantic.main.BaseModel):
321class UniEdge(BaseModel, metaclass=UniModelMeta):
322    """
323    Base class for graph edge models with properties.
324
325    Subclass this to define edge types with typed properties.
326    Edges represent relationships between nodes.
327
328    Attributes:
329        __edge_type__: The edge type name.
330        __from__: The source node type(s).
331        __to__: The target node type(s).
332
333    Private Attributes:
334        _eid: The edge ID assigned by the database.
335        _src_vid: The source vertex ID.
336        _dst_vid: The destination vertex ID.
337        _session: Reference to the owning session.
338
339    Example:
340        >>> class FriendshipEdge(UniEdge):
341        ...     __edge_type__ = "FRIEND_OF"
342        ...     __from__ = Person
343        ...     __to__ = Person
344        ...
345        ...     since: date
346        ...     strength: float = 1.0
347    """
348
349    model_config = ConfigDict(
350        extra="forbid",
351        validate_assignment=True,
352        arbitrary_types_allowed=True,
353        use_enum_values=True,
354    )
355
356    # Class-level configuration
357    __edge_type__: ClassVar[str] = ""
358    __from__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None
359    __to__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None
360    __relationships__: ClassVar[dict[str, RelationshipConfig]] = {}
361
362    # Private attributes
363    _eid: int | None = PrivateAttr(default=None)
364    _src_vid: int | None = PrivateAttr(default=None)
365    _dst_vid: int | None = PrivateAttr(default=None)
366    _session: UniSession | None = PrivateAttr(default=None)
367    _is_new: bool = PrivateAttr(default=True)
368
369    def __init_subclass__(cls, **kwargs: Any) -> None:
370        super().__init_subclass__(**kwargs)
371        # Set default edge type to class name if not specified
372        if not cls.__edge_type__:
373            cls.__edge_type__ = cls.__name__
374
375    @property
376    def eid(self) -> int | None:
377        """The edge ID assigned by the database."""
378        return self._eid
379
380    @property
381    def src_vid(self) -> int | None:
382        """The source vertex ID."""
383        return self._src_vid
384
385    @property
386    def dst_vid(self) -> int | None:
387        """The destination vertex ID."""
388        return self._dst_vid
389
390    @property
391    def is_persisted(self) -> bool:
392        """Whether this edge has been saved to the database."""
393        return self._eid is not None
394
395    def _attach(
396        self,
397        session: UniSession,
398        eid: int,
399        src_vid: int,
400        dst_vid: int,
401    ) -> None:
402        """Attach this edge to a session with its database IDs."""
403        self._session = session
404        self._eid = eid
405        self._src_vid = src_vid
406        self._dst_vid = dst_vid
407        self._is_new = False
408
409    @classmethod
410    def get_from_labels(cls) -> list[str]:
411        """Get the source label names."""
412        if cls.__from__ is None:
413            return []
414        if isinstance(cls.__from__, tuple):
415            return [n.__label__ for n in cls.__from__]
416        return [cls.__from__.__label__]
417
418    @classmethod
419    def get_to_labels(cls) -> list[str]:
420        """Get the target label names."""
421        if cls.__to__ is None:
422            return []
423        if isinstance(cls.__to__, tuple):
424            return [n.__label__ for n in cls.__to__]
425        return [cls.__to__.__label__]
426
427    @classmethod
428    def get_property_fields(cls) -> dict[str, FieldInfo]:
429        """Get all property fields."""
430        return dict(cls.model_fields)
431
432    def to_properties(self) -> dict[str, Any]:
433        """Convert to a property dictionary for database storage."""
434        return _model_to_properties(self, self.get_property_fields())
435
436    @classmethod
437    def from_properties(
438        cls,
439        props: dict[str, Any],
440        *,
441        eid: int | None = None,
442        src_vid: int | None = None,
443        dst_vid: int | None = None,
444        session: UniSession | None = None,
445    ) -> UniEdge:
446        """Create an instance from a property dictionary.
447
448        Accepts _id, _type, _src, _dst from uni-db edge dicts.
449        Does not mutate the input dict.
450        """
451        data = dict(props)
452
453        raw_id = data.pop("_id", None)
454        if raw_id is not None and eid is None:
455            eid = int(raw_id) if not isinstance(raw_id, int) else raw_id
456        data.pop("_type", None)
457        raw_src = data.pop("_src", None)
458        if raw_src is not None and src_vid is None:
459            src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src
460        raw_dst = data.pop("_dst", None)
461        if raw_dst is not None and dst_vid is None:
462            dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst
463
464        converted = _convert_db_values(data, cls)
465
466        instance = cls.model_validate(converted)
467        if eid is not None:
468            instance._eid = eid
469        if src_vid is not None:
470            instance._src_vid = src_vid
471        if dst_vid is not None:
472            instance._dst_vid = dst_vid
473        if session is not None:
474            instance._session = session
475        instance._is_new = eid is None
476        return instance
477
478    @classmethod
479    def from_edge_result(
480        cls,
481        data: dict[str, Any],
482        *,
483        session: UniSession | None = None,
484    ) -> UniEdge:
485        """Create an instance from a uni-db edge result dict.
486
487        Convenience method that handles _id, _type, _src, _dst keys.
488        """
489        return cls.from_properties(data, session=session)
490
491    def __repr__(self) -> str:
492        eid_str = f"eid={self._eid}" if self._eid else "unsaved"
493        return f"{self.__class__.__name__}({eid_str}, {super().__repr__()})"

Base class for graph edge models with properties.

Subclass this to define edge types with typed properties. Edges represent relationships between nodes.

Attributes: __edge_type__: The edge type name. __from__: The source node type(s). __to__: The target node type(s).

Private Attributes: _eid: The edge ID assigned by the database. _src_vid: The source vertex ID. _dst_vid: The destination vertex ID. _session: Reference to the owning session.

Example:

class FriendshipEdge(UniEdge): ... __edge_type__ = "FRIEND_OF" ... __from__ = Person ... __to__ = Person ... ... since: date ... strength: float = 1.0

eid: int | None
375    @property
376    def eid(self) -> int | None:
377        """The edge ID assigned by the database."""
378        return self._eid

The edge ID assigned by the database.

src_vid: int | None
380    @property
381    def src_vid(self) -> int | None:
382        """The source vertex ID."""
383        return self._src_vid

The source vertex ID.

dst_vid: int | None
385    @property
386    def dst_vid(self) -> int | None:
387        """The destination vertex ID."""
388        return self._dst_vid

The destination vertex ID.

is_persisted: bool
390    @property
391    def is_persisted(self) -> bool:
392        """Whether this edge has been saved to the database."""
393        return self._eid is not None

Whether this edge has been saved to the database.

@classmethod
def get_from_labels(cls) -> list[str]:
409    @classmethod
410    def get_from_labels(cls) -> list[str]:
411        """Get the source label names."""
412        if cls.__from__ is None:
413            return []
414        if isinstance(cls.__from__, tuple):
415            return [n.__label__ for n in cls.__from__]
416        return [cls.__from__.__label__]

Get the source label names.

@classmethod
def get_to_labels(cls) -> list[str]:
418    @classmethod
419    def get_to_labels(cls) -> list[str]:
420        """Get the target label names."""
421        if cls.__to__ is None:
422            return []
423        if isinstance(cls.__to__, tuple):
424            return [n.__label__ for n in cls.__to__]
425        return [cls.__to__.__label__]

Get the target label names.

@classmethod
def get_property_fields(cls) -> dict[str, pydantic.fields.FieldInfo]:
427    @classmethod
428    def get_property_fields(cls) -> dict[str, FieldInfo]:
429        """Get all property fields."""
430        return dict(cls.model_fields)

Get all property fields.

def to_properties(self) -> dict[str, typing.Any]:
432    def to_properties(self) -> dict[str, Any]:
433        """Convert to a property dictionary for database storage."""
434        return _model_to_properties(self, self.get_property_fields())

Convert to a property dictionary for database storage.

@classmethod
def from_properties( cls, props: dict[str, typing.Any], *, eid: int | None = None, src_vid: int | None = None, dst_vid: int | None = None, session: UniSession | None = None) -> UniEdge:
436    @classmethod
437    def from_properties(
438        cls,
439        props: dict[str, Any],
440        *,
441        eid: int | None = None,
442        src_vid: int | None = None,
443        dst_vid: int | None = None,
444        session: UniSession | None = None,
445    ) -> UniEdge:
446        """Create an instance from a property dictionary.
447
448        Accepts _id, _type, _src, _dst from uni-db edge dicts.
449        Does not mutate the input dict.
450        """
451        data = dict(props)
452
453        raw_id = data.pop("_id", None)
454        if raw_id is not None and eid is None:
455            eid = int(raw_id) if not isinstance(raw_id, int) else raw_id
456        data.pop("_type", None)
457        raw_src = data.pop("_src", None)
458        if raw_src is not None and src_vid is None:
459            src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src
460        raw_dst = data.pop("_dst", None)
461        if raw_dst is not None and dst_vid is None:
462            dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst
463
464        converted = _convert_db_values(data, cls)
465
466        instance = cls.model_validate(converted)
467        if eid is not None:
468            instance._eid = eid
469        if src_vid is not None:
470            instance._src_vid = src_vid
471        if dst_vid is not None:
472            instance._dst_vid = dst_vid
473        if session is not None:
474            instance._session = session
475        instance._is_new = eid is None
476        return instance

Create an instance from a property dictionary.

Accepts _id, _type, _src, _dst from uni-db edge dicts. Does not mutate the input dict.

@classmethod
def from_edge_result( cls, data: dict[str, typing.Any], *, session: UniSession | None = None) -> UniEdge:
478    @classmethod
479    def from_edge_result(
480        cls,
481        data: dict[str, Any],
482        *,
483        session: UniSession | None = None,
484    ) -> UniEdge:
485        """Create an instance from a uni-db edge result dict.
486
487        Convenience method that handles _id, _type, _src, _dst keys.
488        """
489        return cls.from_properties(data, session=session)

Create an instance from a uni-db edge result dict.

Convenience method that handles _id, _type, _src, _dst keys.

class UniSession:
160class UniSession:
161    """
162    Session for interacting with the graph database using Pydantic models.
163
164    The session manages model registration, schema synchronization,
165    and provides CRUD operations and query building.
166
167    Example:
168        >>> from uni_db import Uni
169        >>> from uni_pydantic import UniSession
170        >>>
171        >>> db = Uni("./my_graph")
172        >>> session = UniSession(db)
173        >>> session.register(Person, Company)
174        >>> session.sync_schema()
175        >>>
176        >>> alice = Person(name="Alice", age=30)
177        >>> session.add(alice)
178        >>> session.commit()
179    """
180
181    def __init__(self, db: uni_db.Uni) -> None:
182        self._db = db
183        self._db_session = db.session()
184        self._schema_gen = SchemaGenerator()
185        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
186            WeakValueDictionary()
187        )
188        self._pending_new: list[UniNode] = []
189        self._pending_delete: list[UniNode] = []
190
191    def __enter__(self) -> UniSession:
192        return self
193
194    def __exit__(
195        self,
196        exc_type: type[BaseException] | None,
197        exc_val: BaseException | None,
198        exc_tb: TracebackType | None,
199    ) -> None:
200        self.close()
201
202    def close(self) -> None:
203        """Close the session and clear all pending state."""
204        self._pending_new.clear()
205        self._pending_delete.clear()
206
207    @property
208    def db(self) -> uni_db.Uni:
209        """Access the underlying uni_db.Uni for low-level operations."""
210        return self._db
211
212    def locy(
213        self, program: str, params: dict[str, Any] | None = None
214    ) -> uni_db.LocyResult:
215        """
216        Evaluate a Locy program and return derived facts, stats, and warnings.
217
218        Delegates to the underlying ``uni_db.Session.locy()``.
219        """
220        return self._db_session.locy(program, params)
221
222    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
223        """
224        Register model classes with the session.
225
226        Registered models can be used for schema generation and queries.
227
228        Args:
229            *models: UniNode or UniEdge subclasses to register.
230        """
231        self._schema_gen.register(*models)
232
233    def sync_schema(self) -> None:
234        """
235        Synchronize database schema with registered models.
236
237        Creates labels, edge types, properties, and indexes as needed.
238        This is additive-only; it won't remove existing schema elements.
239        """
240        self._schema_gen.apply_to_database(self._db)
241
242    def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]:
243        """
244        Create a query builder for the given model.
245
246        Args:
247            model: The UniNode subclass to query.
248
249        Returns:
250            A QueryBuilder for constructing queries.
251        """
252        return QueryBuilder(self, model)
253
254    def add(self, entity: UniNode) -> None:
255        """
256        Add a new entity to be persisted.
257
258        The entity will be inserted on the next commit().
259        """
260        if entity.is_persisted:
261            raise SessionError(f"Entity {entity!r} is already persisted")
262        entity._session = self
263        self._pending_new.append(entity)
264
265    def add_all(self, entities: Sequence[UniNode]) -> None:
266        """Add multiple entities to be persisted."""
267        for entity in entities:
268            self.add(entity)
269
270    def delete(self, entity: UniNode) -> None:
271        """Mark an entity for deletion."""
272        if not entity.is_persisted:
273            raise NotPersisted(entity)
274        self._pending_delete.append(entity)
275
276    def get(
277        self,
278        model: type[NodeT],
279        vid: int | None = None,
280        uid: str | None = None,
281        **kwargs: Any,
282    ) -> NodeT | None:
283        """
284        Get an entity by ID or unique properties.
285
286        Args:
287            model: The model type to retrieve.
288            vid: Vertex ID to look up.
289            uid: Unique ID to look up.
290            **kwargs: Property equality filters.
291
292        Returns:
293            The model instance or None if not found.
294        """
295        # Check identity map first
296        if vid is not None:
297            cached = self._identity_map.get((model.__label__, vid))
298            if cached is not None:
299                return cached  # type: ignore[return-value]
300
301        # Build query
302        label = model.__label__
303        params: dict[str, Any] = {}
304
305        if vid is not None:
306            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
307            params["vid"] = vid
308        elif uid is not None:
309            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
310            params["uid"] = uid
311        elif kwargs:
312            # Validate property names
313            for k in kwargs:
314                _validate_property(k, model)
315            conditions = [f"n.{k} = ${k}" for k in kwargs]
316            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
317            params.update(kwargs)
318        else:
319            raise ValueError("Must provide vid, uid, or property filters")
320
321        results = self._db_session.query(cypher, params)
322        if not results:
323            return None
324
325        node_data = _row_to_node_dict(results[0].to_dict())
326        if node_data is None:
327            return None
328        return self._result_to_model(node_data, model)
329
330    def refresh(self, entity: UniNode) -> None:
331        """Refresh an entity's properties from the database."""
332        if not entity.is_persisted:
333            raise NotPersisted(entity)
334
335        label = entity.__class__.__label__
336        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
337        results = self._db_session.query(cypher, {"vid": entity._vid})
338
339        if not results:
340            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
341
342        # Update properties
343        props = _row_to_node_dict(results[0].to_dict())
344        if props is None:
345            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
346        try:
347            hints = get_type_hints(type(entity))
348        except Exception:
349            hints = {}
350
351        for field_name in entity.get_property_fields():
352            if field_name in props:
353                value = props[field_name]
354                if field_name in hints:
355                    value = db_to_python_value(value, hints[field_name])
356                setattr(entity, field_name, value)
357
358        entity._mark_clean()
359
360    def commit(self) -> None:
361        """
362        Commit all pending changes to the database.
363
364        This persists new entities, updates dirty entities,
365        and deletes marked entities.
366        """
367        # Insert new entities
368        for entity in self._pending_new:
369            self._create_node(entity)
370
371        # Update dirty entities in identity map
372        for (label, vid), entity in list(self._identity_map.items()):
373            if entity.is_dirty and entity.is_persisted:
374                self._update_node(entity)
375
376        # Delete marked entities
377        for entity in self._pending_delete:
378            self._delete_node(entity)
379
380        # Flush to storage
381        self._db.flush()
382
383        # Clear pending lists
384        self._pending_new.clear()
385        self._pending_delete.clear()
386
387    def rollback(self) -> None:
388        """Discard all pending changes."""
389        # Clear pending new — detach entities
390        for entity in self._pending_new:
391            entity._session = None
392        self._pending_new.clear()
393
394        # Clear pending deletes
395        self._pending_delete.clear()
396
397        # Invalidate dirty identity map entries
398        for entity in list(self._identity_map.values()):
399            if entity.is_dirty:
400                self.refresh(entity)
401
402    @contextmanager
403    def transaction(self) -> Iterator[UniTransaction]:
404        """Create a transaction context."""
405        tx = UniTransaction(self)
406        with tx:
407            yield tx
408
409    def begin(self) -> UniTransaction:
410        """Begin a new transaction."""
411        tx = UniTransaction(self)
412        tx._tx = self._db_session.tx()
413        return tx
414
415    def cypher(
416        self,
417        query: str,
418        params: dict[str, Any] | None = None,
419        result_type: type[NodeT] | None = None,
420    ) -> list[NodeT] | list[dict[str, Any]]:
421        """
422        Execute a raw Cypher query.
423
424        Args:
425            query: Cypher query string.
426            params: Query parameters.
427            result_type: Optional model type for result mapping.
428
429        Returns:
430            List of results (model instances if result_type provided).
431        """
432        results = self._db_session.query(query, params)
433
434        if result_type is None:
435            return [r.to_dict() for r in results]
436
437        # Map results to model instances
438        mapped = []
439        for raw_row in results:
440            row = raw_row.to_dict()
441            # Try to find node data in the row
442            for key, value in row.items():
443                if isinstance(value, dict):
444                    # Check for _id/_label keys (uni-db node dict)
445                    if "_id" in value and "_label" in value:
446                        instance = self._result_to_model(value, result_type)
447                        if instance:
448                            mapped.append(instance)
449                            break
450                    # Also check if _label matches registered model
451                    elif "_label" in value:
452                        label = value["_label"]
453                        if label in self._schema_gen._node_models:
454                            model = self._schema_gen._node_models[label]
455                            instance = self._result_to_model(value, model)
456                            if instance:
457                                mapped.append(instance)
458                                break
459            else:
460                # Try the first column
461                first_value = next(iter(row.values()), None)
462                if isinstance(first_value, dict):
463                    instance = self._result_to_model(first_value, result_type)
464                    if instance:
465                        mapped.append(instance)
466
467        return mapped
468
469    @staticmethod
470    def _validate_edge_endpoints(
471        source: UniNode, target: UniNode
472    ) -> tuple[int, int, str, str]:
473        """Validate that both endpoints are persisted and return (src_vid, dst_vid, src_label, dst_label)."""
474        if not source.is_persisted:
475            raise NotPersisted(source)
476        if not target.is_persisted:
477            raise NotPersisted(target)
478        return (
479            source._vid,
480            target._vid,
481            source.__class__.__label__,
482            target.__class__.__label__,
483        )
484
485    @staticmethod
486    def _normalize_edge_properties(
487        properties: dict[str, Any] | UniEdge | None,
488    ) -> dict[str, Any]:
489        """Normalize edge properties from dict, UniEdge, or None."""
490        if isinstance(properties, UniEdge):
491            return properties.to_properties()
492        if properties:
493            return properties
494        return {}
495
496    def create_edge(
497        self,
498        source: UniNode,
499        edge_type: str,
500        target: UniNode,
501        properties: dict[str, Any] | UniEdge | None = None,
502    ) -> None:
503        """Create an edge between two nodes."""
504        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
505            source, target
506        )
507        props = self._normalize_edge_properties(properties)
508
509        # Build CREATE edge query with labels (required by Cypher implementation)
510        props_str = ", ".join(f"{k}: ${k}" for k in props)
511        if props_str:
512            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
513        else:
514            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
515
516        params = {"src": src_vid, "dst": dst_vid, **props}
517        with self._db_session.tx() as tx:
518            tx.execute(cypher, params)
519            tx.commit()
520
521    def delete_edge(
522        self,
523        source: UniNode,
524        edge_type: str,
525        target: UniNode,
526    ) -> int:
527        """Delete edges between two nodes. Returns the number of deleted edges."""
528        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
529            source, target
530        )
531        cypher = (
532            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
533            f"WHERE a._vid = $src AND b._vid = $dst "
534            f"DELETE r RETURN count(r) as count"
535        )
536        with self._db_session.tx() as tx:
537            results = tx.query(cypher, {"src": src_vid, "dst": dst_vid})
538            tx.commit()
539        return cast(int, results[0]["count"]) if results else 0
540
541    def update_edge(
542        self,
543        source: UniNode,
544        edge_type: str,
545        target: UniNode,
546        properties: dict[str, Any],
547    ) -> int:
548        """Update properties on edges between two nodes. Returns the number of updated edges."""
549        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
550            source, target
551        )
552        set_parts = [f"r.{k} = ${k}" for k in properties]
553        params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties}
554        cypher = (
555            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
556            f"WHERE a._vid = $src AND b._vid = $dst "
557            f"SET {', '.join(set_parts)} "
558            f"RETURN count(r) as count"
559        )
560        with self._db_session.tx() as tx:
561            results = tx.query(cypher, params)
562            tx.commit()
563        return cast(int, results[0]["count"]) if results else 0
564
565    def get_edge(
566        self,
567        source: UniNode,
568        edge_type: str,
569        target: UniNode,
570        edge_model: type[EdgeT] | None = None,
571    ) -> list[dict[str, Any]] | list[EdgeT]:
572        """Get edges between two nodes. Returns dicts or edge model instances."""
573        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
574            source, target
575        )
576        cypher = (
577            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
578            f"WHERE a._vid = $src AND b._vid = $dst "
579            f"RETURN properties(r) AS _props, id(r) AS _eid"
580        )
581        results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid})
582        rows = [r.to_dict() for r in results]
583
584        if edge_model is None:
585            edge_dicts: list[dict[str, Any]] = []
586            for row in rows:
587                props = row.get("_props", {})
588                if isinstance(props, dict):
589                    edge_dict = dict(props)
590                    edge_dict["_eid"] = row.get("_eid")
591                    edge_dicts.append(edge_dict)
592            return edge_dicts
593
594        edges = []
595        for row in rows:
596            r_data = row.get("_props", {})
597            if isinstance(r_data, dict):
598                edge = edge_model.from_properties(
599                    r_data,
600                    src_vid=src_vid,
601                    dst_vid=dst_vid,
602                    session=self,
603                )
604                edges.append(edge)
605        return edges
606
607    def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
608        """
609        Bulk-add entities using bulk_writer for performance.
610
611        Groups entities by label and uses db.bulk_writer().
612        Returns VIDs and attaches sessions.
613
614        Args:
615            entities: Sequence of UniNode instances to bulk-insert.
616
617        Returns:
618            List of assigned vertex IDs.
619
620        Raises:
621            BulkLoadError: If bulk insertion fails.
622        """
623        if not entities:
624            return []
625
626        # Group by label
627        by_label: dict[str, list[UniNode]] = {}
628        for entity in entities:
629            label = entity.__class__.__label__
630            if label not in by_label:
631                by_label[label] = []
632            by_label[label].append(entity)
633
634        all_vids: list[int] = []
635        try:
636            for label, group in by_label.items():
637                # Run before_create hooks
638                for entity in group:
639                    run_hooks(entity, _BEFORE_CREATE)
640
641                # Convert to property dicts
642                prop_dicts = [e.to_properties() for e in group]
643
644                # Bulk insert via transaction
645                tx = self._db_session.tx()
646                with tx.bulk_writer().build() as bw:
647                    vids = bw.insert_vertices(label, prop_dicts)
648                    bw.commit()
649                tx.commit()
650
651                # Attach sessions and record VIDs
652                for entity, vid in zip(group, vids):
653                    entity._attach_session(self, vid)
654                    self._identity_map[(label, vid)] = entity
655                    run_hooks(entity, _AFTER_CREATE)
656                    entity._mark_clean()
657
658                all_vids.extend(vids)
659        except Exception as e:
660            raise BulkLoadError(f"Bulk insert failed: {e}") from e
661
662        return all_vids
663
664    def explain(self, cypher: str) -> uni_db.ExplainOutput:
665        """Get the query execution plan without running it."""
666        return self._db_session.explain(cypher)
667
668    def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]:
669        """Run the query with profiling and return results + stats."""
670        return self._db_session.profile(cypher)
671
672    def save_schema(self, path: str) -> None:
673        """Save the database schema to a file."""
674        self._db.save_schema(path)
675
676    def load_schema(self, path: str) -> None:
677        """Load a database schema from a file."""
678        self._db.load_schema(path)
679
680    # -------------------------------------------------------------------------
681    # Internal Methods
682    # -------------------------------------------------------------------------
683
684    def _create_node(self, entity: UniNode) -> None:
685        """Create a node in the database."""
686        # Run before_create hooks
687        run_hooks(entity, _BEFORE_CREATE)
688
689        label = entity.__class__.__label__
690        props = entity.to_properties()
691
692        # Build CREATE query
693        props_str = ", ".join(f"{k}: ${k}" for k in props)
694        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
695
696        with self._db_session.tx() as tx:
697            results = tx.query(cypher, props)
698            tx.commit()
699        if results:
700            vid = results[0]["vid"]
701            entity._attach_session(self, vid)
702
703            # Add to identity map
704            self._identity_map[(label, vid)] = entity
705
706        # Run after_create hooks
707        run_hooks(entity, _AFTER_CREATE)
708        entity._mark_clean()
709
710    def _create_node_in_tx(self, entity: UniNode, tx: uni_db.Transaction) -> None:
711        """Create a node within a transaction."""
712        run_hooks(entity, _BEFORE_CREATE)
713
714        label = entity.__class__.__label__
715        props = entity.to_properties()
716
717        props_str = ", ".join(f"{k}: ${k}" for k in props)
718        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
719
720        results = tx.query(cypher, props)
721        if results:
722            vid = results[0]["vid"]
723            entity._attach_session(self, vid)
724            self._identity_map[(label, vid)] = entity
725
726        run_hooks(entity, _AFTER_CREATE)
727
728    def _create_edge_in_tx(
729        self,
730        source: UniNode,
731        edge_type: str,
732        target: UniNode,
733        properties: UniEdge | None,
734        tx: uni_db.Transaction,
735    ) -> None:
736        """Create an edge within a transaction."""
737        props = properties.to_properties() if properties else {}
738        src_label = source.__class__.__label__
739        dst_label = target.__class__.__label__
740
741        props_str = ", ".join(f"{k}: ${k}" for k in props)
742        if props_str:
743            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)"
744        else:
745            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)"
746
747        params = {"src": source._vid, "dst": target._vid, **props}
748        tx.query(cypher, params)
749
750    def _update_node(self, entity: UniNode) -> None:
751        """Update a node in the database."""
752        run_hooks(entity, _BEFORE_UPDATE)
753
754        label = entity.__class__.__label__
755
756        # Convert dirty prop values via python_to_db_value
757        try:
758            hints = get_type_hints(type(entity))
759        except Exception:
760            hints = {}
761
762        dirty_props = {}
763        for name in entity._dirty:
764            value = getattr(entity, name)
765            if name in hints:
766                value = python_to_db_value(value, hints[name])
767            dirty_props[name] = value
768
769        if not dirty_props:
770            return
771
772        set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props)
773        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}"
774        params = {"vid": entity._vid, **dirty_props}
775
776        with self._db_session.tx() as tx:
777            tx.execute(cypher, params)
778            tx.commit()
779
780        run_hooks(entity, _AFTER_UPDATE)
781        entity._mark_clean()
782
783    def _delete_node(self, entity: UniNode) -> None:
784        """Delete a node from the database."""
785        run_hooks(entity, _BEFORE_DELETE)
786
787        label = entity.__class__.__label__
788        vid = entity._vid
789
790        # DETACH DELETE to also remove connected edges
791        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n"
792        with self._db_session.tx() as tx:
793            tx.execute(cypher, {"vid": vid})
794            tx.commit()
795
796        # Remove from identity map
797        if vid is not None and (label, vid) in self._identity_map:
798            del self._identity_map[(label, vid)]
799
800        # Clear entity IDs
801        entity._vid = None
802        entity._uid = None
803        entity._session = None
804
805        run_hooks(entity, _AFTER_DELETE)
806
807    def _result_to_model(
808        self,
809        data: dict[str, Any],
810        model: type[NodeT],
811    ) -> NodeT | None:
812        """Convert a query result row to a model instance.
813
814        Does not mutate the input dict.
815        """
816        if not data:
817            return None
818
819        # Work on a copy
820        data = dict(data)
821
822        # Run before_load hooks
823        data = run_class_hooks(model, _BEFORE_LOAD, data) or data
824
825        # Extract _id → vid (uni-db returns _id as string or int)
826        vid = data.pop("_id", None)
827        if vid is None:
828            vid = data.pop("_vid", None)
829        if vid is None:
830            vid = data.pop("vid", None)
831        if vid is not None and not isinstance(vid, int):
832            vid = int(vid)
833
834        # Remove _label (informational)
835        data.pop("_label", None)
836
837        try:
838            instance = cast(
839                NodeT,
840                model.from_properties(
841                    data,
842                    vid=vid,
843                    session=self,
844                ),
845            )
846        except Exception:
847            # If validation fails, return None
848            return None
849
850        # Add to identity map if we have a vid
851        if vid is not None:
852            existing = self._identity_map.get((model.__label__, vid))
853            if existing is not None:
854                return cast(NodeT, existing)
855            self._identity_map[(model.__label__, vid)] = instance
856
857        # Run after_load hooks
858        run_hooks(instance, _AFTER_LOAD)
859
860        return instance
861
862    def _load_relationship(
863        self,
864        entity: UniNode,
865        descriptor: RelationshipDescriptor[Any],
866    ) -> list[UniNode] | UniNode | None:
867        """Load a relationship for an entity."""
868        if not entity.is_persisted:
869            raise NotPersisted(entity)
870
871        config = descriptor.config
872        label = entity.__class__.__label__
873        pattern = _edge_pattern(config.edge_type, config.direction)
874
875        cypher = (
876            f"MATCH (a:{label}){pattern}(b) WHERE id(a) = $vid "
877            f"RETURN properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
878        )
879        results = self._db_session.query(cypher, {"vid": entity._vid})
880
881        nodes = []
882        for raw_row in results:
883            row = raw_row.to_dict()
884            node_data = _row_to_node_dict(row)
885            if node_data is None:
886                continue
887            # Try to find the model for this node
888            node_label = node_data.get("_label")
889            if node_label and node_label in self._schema_gen._node_models:
890                model = self._schema_gen._node_models[node_label]
891                instance = self._result_to_model(node_data, model)
892                if instance:
893                    nodes.append(instance)
894
895        if not descriptor.is_list:
896            return nodes[0] if nodes else None
897        return nodes
898
899    def _eager_load_relationships(
900        self,
901        entities: list[NodeT],
902        relationships: list[str],
903    ) -> None:
904        """Eager load relationships for a list of entities."""
905        if not entities:
906            return
907
908        model = type(entities[0])
909        rel_configs = model.get_relationship_fields()
910
911        for rel_name in relationships:
912            if rel_name not in rel_configs:
913                continue
914
915            config = rel_configs[rel_name]
916            label = model.__label__
917            vids = [e._vid for e in entities if e._vid is not None]
918
919            if not vids:
920                continue
921
922            pattern = _edge_pattern(config.edge_type, config.direction)
923            cypher = (
924                f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids "
925                f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
926            )
927            results = self._db_session.query(cypher, {"vids": vids})
928
929            # Group results by source vid
930            by_source: dict[int, list[Any]] = {}
931            for raw_row in results:
932                row = raw_row.to_dict()
933                src_vid = row["src_vid"]
934                node_data = _row_to_node_dict(row)
935                if node_data is None:
936                    continue
937                if src_vid not in by_source:
938                    by_source[src_vid] = []
939                by_source[src_vid].append(node_data)
940
941            # Set cached values on entities
942            for entity in entities:
943                if entity._vid in by_source:
944                    related = by_source[entity._vid]
945                    cache_attr = f"_rel_cache_{rel_name}"
946                    setattr(entity, cache_attr, related)

Session for interacting with the graph database using Pydantic models.

The session manages model registration, schema synchronization, and provides CRUD operations and query building.

Example:

from uni_db import Uni from uni_pydantic import UniSession

db = Uni("./my_graph") session = UniSession(db) session.register(Person, Company) session.sync_schema()

alice = Person(name="Alice", age=30) session.add(alice) session.commit()

UniSession(db: Uni)
181    def __init__(self, db: uni_db.Uni) -> None:
182        self._db = db
183        self._db_session = db.session()
184        self._schema_gen = SchemaGenerator()
185        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
186            WeakValueDictionary()
187        )
188        self._pending_new: list[UniNode] = []
189        self._pending_delete: list[UniNode] = []
def close(self) -> None:
202    def close(self) -> None:
203        """Close the session and clear all pending state."""
204        self._pending_new.clear()
205        self._pending_delete.clear()

Close the session and clear all pending state.

db: Uni
207    @property
208    def db(self) -> uni_db.Uni:
209        """Access the underlying uni_db.Uni for low-level operations."""
210        return self._db

Access the underlying uni_db.Uni for low-level operations.

def locy( self, program: str, params: dict[str, typing.Any] | None = None) -> LocyResult:
212    def locy(
213        self, program: str, params: dict[str, Any] | None = None
214    ) -> uni_db.LocyResult:
215        """
216        Evaluate a Locy program and return derived facts, stats, and warnings.
217
218        Delegates to the underlying ``uni_db.Session.locy()``.
219        """
220        return self._db_session.locy(program, params)

Evaluate a Locy program and return derived facts, stats, and warnings.

Delegates to the underlying uni_db.Session.locy().

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
222    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
223        """
224        Register model classes with the session.
225
226        Registered models can be used for schema generation and queries.
227
228        Args:
229            *models: UniNode or UniEdge subclasses to register.
230        """
231        self._schema_gen.register(*models)

Register model classes with the session.

Registered models can be used for schema generation and queries.

Args: *models: UniNode or UniEdge subclasses to register.

def sync_schema(self) -> None:
233    def sync_schema(self) -> None:
234        """
235        Synchronize database schema with registered models.
236
237        Creates labels, edge types, properties, and indexes as needed.
238        This is additive-only; it won't remove existing schema elements.
239        """
240        self._schema_gen.apply_to_database(self._db)

Synchronize database schema with registered models.

Creates labels, edge types, properties, and indexes as needed. This is additive-only; it won't remove existing schema elements.

def query(self, model: type[~NodeT]) -> QueryBuilder[~NodeT]:
242    def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]:
243        """
244        Create a query builder for the given model.
245
246        Args:
247            model: The UniNode subclass to query.
248
249        Returns:
250            A QueryBuilder for constructing queries.
251        """
252        return QueryBuilder(self, model)

Create a query builder for the given model.

Args: model: The UniNode subclass to query.

Returns: A QueryBuilder for constructing queries.

def add(self, entity: UniNode) -> None:
254    def add(self, entity: UniNode) -> None:
255        """
256        Add a new entity to be persisted.
257
258        The entity will be inserted on the next commit().
259        """
260        if entity.is_persisted:
261            raise SessionError(f"Entity {entity!r} is already persisted")
262        entity._session = self
263        self._pending_new.append(entity)

Add a new entity to be persisted.

The entity will be inserted on the next commit().

def add_all(self, entities: Sequence[UniNode]) -> None:
265    def add_all(self, entities: Sequence[UniNode]) -> None:
266        """Add multiple entities to be persisted."""
267        for entity in entities:
268            self.add(entity)

Add multiple entities to be persisted.

def delete(self, entity: UniNode) -> None:
270    def delete(self, entity: UniNode) -> None:
271        """Mark an entity for deletion."""
272        if not entity.is_persisted:
273            raise NotPersisted(entity)
274        self._pending_delete.append(entity)

Mark an entity for deletion.

def get( self, model: type[~NodeT], vid: int | None = None, uid: str | None = None, **kwargs: Any) -> Optional[~NodeT]:
276    def get(
277        self,
278        model: type[NodeT],
279        vid: int | None = None,
280        uid: str | None = None,
281        **kwargs: Any,
282    ) -> NodeT | None:
283        """
284        Get an entity by ID or unique properties.
285
286        Args:
287            model: The model type to retrieve.
288            vid: Vertex ID to look up.
289            uid: Unique ID to look up.
290            **kwargs: Property equality filters.
291
292        Returns:
293            The model instance or None if not found.
294        """
295        # Check identity map first
296        if vid is not None:
297            cached = self._identity_map.get((model.__label__, vid))
298            if cached is not None:
299                return cached  # type: ignore[return-value]
300
301        # Build query
302        label = model.__label__
303        params: dict[str, Any] = {}
304
305        if vid is not None:
306            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
307            params["vid"] = vid
308        elif uid is not None:
309            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
310            params["uid"] = uid
311        elif kwargs:
312            # Validate property names
313            for k in kwargs:
314                _validate_property(k, model)
315            conditions = [f"n.{k} = ${k}" for k in kwargs]
316            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
317            params.update(kwargs)
318        else:
319            raise ValueError("Must provide vid, uid, or property filters")
320
321        results = self._db_session.query(cypher, params)
322        if not results:
323            return None
324
325        node_data = _row_to_node_dict(results[0].to_dict())
326        if node_data is None:
327            return None
328        return self._result_to_model(node_data, model)

Get an entity by ID or unique properties.

Args: model: The model type to retrieve. vid: Vertex ID to look up. uid: Unique ID to look up. **kwargs: Property equality filters.

Returns: The model instance or None if not found.

def refresh(self, entity: UniNode) -> None:
330    def refresh(self, entity: UniNode) -> None:
331        """Refresh an entity's properties from the database."""
332        if not entity.is_persisted:
333            raise NotPersisted(entity)
334
335        label = entity.__class__.__label__
336        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
337        results = self._db_session.query(cypher, {"vid": entity._vid})
338
339        if not results:
340            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
341
342        # Update properties
343        props = _row_to_node_dict(results[0].to_dict())
344        if props is None:
345            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
346        try:
347            hints = get_type_hints(type(entity))
348        except Exception:
349            hints = {}
350
351        for field_name in entity.get_property_fields():
352            if field_name in props:
353                value = props[field_name]
354                if field_name in hints:
355                    value = db_to_python_value(value, hints[field_name])
356                setattr(entity, field_name, value)
357
358        entity._mark_clean()

Refresh an entity's properties from the database.

def commit(self) -> None:
360    def commit(self) -> None:
361        """
362        Commit all pending changes to the database.
363
364        This persists new entities, updates dirty entities,
365        and deletes marked entities.
366        """
367        # Insert new entities
368        for entity in self._pending_new:
369            self._create_node(entity)
370
371        # Update dirty entities in identity map
372        for (label, vid), entity in list(self._identity_map.items()):
373            if entity.is_dirty and entity.is_persisted:
374                self._update_node(entity)
375
376        # Delete marked entities
377        for entity in self._pending_delete:
378            self._delete_node(entity)
379
380        # Flush to storage
381        self._db.flush()
382
383        # Clear pending lists
384        self._pending_new.clear()
385        self._pending_delete.clear()

Commit all pending changes to the database.

This persists new entities, updates dirty entities, and deletes marked entities.

def rollback(self) -> None:
387    def rollback(self) -> None:
388        """Discard all pending changes."""
389        # Clear pending new — detach entities
390        for entity in self._pending_new:
391            entity._session = None
392        self._pending_new.clear()
393
394        # Clear pending deletes
395        self._pending_delete.clear()
396
397        # Invalidate dirty identity map entries
398        for entity in list(self._identity_map.values()):
399            if entity.is_dirty:
400                self.refresh(entity)

Discard all pending changes.

@contextmanager
def transaction(self) -> Iterator[UniTransaction]:
402    @contextmanager
403    def transaction(self) -> Iterator[UniTransaction]:
404        """Create a transaction context."""
405        tx = UniTransaction(self)
406        with tx:
407            yield tx

Create a transaction context.

def begin(self) -> UniTransaction:
409    def begin(self) -> UniTransaction:
410        """Begin a new transaction."""
411        tx = UniTransaction(self)
412        tx._tx = self._db_session.tx()
413        return tx

Begin a new transaction.

def cypher( self, query: str, params: dict[str, typing.Any] | None = None, result_type: type[~NodeT] | None = None) -> list[~NodeT] | list[dict[str, typing.Any]]:
415    def cypher(
416        self,
417        query: str,
418        params: dict[str, Any] | None = None,
419        result_type: type[NodeT] | None = None,
420    ) -> list[NodeT] | list[dict[str, Any]]:
421        """
422        Execute a raw Cypher query.
423
424        Args:
425            query: Cypher query string.
426            params: Query parameters.
427            result_type: Optional model type for result mapping.
428
429        Returns:
430            List of results (model instances if result_type provided).
431        """
432        results = self._db_session.query(query, params)
433
434        if result_type is None:
435            return [r.to_dict() for r in results]
436
437        # Map results to model instances
438        mapped = []
439        for raw_row in results:
440            row = raw_row.to_dict()
441            # Try to find node data in the row
442            for key, value in row.items():
443                if isinstance(value, dict):
444                    # Check for _id/_label keys (uni-db node dict)
445                    if "_id" in value and "_label" in value:
446                        instance = self._result_to_model(value, result_type)
447                        if instance:
448                            mapped.append(instance)
449                            break
450                    # Also check if _label matches registered model
451                    elif "_label" in value:
452                        label = value["_label"]
453                        if label in self._schema_gen._node_models:
454                            model = self._schema_gen._node_models[label]
455                            instance = self._result_to_model(value, model)
456                            if instance:
457                                mapped.append(instance)
458                                break
459            else:
460                # Try the first column
461                first_value = next(iter(row.values()), None)
462                if isinstance(first_value, dict):
463                    instance = self._result_to_model(first_value, result_type)
464                    if instance:
465                        mapped.append(instance)
466
467        return mapped

Execute a raw Cypher query.

Args: query: Cypher query string. params: Query parameters. result_type: Optional model type for result mapping.

Returns: List of results (model instances if result_type provided).

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any] | UniEdge | None = None) -> None:
496    def create_edge(
497        self,
498        source: UniNode,
499        edge_type: str,
500        target: UniNode,
501        properties: dict[str, Any] | UniEdge | None = None,
502    ) -> None:
503        """Create an edge between two nodes."""
504        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
505            source, target
506        )
507        props = self._normalize_edge_properties(properties)
508
509        # Build CREATE edge query with labels (required by Cypher implementation)
510        props_str = ", ".join(f"{k}: ${k}" for k in props)
511        if props_str:
512            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
513        else:
514            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
515
516        params = {"src": src_vid, "dst": dst_vid, **props}
517        with self._db_session.tx() as tx:
518            tx.execute(cypher, params)
519            tx.commit()

Create an edge between two nodes.

def delete_edge( self, source: UniNode, edge_type: str, target: UniNode) -> int:
521    def delete_edge(
522        self,
523        source: UniNode,
524        edge_type: str,
525        target: UniNode,
526    ) -> int:
527        """Delete edges between two nodes. Returns the number of deleted edges."""
528        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
529            source, target
530        )
531        cypher = (
532            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
533            f"WHERE a._vid = $src AND b._vid = $dst "
534            f"DELETE r RETURN count(r) as count"
535        )
536        with self._db_session.tx() as tx:
537            results = tx.query(cypher, {"src": src_vid, "dst": dst_vid})
538            tx.commit()
539        return cast(int, results[0]["count"]) if results else 0

Delete edges between two nodes. Returns the number of deleted edges.

def update_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any]) -> int:
541    def update_edge(
542        self,
543        source: UniNode,
544        edge_type: str,
545        target: UniNode,
546        properties: dict[str, Any],
547    ) -> int:
548        """Update properties on edges between two nodes. Returns the number of updated edges."""
549        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
550            source, target
551        )
552        set_parts = [f"r.{k} = ${k}" for k in properties]
553        params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties}
554        cypher = (
555            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
556            f"WHERE a._vid = $src AND b._vid = $dst "
557            f"SET {', '.join(set_parts)} "
558            f"RETURN count(r) as count"
559        )
560        with self._db_session.tx() as tx:
561            results = tx.query(cypher, params)
562            tx.commit()
563        return cast(int, results[0]["count"]) if results else 0

Update properties on edges between two nodes. Returns the number of updated edges.

def get_edge( self, source: UniNode, edge_type: str, target: UniNode, edge_model: type[~EdgeT] | None = None) -> list[dict[str, typing.Any]] | list[~EdgeT]:
565    def get_edge(
566        self,
567        source: UniNode,
568        edge_type: str,
569        target: UniNode,
570        edge_model: type[EdgeT] | None = None,
571    ) -> list[dict[str, Any]] | list[EdgeT]:
572        """Get edges between two nodes. Returns dicts or edge model instances."""
573        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
574            source, target
575        )
576        cypher = (
577            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
578            f"WHERE a._vid = $src AND b._vid = $dst "
579            f"RETURN properties(r) AS _props, id(r) AS _eid"
580        )
581        results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid})
582        rows = [r.to_dict() for r in results]
583
584        if edge_model is None:
585            edge_dicts: list[dict[str, Any]] = []
586            for row in rows:
587                props = row.get("_props", {})
588                if isinstance(props, dict):
589                    edge_dict = dict(props)
590                    edge_dict["_eid"] = row.get("_eid")
591                    edge_dicts.append(edge_dict)
592            return edge_dicts
593
594        edges = []
595        for row in rows:
596            r_data = row.get("_props", {})
597            if isinstance(r_data, dict):
598                edge = edge_model.from_properties(
599                    r_data,
600                    src_vid=src_vid,
601                    dst_vid=dst_vid,
602                    session=self,
603                )
604                edges.append(edge)
605        return edges

Get edges between two nodes. Returns dicts or edge model instances.

def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
607    def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
608        """
609        Bulk-add entities using bulk_writer for performance.
610
611        Groups entities by label and uses db.bulk_writer().
612        Returns VIDs and attaches sessions.
613
614        Args:
615            entities: Sequence of UniNode instances to bulk-insert.
616
617        Returns:
618            List of assigned vertex IDs.
619
620        Raises:
621            BulkLoadError: If bulk insertion fails.
622        """
623        if not entities:
624            return []
625
626        # Group by label
627        by_label: dict[str, list[UniNode]] = {}
628        for entity in entities:
629            label = entity.__class__.__label__
630            if label not in by_label:
631                by_label[label] = []
632            by_label[label].append(entity)
633
634        all_vids: list[int] = []
635        try:
636            for label, group in by_label.items():
637                # Run before_create hooks
638                for entity in group:
639                    run_hooks(entity, _BEFORE_CREATE)
640
641                # Convert to property dicts
642                prop_dicts = [e.to_properties() for e in group]
643
644                # Bulk insert via transaction
645                tx = self._db_session.tx()
646                with tx.bulk_writer().build() as bw:
647                    vids = bw.insert_vertices(label, prop_dicts)
648                    bw.commit()
649                tx.commit()
650
651                # Attach sessions and record VIDs
652                for entity, vid in zip(group, vids):
653                    entity._attach_session(self, vid)
654                    self._identity_map[(label, vid)] = entity
655                    run_hooks(entity, _AFTER_CREATE)
656                    entity._mark_clean()
657
658                all_vids.extend(vids)
659        except Exception as e:
660            raise BulkLoadError(f"Bulk insert failed: {e}") from e
661
662        return all_vids

Bulk-add entities using bulk_writer for performance.

Groups entities by label and uses db.bulk_writer(). Returns VIDs and attaches sessions.

Args: entities: Sequence of UniNode instances to bulk-insert.

Returns: List of assigned vertex IDs.

Raises: BulkLoadError: If bulk insertion fails.

def explain(self, cypher: str) -> ExplainOutput:
664    def explain(self, cypher: str) -> uni_db.ExplainOutput:
665        """Get the query execution plan without running it."""
666        return self._db_session.explain(cypher)

Get the query execution plan without running it.

def profile(self, cypher: str) -> tuple[QueryResult, ProfileOutput]:
668    def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]:
669        """Run the query with profiling and return results + stats."""
670        return self._db_session.profile(cypher)

Run the query with profiling and return results + stats.

def save_schema(self, path: str) -> None:
672    def save_schema(self, path: str) -> None:
673        """Save the database schema to a file."""
674        self._db.save_schema(path)

Save the database schema to a file.

def load_schema(self, path: str) -> None:
676    def load_schema(self, path: str) -> None:
677        """Load a database schema from a file."""
678        self._db.load_schema(path)

Load a database schema from a file.

class UniTransaction:
 61class UniTransaction:
 62    """
 63    Transaction context for atomic operations.
 64
 65    Provides commit/rollback semantics for a group of operations.
 66
 67    Example:
 68        >>> with session.transaction() as tx:
 69        ...     alice = Person(name="Alice")
 70        ...     tx.add(alice)
 71        ...     # Auto-commits on success, rolls back on exception
 72    """
 73
 74    def __init__(self, session: UniSession) -> None:
 75        self._session = session
 76        self._tx: uni_db.Transaction | None = None
 77        self._pending_nodes: list[UniNode] = []
 78        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
 79        self._committed = False
 80        self._rolled_back = False
 81
 82    def __enter__(self) -> UniTransaction:
 83        self._tx = self._session._db_session.tx()
 84        return self
 85
 86    def __exit__(
 87        self,
 88        exc_type: type[BaseException] | None,
 89        exc_val: BaseException | None,
 90        exc_tb: TracebackType | None,
 91    ) -> None:
 92        if exc_type is not None:
 93            self.rollback()
 94            return
 95        if not self._committed and not self._rolled_back:
 96            self.commit()
 97
 98    def add(self, entity: UniNode) -> None:
 99        """Add a node to be created in this transaction."""
100        self._pending_nodes.append(entity)
101
102    def create_edge(
103        self,
104        source: UniNode,
105        edge_type: str,
106        target: UniNode,
107        properties: UniEdge | None = None,
108        **kwargs: Any,
109    ) -> None:
110        """Create an edge between two nodes in this transaction."""
111        if not source.is_persisted:
112            raise NotPersisted(source)
113        if not target.is_persisted:
114            raise NotPersisted(target)
115        self._pending_edges.append((source, edge_type, target, properties))
116
117    def commit(self) -> None:
118        """Commit the transaction."""
119        if self._committed:
120            raise TransactionError("Transaction already committed")
121        if self._rolled_back:
122            raise TransactionError("Transaction already rolled back")
123
124        if self._tx is None:
125            raise TransactionError("Transaction not started")
126
127        try:
128            # Create pending nodes
129            for node in self._pending_nodes:
130                self._session._create_node_in_tx(node, self._tx)
131
132            # Create pending edges
133            for source, edge_type, target, props in self._pending_edges:
134                self._session._create_edge_in_tx(
135                    source, edge_type, target, props, self._tx
136                )
137
138            self._tx.commit()
139            self._committed = True
140
141            # Mark nodes as clean
142            for node in self._pending_nodes:
143                node._mark_clean()
144
145        except Exception as e:
146            self.rollback()
147            raise TransactionError(f"Commit failed: {e}") from e
148
149    def rollback(self) -> None:
150        """Rollback the transaction."""
151        if self._rolled_back:
152            return
153        if self._tx is not None:
154            self._tx.rollback()
155        self._rolled_back = True
156        self._pending_nodes.clear()
157        self._pending_edges.clear()

Transaction context for atomic operations.

Provides commit/rollback semantics for a group of operations.

Example:

with session.transaction() as tx: ... alice = Person(name="Alice") ... tx.add(alice) ... # Auto-commits on success, rolls back on exception

UniTransaction(session: UniSession)
74    def __init__(self, session: UniSession) -> None:
75        self._session = session
76        self._tx: uni_db.Transaction | None = None
77        self._pending_nodes: list[UniNode] = []
78        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
79        self._committed = False
80        self._rolled_back = False
def add(self, entity: UniNode) -> None:
 98    def add(self, entity: UniNode) -> None:
 99        """Add a node to be created in this transaction."""
100        self._pending_nodes.append(entity)

Add a node to be created in this transaction.

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: UniEdge | None = None, **kwargs: Any) -> None:
102    def create_edge(
103        self,
104        source: UniNode,
105        edge_type: str,
106        target: UniNode,
107        properties: UniEdge | None = None,
108        **kwargs: Any,
109    ) -> None:
110        """Create an edge between two nodes in this transaction."""
111        if not source.is_persisted:
112            raise NotPersisted(source)
113        if not target.is_persisted:
114            raise NotPersisted(target)
115        self._pending_edges.append((source, edge_type, target, properties))

Create an edge between two nodes in this transaction.

def commit(self) -> None:
117    def commit(self) -> None:
118        """Commit the transaction."""
119        if self._committed:
120            raise TransactionError("Transaction already committed")
121        if self._rolled_back:
122            raise TransactionError("Transaction already rolled back")
123
124        if self._tx is None:
125            raise TransactionError("Transaction not started")
126
127        try:
128            # Create pending nodes
129            for node in self._pending_nodes:
130                self._session._create_node_in_tx(node, self._tx)
131
132            # Create pending edges
133            for source, edge_type, target, props in self._pending_edges:
134                self._session._create_edge_in_tx(
135                    source, edge_type, target, props, self._tx
136                )
137
138            self._tx.commit()
139            self._committed = True
140
141            # Mark nodes as clean
142            for node in self._pending_nodes:
143                node._mark_clean()
144
145        except Exception as e:
146            self.rollback()
147            raise TransactionError(f"Commit failed: {e}") from e

Commit the transaction.

def rollback(self) -> None:
149    def rollback(self) -> None:
150        """Rollback the transaction."""
151        if self._rolled_back:
152            return
153        if self._tx is not None:
154            self._tx.rollback()
155        self._rolled_back = True
156        self._pending_nodes.clear()
157        self._pending_edges.clear()

Rollback the transaction.

class AsyncUniSession:
134class AsyncUniSession:
135    """
136    Async session for interacting with the graph database.
137
138    Mirrors UniSession with async methods. Uses AsyncUni.
139
140    Example:
141        >>> from uni_db import AsyncUni
142        >>> from uni_pydantic import AsyncUniSession
143        >>>
144        >>> db = await AsyncUni.open("./my_graph")
145        >>> async with AsyncUniSession(db) as session:
146        ...     session.register(Person)
147        ...     await session.sync_schema()
148        ...     alice = Person(name="Alice", age=30)
149        ...     session.add(alice)
150        ...     await session.commit()
151    """
152
153    def __init__(self, db: uni_db.AsyncUni) -> None:
154        self._db = db
155        self._db_session = db.session()
156        self._schema_gen = SchemaGenerator()
157        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
158            WeakValueDictionary()
159        )
160        self._pending_new: list[UniNode] = []
161        self._pending_delete: list[UniNode] = []
162
163    async def __aenter__(self) -> AsyncUniSession:
164        return self
165
166    async def __aexit__(
167        self,
168        exc_type: type[BaseException] | None,
169        exc_val: BaseException | None,
170        exc_tb: TracebackType | None,
171    ) -> None:
172        self.close()
173
174    def close(self) -> None:
175        """Close the session and clear pending state."""
176        self._pending_new.clear()
177        self._pending_delete.clear()
178
179    @property
180    def db(self) -> uni_db.AsyncUni:
181        """Access the underlying uni_db.AsyncUni for low-level operations."""
182        return self._db
183
184    async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any:
185        """
186        Evaluate a Locy program and return derived facts, stats, and warnings.
187
188        Delegates to the underlying ``uni_db.AsyncSession.locy()``.
189        """
190        return await self._db_session.locy(program, params)
191
192    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
193        """Register model classes with the session (sync)."""
194        self._schema_gen.register(*models)
195
196    async def sync_schema(self) -> None:
197        """Synchronize database schema with registered models."""
198        await self._schema_gen.async_apply_to_database(self._db)
199
200    def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]:
201        """Create an async query builder for the given model."""
202        return AsyncQueryBuilder(self, model)
203
204    def add(self, entity: UniNode) -> None:
205        """Add a new entity to be persisted (sync — just collects)."""
206        if entity.is_persisted:
207            raise SessionError(f"Entity {entity!r} is already persisted")
208        entity._session = self
209        self._pending_new.append(entity)
210
211    def add_all(self, entities: Sequence[UniNode]) -> None:
212        """Add multiple entities (sync — just collects)."""
213        for entity in entities:
214            self.add(entity)
215
216    def delete(self, entity: UniNode) -> None:
217        """Mark an entity for deletion (sync — just collects)."""
218        if not entity.is_persisted:
219            raise NotPersisted(entity)
220        self._pending_delete.append(entity)
221
222    async def get(
223        self,
224        model: type[NodeT],
225        vid: int | None = None,
226        uid: str | None = None,
227        **kwargs: Any,
228    ) -> NodeT | None:
229        """Get an entity by ID or unique properties."""
230        if vid is not None:
231            cached = self._identity_map.get((model.__label__, vid))
232            if cached is not None:
233                return cached  # type: ignore[return-value]
234
235        label = model.__label__
236        params: dict[str, Any] = {}
237
238        if vid is not None:
239            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
240            params["vid"] = vid
241        elif uid is not None:
242            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
243            params["uid"] = uid
244        elif kwargs:
245            for k in kwargs:
246                _validate_property(k, model)
247            conditions = [f"n.{k} = ${k}" for k in kwargs]
248            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
249            params.update(kwargs)
250        else:
251            raise ValueError("Must provide vid, uid, or property filters")
252
253        results = await self._db_session.query(cypher, params)
254        if not results:
255            return None
256
257        node_data = _row_to_node_dict(results[0].to_dict())
258        if node_data is None:
259            return None
260        return self._result_to_model(node_data, model)
261
262    async def refresh(self, entity: UniNode) -> None:
263        """Refresh an entity's properties from the database."""
264        if not entity.is_persisted:
265            raise NotPersisted(entity)
266
267        label = entity.__class__.__label__
268        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
269        results = await self._db_session.query(cypher, {"vid": entity._vid})
270
271        if not results:
272            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
273
274        props = _row_to_node_dict(results[0].to_dict())
275        if props is None:
276            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
277        try:
278            hints = get_type_hints(type(entity))
279        except Exception:
280            hints = {}
281
282        for field_name in entity.get_property_fields():
283            if field_name in props:
284                value = props[field_name]
285                if field_name in hints:
286                    value = db_to_python_value(value, hints[field_name])
287                setattr(entity, field_name, value)
288
289        entity._mark_clean()
290
291    async def commit(self) -> None:
292        """Commit all pending changes."""
293        for entity in self._pending_new:
294            await self._create_node(entity)
295
296        for (label, vid), entity in list(self._identity_map.items()):
297            if entity.is_dirty and entity.is_persisted:
298                await self._update_node(entity)
299
300        for entity in self._pending_delete:
301            await self._delete_node(entity)
302
303        await self._db.flush()
304        self._pending_new.clear()
305        self._pending_delete.clear()
306
307    async def rollback(self) -> None:
308        """Discard all pending changes."""
309        for entity in self._pending_new:
310            entity._session = None
311        self._pending_new.clear()
312        self._pending_delete.clear()
313        for entity in list(self._identity_map.values()):
314            if entity.is_dirty:
315                await self.refresh(entity)
316
317    async def transaction(self) -> AsyncUniTransaction:
318        """Create an async transaction. Use as `async with session.transaction() as tx:`."""
319        return AsyncUniTransaction(self)
320
321    async def cypher(
322        self,
323        query: str,
324        params: dict[str, Any] | None = None,
325        result_type: type[NodeT] | None = None,
326    ) -> list[NodeT] | list[dict[str, Any]]:
327        """Execute a raw Cypher query."""
328        results = await self._db_session.query(query, params)
329
330        if result_type is None:
331            return [r.to_dict() for r in results]
332
333        mapped = []
334        for raw_row in results:
335            row = raw_row.to_dict()
336            for key, value in row.items():
337                if isinstance(value, dict):
338                    if "_id" in value and "_label" in value:
339                        instance = self._result_to_model(value, result_type)
340                        if instance:
341                            mapped.append(instance)
342                            break
343                    elif "_label" in value:
344                        label = value["_label"]
345                        if label in self._schema_gen._node_models:
346                            model = self._schema_gen._node_models[label]
347                            instance = self._result_to_model(value, model)
348                            if instance:
349                                mapped.append(instance)
350                                break
351            else:
352                first_value = next(iter(row.values()), None)
353                if isinstance(first_value, dict):
354                    instance = self._result_to_model(first_value, result_type)
355                    if instance:
356                        mapped.append(instance)
357
358        return mapped
359
360    async def create_edge(
361        self,
362        source: UniNode,
363        edge_type: str,
364        target: UniNode,
365        properties: dict[str, Any] | UniEdge | None = None,
366    ) -> None:
367        """Create an edge between two nodes."""
368        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
369            source, target
370        )
371        props = UniSession._normalize_edge_properties(properties)
372
373        props_str = ", ".join(f"{k}: ${k}" for k in props)
374        if props_str:
375            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
376        else:
377            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
378
379        async with await self._db_session.tx() as tx:
380            await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props})
381            await tx.commit()
382
383    async def delete_edge(
384        self, source: UniNode, edge_type: str, target: UniNode
385    ) -> int:
386        """Delete edges between two nodes. Returns the number of deleted edges."""
387        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
388            source, target
389        )
390        cypher = (
391            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
392            f"WHERE a._vid = $src AND b._vid = $dst "
393            f"DELETE r RETURN count(r) as count"
394        )
395        async with await self._db_session.tx() as tx:
396            results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid})
397            await tx.commit()
398        return cast(int, results[0]["count"]) if results else 0
399
400    async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
401        """Bulk-add entities using bulk_writer."""
402        if not entities:
403            return []
404
405        by_label: dict[str, list[UniNode]] = {}
406        for entity in entities:
407            label = entity.__class__.__label__
408            if label not in by_label:
409                by_label[label] = []
410            by_label[label].append(entity)
411
412        all_vids: list[int] = []
413        try:
414            for label, group in by_label.items():
415                for entity in group:
416                    run_hooks(entity, _BEFORE_CREATE)
417                prop_dicts = [e.to_properties() for e in group]
418                tx = await self._db_session.tx()
419                async with await tx.bulk_writer().build() as bw:
420                    vids = await bw.insert_vertices(label, prop_dicts)
421                    await bw.commit()
422                await tx.commit()
423                for entity, vid in zip(group, vids):
424                    entity._attach_session(self, vid)
425                    self._identity_map[(label, vid)] = entity
426                    run_hooks(entity, _AFTER_CREATE)
427                    entity._mark_clean()
428                all_vids.extend(vids)
429        except Exception as e:
430            raise BulkLoadError(f"Bulk insert failed: {e}") from e
431
432        return all_vids
433
434    async def explain(self, cypher: str) -> Any:
435        """Get the query execution plan."""
436        return await self._db_session.explain(cypher)
437
438    async def profile(self, cypher: str) -> Any:
439        """Run the query with profiling and return results + stats."""
440        return await self._db_session.profile(cypher)
441
442    async def save_schema(self, path: str) -> None:
443        """Save the database schema to a file."""
444        await self._db.save_schema(path)
445
446    async def load_schema(self, path: str) -> None:
447        """Load a database schema from a file."""
448        await self._db.load_schema(path)
449
450    # ---- Internal methods ----
451
452    async def _create_node(self, entity: UniNode) -> None:
453        run_hooks(entity, _BEFORE_CREATE)
454        label = entity.__class__.__label__
455        props = entity.to_properties()
456        props_str = ", ".join(f"{k}: ${k}" for k in props)
457        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
458        async with await self._db_session.tx() as tx:
459            results = await tx.query(cypher, props)
460            await tx.commit()
461        if results:
462            vid = results[0]["vid"]
463            entity._attach_session(self, vid)
464            self._identity_map[(label, vid)] = entity
465        run_hooks(entity, _AFTER_CREATE)
466        entity._mark_clean()
467
468    async def _create_node_in_tx(
469        self, entity: UniNode, tx: uni_db.AsyncTransaction
470    ) -> None:
471        run_hooks(entity, _BEFORE_CREATE)
472        label = entity.__class__.__label__
473        props = entity.to_properties()
474        props_str = ", ".join(f"{k}: ${k}" for k in props)
475        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
476        results = await tx.query(cypher, props)
477        if results:
478            vid = results[0]["vid"]
479            entity._attach_session(self, vid)
480            self._identity_map[(label, vid)] = entity
481        run_hooks(entity, _AFTER_CREATE)
482
483    async def _create_edge_in_tx(
484        self,
485        source: UniNode,
486        edge_type: str,
487        target: UniNode,
488        properties: UniEdge | None,
489        tx: uni_db.AsyncTransaction,
490    ) -> None:
491        props = properties.to_properties() if properties else {}
492        src_label = source.__class__.__label__
493        dst_label = target.__class__.__label__
494        props_str = ", ".join(f"{k}: ${k}" for k in props)
495        if props_str:
496            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)"
497        else:
498            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)"
499        params = {"src": source._vid, "dst": target._vid, **props}
500        await tx.query(cypher, params)
501
502    async def _update_node(self, entity: UniNode) -> None:
503        run_hooks(entity, _BEFORE_UPDATE)
504        label = entity.__class__.__label__
505        try:
506            hints = get_type_hints(type(entity))
507        except Exception:
508            hints = {}
509        dirty_props = {}
510        for name in entity._dirty:
511            value = getattr(entity, name)
512            if name in hints:
513                value = python_to_db_value(value, hints[name])
514            dirty_props[name] = value
515        if not dirty_props:
516            return
517        set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props)
518        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}"
519        params = {"vid": entity._vid, **dirty_props}
520        async with await self._db_session.tx() as tx:
521            await tx.execute(cypher, params)
522            await tx.commit()
523        run_hooks(entity, _AFTER_UPDATE)
524        entity._mark_clean()
525
526    async def _delete_node(self, entity: UniNode) -> None:
527        run_hooks(entity, _BEFORE_DELETE)
528        label = entity.__class__.__label__
529        vid = entity._vid
530        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n"
531        async with await self._db_session.tx() as tx:
532            await tx.execute(cypher, {"vid": vid})
533            await tx.commit()
534        if vid is not None and (label, vid) in self._identity_map:
535            del self._identity_map[(label, vid)]
536        entity._vid = None
537        entity._uid = None
538        entity._session = None
539        run_hooks(entity, _AFTER_DELETE)
540
541    def _result_to_model(
542        self,
543        data: dict[str, Any],
544        model: type[NodeT],
545    ) -> NodeT | None:
546        """Convert a query result row to a model instance (sync — pure dict processing)."""
547        if not data:
548            return None
549
550        data = dict(data)
551        data = run_class_hooks(model, _BEFORE_LOAD, data) or data
552
553        vid = data.pop("_id", None)
554        if vid is None:
555            vid = data.pop("_vid", None)
556        if vid is None:
557            vid = data.pop("vid", None)
558        if vid is not None and not isinstance(vid, int):
559            vid = int(vid)
560        data.pop("_label", None)
561
562        try:
563            instance = cast(
564                NodeT,
565                model.from_properties(data, vid=vid, session=self),
566            )
567        except Exception:
568            return None
569
570        if vid is not None:
571            existing = self._identity_map.get((model.__label__, vid))
572            if existing is not None:
573                return cast(NodeT, existing)
574            self._identity_map[(model.__label__, vid)] = instance
575
576        run_hooks(instance, _AFTER_LOAD)
577        return instance
578
579    def _load_relationship(
580        self,
581        entity: UniNode,
582        descriptor: RelationshipDescriptor[Any],
583    ) -> list[UniNode] | UniNode | None:
584        """Sync relationship loading — raises error for async session.
585        Use _async_load_relationship instead."""
586        raise SessionError(
587            "Cannot synchronously load relationships in an async session. "
588            "Use eager_load() or access relationships via async queries."
589        )
590
591    async def _async_eager_load_relationships(
592        self,
593        entities: list[NodeT],
594        relationships: list[str],
595    ) -> None:
596        """Eager load relationships for a list of entities (async)."""
597        if not entities:
598            return
599
600        model = type(entities[0])
601        rel_configs = model.get_relationship_fields()
602
603        for rel_name in relationships:
604            if rel_name not in rel_configs:
605                continue
606
607            config = rel_configs[rel_name]
608            label = model.__label__
609            vids = [e._vid for e in entities if e._vid is not None]
610
611            if not vids:
612                continue
613
614            pattern = _edge_pattern(config.edge_type, config.direction)
615            cypher = (
616                f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids "
617                f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
618            )
619            results = await self._db_session.query(cypher, {"vids": vids})
620
621            by_source: dict[int, list[Any]] = {}
622            for raw_row in results:
623                row = raw_row.to_dict()
624                src_vid = row["src_vid"]
625                node_data = _row_to_node_dict(row)
626                if node_data is None:
627                    continue
628                if src_vid not in by_source:
629                    by_source[src_vid] = []
630                by_source[src_vid].append(node_data)
631
632            for entity in entities:
633                if entity._vid in by_source:
634                    related = by_source[entity._vid]
635                    cache_attr = f"_rel_cache_{rel_name}"
636                    setattr(entity, cache_attr, related)

Async session for interacting with the graph database.

Mirrors UniSession with async methods. Uses AsyncUni.

Example:

from uni_db import AsyncUni from uni_pydantic import AsyncUniSession

db = await AsyncUni.open("./my_graph") async with AsyncUniSession(db) as session: ... session.register(Person) ... await session.sync_schema() ... alice = Person(name="Alice", age=30) ... session.add(alice) ... await session.commit()

AsyncUniSession(db: AsyncUni)
153    def __init__(self, db: uni_db.AsyncUni) -> None:
154        self._db = db
155        self._db_session = db.session()
156        self._schema_gen = SchemaGenerator()
157        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
158            WeakValueDictionary()
159        )
160        self._pending_new: list[UniNode] = []
161        self._pending_delete: list[UniNode] = []
def close(self) -> None:
174    def close(self) -> None:
175        """Close the session and clear pending state."""
176        self._pending_new.clear()
177        self._pending_delete.clear()

Close the session and clear pending state.

db: AsyncUni
179    @property
180    def db(self) -> uni_db.AsyncUni:
181        """Access the underlying uni_db.AsyncUni for low-level operations."""
182        return self._db

Access the underlying uni_db.AsyncUni for low-level operations.

async def locy(self, program: str, params: dict[str, typing.Any] | None = None) -> Any:
184    async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any:
185        """
186        Evaluate a Locy program and return derived facts, stats, and warnings.
187
188        Delegates to the underlying ``uni_db.AsyncSession.locy()``.
189        """
190        return await self._db_session.locy(program, params)

Evaluate a Locy program and return derived facts, stats, and warnings.

Delegates to the underlying uni_db.AsyncSession.locy().

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
192    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
193        """Register model classes with the session (sync)."""
194        self._schema_gen.register(*models)

Register model classes with the session (sync).

async def sync_schema(self) -> None:
196    async def sync_schema(self) -> None:
197        """Synchronize database schema with registered models."""
198        await self._schema_gen.async_apply_to_database(self._db)

Synchronize database schema with registered models.

def query( self, model: type[~NodeT]) -> AsyncQueryBuilder[~NodeT]:
200    def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]:
201        """Create an async query builder for the given model."""
202        return AsyncQueryBuilder(self, model)

Create an async query builder for the given model.

def add(self, entity: UniNode) -> None:
204    def add(self, entity: UniNode) -> None:
205        """Add a new entity to be persisted (sync — just collects)."""
206        if entity.is_persisted:
207            raise SessionError(f"Entity {entity!r} is already persisted")
208        entity._session = self
209        self._pending_new.append(entity)

Add a new entity to be persisted (sync — just collects).

def add_all(self, entities: Sequence[UniNode]) -> None:
211    def add_all(self, entities: Sequence[UniNode]) -> None:
212        """Add multiple entities (sync — just collects)."""
213        for entity in entities:
214            self.add(entity)

Add multiple entities (sync — just collects).

def delete(self, entity: UniNode) -> None:
216    def delete(self, entity: UniNode) -> None:
217        """Mark an entity for deletion (sync — just collects)."""
218        if not entity.is_persisted:
219            raise NotPersisted(entity)
220        self._pending_delete.append(entity)

Mark an entity for deletion (sync — just collects).

async def get( self, model: type[~NodeT], vid: int | None = None, uid: str | None = None, **kwargs: Any) -> Optional[~NodeT]:
222    async def get(
223        self,
224        model: type[NodeT],
225        vid: int | None = None,
226        uid: str | None = None,
227        **kwargs: Any,
228    ) -> NodeT | None:
229        """Get an entity by ID or unique properties."""
230        if vid is not None:
231            cached = self._identity_map.get((model.__label__, vid))
232            if cached is not None:
233                return cached  # type: ignore[return-value]
234
235        label = model.__label__
236        params: dict[str, Any] = {}
237
238        if vid is not None:
239            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
240            params["vid"] = vid
241        elif uid is not None:
242            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
243            params["uid"] = uid
244        elif kwargs:
245            for k in kwargs:
246                _validate_property(k, model)
247            conditions = [f"n.{k} = ${k}" for k in kwargs]
248            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
249            params.update(kwargs)
250        else:
251            raise ValueError("Must provide vid, uid, or property filters")
252
253        results = await self._db_session.query(cypher, params)
254        if not results:
255            return None
256
257        node_data = _row_to_node_dict(results[0].to_dict())
258        if node_data is None:
259            return None
260        return self._result_to_model(node_data, model)

Get an entity by ID or unique properties.

async def refresh(self, entity: UniNode) -> None:
262    async def refresh(self, entity: UniNode) -> None:
263        """Refresh an entity's properties from the database."""
264        if not entity.is_persisted:
265            raise NotPersisted(entity)
266
267        label = entity.__class__.__label__
268        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
269        results = await self._db_session.query(cypher, {"vid": entity._vid})
270
271        if not results:
272            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
273
274        props = _row_to_node_dict(results[0].to_dict())
275        if props is None:
276            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
277        try:
278            hints = get_type_hints(type(entity))
279        except Exception:
280            hints = {}
281
282        for field_name in entity.get_property_fields():
283            if field_name in props:
284                value = props[field_name]
285                if field_name in hints:
286                    value = db_to_python_value(value, hints[field_name])
287                setattr(entity, field_name, value)
288
289        entity._mark_clean()

Refresh an entity's properties from the database.

async def commit(self) -> None:
291    async def commit(self) -> None:
292        """Commit all pending changes."""
293        for entity in self._pending_new:
294            await self._create_node(entity)
295
296        for (label, vid), entity in list(self._identity_map.items()):
297            if entity.is_dirty and entity.is_persisted:
298                await self._update_node(entity)
299
300        for entity in self._pending_delete:
301            await self._delete_node(entity)
302
303        await self._db.flush()
304        self._pending_new.clear()
305        self._pending_delete.clear()

Commit all pending changes.

async def rollback(self) -> None:
307    async def rollback(self) -> None:
308        """Discard all pending changes."""
309        for entity in self._pending_new:
310            entity._session = None
311        self._pending_new.clear()
312        self._pending_delete.clear()
313        for entity in list(self._identity_map.values()):
314            if entity.is_dirty:
315                await self.refresh(entity)

Discard all pending changes.

async def transaction(self) -> AsyncUniTransaction:
317    async def transaction(self) -> AsyncUniTransaction:
318        """Create an async transaction. Use as `async with session.transaction() as tx:`."""
319        return AsyncUniTransaction(self)

Create an async transaction. Use as async with session.transaction() as tx:.

async def cypher( self, query: str, params: dict[str, typing.Any] | None = None, result_type: type[~NodeT] | None = None) -> list[~NodeT] | list[dict[str, typing.Any]]:
321    async def cypher(
322        self,
323        query: str,
324        params: dict[str, Any] | None = None,
325        result_type: type[NodeT] | None = None,
326    ) -> list[NodeT] | list[dict[str, Any]]:
327        """Execute a raw Cypher query."""
328        results = await self._db_session.query(query, params)
329
330        if result_type is None:
331            return [r.to_dict() for r in results]
332
333        mapped = []
334        for raw_row in results:
335            row = raw_row.to_dict()
336            for key, value in row.items():
337                if isinstance(value, dict):
338                    if "_id" in value and "_label" in value:
339                        instance = self._result_to_model(value, result_type)
340                        if instance:
341                            mapped.append(instance)
342                            break
343                    elif "_label" in value:
344                        label = value["_label"]
345                        if label in self._schema_gen._node_models:
346                            model = self._schema_gen._node_models[label]
347                            instance = self._result_to_model(value, model)
348                            if instance:
349                                mapped.append(instance)
350                                break
351            else:
352                first_value = next(iter(row.values()), None)
353                if isinstance(first_value, dict):
354                    instance = self._result_to_model(first_value, result_type)
355                    if instance:
356                        mapped.append(instance)
357
358        return mapped

Execute a raw Cypher query.

async def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any] | UniEdge | None = None) -> None:
360    async def create_edge(
361        self,
362        source: UniNode,
363        edge_type: str,
364        target: UniNode,
365        properties: dict[str, Any] | UniEdge | None = None,
366    ) -> None:
367        """Create an edge between two nodes."""
368        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
369            source, target
370        )
371        props = UniSession._normalize_edge_properties(properties)
372
373        props_str = ", ".join(f"{k}: ${k}" for k in props)
374        if props_str:
375            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
376        else:
377            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
378
379        async with await self._db_session.tx() as tx:
380            await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props})
381            await tx.commit()

Create an edge between two nodes.

async def delete_edge( self, source: UniNode, edge_type: str, target: UniNode) -> int:
383    async def delete_edge(
384        self, source: UniNode, edge_type: str, target: UniNode
385    ) -> int:
386        """Delete edges between two nodes. Returns the number of deleted edges."""
387        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
388            source, target
389        )
390        cypher = (
391            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
392            f"WHERE a._vid = $src AND b._vid = $dst "
393            f"DELETE r RETURN count(r) as count"
394        )
395        async with await self._db_session.tx() as tx:
396            results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid})
397            await tx.commit()
398        return cast(int, results[0]["count"]) if results else 0

Delete edges between two nodes. Returns the number of deleted edges.

async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
400    async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
401        """Bulk-add entities using bulk_writer."""
402        if not entities:
403            return []
404
405        by_label: dict[str, list[UniNode]] = {}
406        for entity in entities:
407            label = entity.__class__.__label__
408            if label not in by_label:
409                by_label[label] = []
410            by_label[label].append(entity)
411
412        all_vids: list[int] = []
413        try:
414            for label, group in by_label.items():
415                for entity in group:
416                    run_hooks(entity, _BEFORE_CREATE)
417                prop_dicts = [e.to_properties() for e in group]
418                tx = await self._db_session.tx()
419                async with await tx.bulk_writer().build() as bw:
420                    vids = await bw.insert_vertices(label, prop_dicts)
421                    await bw.commit()
422                await tx.commit()
423                for entity, vid in zip(group, vids):
424                    entity._attach_session(self, vid)
425                    self._identity_map[(label, vid)] = entity
426                    run_hooks(entity, _AFTER_CREATE)
427                    entity._mark_clean()
428                all_vids.extend(vids)
429        except Exception as e:
430            raise BulkLoadError(f"Bulk insert failed: {e}") from e
431
432        return all_vids

Bulk-add entities using bulk_writer.

async def explain(self, cypher: str) -> Any:
434    async def explain(self, cypher: str) -> Any:
435        """Get the query execution plan."""
436        return await self._db_session.explain(cypher)

Get the query execution plan.

async def profile(self, cypher: str) -> Any:
438    async def profile(self, cypher: str) -> Any:
439        """Run the query with profiling and return results + stats."""
440        return await self._db_session.profile(cypher)

Run the query with profiling and return results + stats.

async def save_schema(self, path: str) -> None:
442    async def save_schema(self, path: str) -> None:
443        """Save the database schema to a file."""
444        await self._db.save_schema(path)

Save the database schema to a file.

async def load_schema(self, path: str) -> None:
446    async def load_schema(self, path: str) -> None:
447        """Load a database schema from a file."""
448        await self._db.load_schema(path)

Load a database schema from a file.

class AsyncUniTransaction:
 54class AsyncUniTransaction:
 55    """Async transaction context for atomic operations."""
 56
 57    def __init__(self, session: AsyncUniSession) -> None:
 58        self._session = session
 59        self._tx: uni_db.AsyncTransaction | None = None
 60        self._pending_nodes: list[UniNode] = []
 61        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
 62        self._committed = False
 63        self._rolled_back = False
 64
 65    async def __aenter__(self) -> AsyncUniTransaction:
 66        self._tx = await self._session._db_session.tx()
 67        return self
 68
 69    async def __aexit__(
 70        self,
 71        exc_type: type[BaseException] | None,
 72        exc_val: BaseException | None,
 73        exc_tb: TracebackType | None,
 74    ) -> None:
 75        if exc_type is not None:
 76            await self.rollback()
 77            return
 78        if not self._committed and not self._rolled_back:
 79            await self.commit()
 80
 81    def add(self, entity: UniNode) -> None:
 82        """Add a node to be created in this transaction (sync — just collects)."""
 83        self._pending_nodes.append(entity)
 84
 85    def create_edge(
 86        self,
 87        source: UniNode,
 88        edge_type: str,
 89        target: UniNode,
 90        properties: UniEdge | None = None,
 91    ) -> None:
 92        """Create an edge between two nodes in this transaction (sync — just collects)."""
 93        if not source.is_persisted:
 94            raise NotPersisted(source)
 95        if not target.is_persisted:
 96            raise NotPersisted(target)
 97        self._pending_edges.append((source, edge_type, target, properties))
 98
 99    async def commit(self) -> None:
100        """Commit the transaction."""
101        if self._committed:
102            raise TransactionError("Transaction already committed")
103        if self._rolled_back:
104            raise TransactionError("Transaction already rolled back")
105        if self._tx is None:
106            raise TransactionError("Transaction not started")
107
108        try:
109            for node in self._pending_nodes:
110                await self._session._create_node_in_tx(node, self._tx)
111            for source, edge_type, target, props in self._pending_edges:
112                await self._session._create_edge_in_tx(
113                    source, edge_type, target, props, self._tx
114                )
115            await self._tx.commit()
116            self._committed = True
117            for node in self._pending_nodes:
118                node._mark_clean()
119        except Exception as e:
120            await self.rollback()
121            raise TransactionError(f"Commit failed: {e}") from e
122
123    async def rollback(self) -> None:
124        """Rollback the transaction."""
125        if self._rolled_back:
126            return
127        if self._tx is not None:
128            await self._tx.rollback()
129        self._rolled_back = True
130        self._pending_nodes.clear()
131        self._pending_edges.clear()

Async transaction context for atomic operations.

AsyncUniTransaction(session: AsyncUniSession)
57    def __init__(self, session: AsyncUniSession) -> None:
58        self._session = session
59        self._tx: uni_db.AsyncTransaction | None = None
60        self._pending_nodes: list[UniNode] = []
61        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
62        self._committed = False
63        self._rolled_back = False
def add(self, entity: UniNode) -> None:
81    def add(self, entity: UniNode) -> None:
82        """Add a node to be created in this transaction (sync — just collects)."""
83        self._pending_nodes.append(entity)

Add a node to be created in this transaction (sync — just collects).

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: UniEdge | None = None) -> None:
85    def create_edge(
86        self,
87        source: UniNode,
88        edge_type: str,
89        target: UniNode,
90        properties: UniEdge | None = None,
91    ) -> None:
92        """Create an edge between two nodes in this transaction (sync — just collects)."""
93        if not source.is_persisted:
94            raise NotPersisted(source)
95        if not target.is_persisted:
96            raise NotPersisted(target)
97        self._pending_edges.append((source, edge_type, target, properties))

Create an edge between two nodes in this transaction (sync — just collects).

async def commit(self) -> None:
 99    async def commit(self) -> None:
100        """Commit the transaction."""
101        if self._committed:
102            raise TransactionError("Transaction already committed")
103        if self._rolled_back:
104            raise TransactionError("Transaction already rolled back")
105        if self._tx is None:
106            raise TransactionError("Transaction not started")
107
108        try:
109            for node in self._pending_nodes:
110                await self._session._create_node_in_tx(node, self._tx)
111            for source, edge_type, target, props in self._pending_edges:
112                await self._session._create_edge_in_tx(
113                    source, edge_type, target, props, self._tx
114                )
115            await self._tx.commit()
116            self._committed = True
117            for node in self._pending_nodes:
118                node._mark_clean()
119        except Exception as e:
120            await self.rollback()
121            raise TransactionError(f"Commit failed: {e}") from e

Commit the transaction.

async def rollback(self) -> None:
123    async def rollback(self) -> None:
124        """Rollback the transaction."""
125        if self._rolled_back:
126            return
127        if self._tx is not None:
128            await self._tx.rollback()
129        self._rolled_back = True
130        self._pending_nodes.clear()
131        self._pending_edges.clear()

Rollback the transaction.

def Field( default: Any = Ellipsis, *, default_factory: Callable[[], typing.Any] | None = None, alias: str | None = None, title: str | None = None, description: str | None = None, examples: list[typing.Any] | None = None, exclude: bool = False, json_schema_extra: dict[str, typing.Any] | None = None, index: Optional[Literal['btree', 'hash', 'fulltext', 'vector']] = None, unique: bool = False, tokenizer: str | None = None, metric: Optional[Literal['l2', 'cosine', 'dot']] = None, generated: str | None = None) -> Any:
 69def Field(
 70    default: Any = ...,
 71    *,
 72    default_factory: Callable[[], Any] | None = None,
 73    alias: str | None = None,
 74    title: str | None = None,
 75    description: str | None = None,
 76    examples: list[Any] | None = None,
 77    exclude: bool = False,
 78    json_schema_extra: dict[str, Any] | None = None,
 79    # Uni-specific options
 80    index: IndexType | None = None,
 81    unique: bool = False,
 82    tokenizer: str | None = None,
 83    metric: VectorMetric | None = None,
 84    generated: str | None = None,
 85) -> Any:
 86    """
 87    Create a field with uni-pydantic configuration.
 88
 89    This extends Pydantic's Field with graph database options.
 90
 91    Args:
 92        default: Default value for the field.
 93        default_factory: Factory function for default value.
 94        alias: Field alias for serialization.
 95        title: Human-readable title.
 96        description: Field description.
 97        examples: Example values.
 98        exclude: Exclude from serialization.
 99        json_schema_extra: Extra JSON schema properties.
100        index: Index type ("btree", "hash", "fulltext", "vector").
101        unique: Whether to create a unique constraint.
102        tokenizer: Tokenizer for fulltext index (default: "standard").
103        metric: Distance metric for vector index ("l2", "cosine", "dot").
104        generated: Expression for generated/computed property.
105
106    Returns:
107        A Pydantic FieldInfo with uni-pydantic metadata attached.
108
109    Examples:
110        >>> class Person(UniNode):
111        ...     name: str = Field(index="btree")
112        ...     email: str = Field(unique=True)
113        ...     bio: str = Field(index="fulltext", tokenizer="standard")
114        ...     embedding: Vector[768] = Field(metric="cosine")
115    """
116    # Default tokenizer for fulltext indexes
117    if index == "fulltext" and tokenizer is None:
118        tokenizer = "standard"
119
120    # Store uni config in json_schema_extra
121    uni_config = FieldConfig(
122        index=index,
123        unique=unique,
124        tokenizer=tokenizer,
125        metric=metric,
126        generated=generated,
127        default=default,
128        default_factory=default_factory,
129        alias=alias,
130        title=title,
131        description=description,
132        examples=examples,
133        exclude=exclude,
134        json_schema_extra=json_schema_extra,
135    )
136
137    # Merge uni config into json_schema_extra
138    extra = json_schema_extra or {}
139    extra["uni_config"] = uni_config
140
141    # Create Pydantic FieldInfo
142    from pydantic.fields import FieldInfo as PydanticFieldInfo
143
144    if default_factory is not None:
145        return PydanticFieldInfo(
146            default_factory=default_factory,
147            alias=alias,
148            title=title,
149            description=description,
150            examples=examples,
151            exclude=exclude,
152            json_schema_extra=extra,
153        )
154    elif default is not ...:
155        return PydanticFieldInfo(
156            default=default,
157            alias=alias,
158            title=title,
159            description=description,
160            examples=examples,
161            exclude=exclude,
162            json_schema_extra=extra,
163        )
164    else:
165        return PydanticFieldInfo(
166            alias=alias,
167            title=title,
168            description=description,
169            examples=examples,
170            exclude=exclude,
171            json_schema_extra=extra,
172        )

Create a field with uni-pydantic configuration.

This extends Pydantic's Field with graph database options.

Args: default: Default value for the field. default_factory: Factory function for default value. alias: Field alias for serialization. title: Human-readable title. description: Field description. examples: Example values. exclude: Exclude from serialization. json_schema_extra: Extra JSON schema properties. index: Index type ("btree", "hash", "fulltext", "vector"). unique: Whether to create a unique constraint. tokenizer: Tokenizer for fulltext index (default: "standard"). metric: Distance metric for vector index ("l2", "cosine", "dot"). generated: Expression for generated/computed property.

Returns: A Pydantic FieldInfo with uni-pydantic metadata attached.

Examples:

class Person(UniNode): ... name: str = Field(index="btree") ... email: str = Field(unique=True) ... bio: str = Field(index="fulltext", tokenizer="standard") ... embedding: Vector[768] = Field(metric="cosine")

@dataclass
class FieldConfig:
41@dataclass
42class FieldConfig:
43    """Configuration for a uni-pydantic field."""
44
45    # Index configuration
46    index: IndexType | None = None
47    unique: bool = False
48
49    # Fulltext index options
50    tokenizer: str | None = None
51
52    # Vector index options
53    metric: VectorMetric | None = None
54
55    # Generated/computed property
56    generated: str | None = None
57
58    # Pydantic field options (passed through)
59    default: Any = dataclass_field(default_factory=lambda: ...)
60    default_factory: Callable[[], Any] | None = None
61    alias: str | None = None
62    title: str | None = None
63    description: str | None = None
64    examples: list[Any] | None = None
65    exclude: bool = False
66    json_schema_extra: dict[str, Any] | None = None

Configuration for a uni-pydantic field.

FieldConfig( index: Optional[Literal['btree', 'hash', 'fulltext', 'vector']] = None, unique: bool = False, tokenizer: str | None = None, metric: Optional[Literal['l2', 'cosine', 'dot']] = None, generated: str | None = None, default: Any = <factory>, default_factory: Callable[[], typing.Any] | None = None, alias: str | None = None, title: str | None = None, description: str | None = None, examples: list[typing.Any] | None = None, exclude: bool = False, json_schema_extra: dict[str, typing.Any] | None = None)
index: Optional[Literal['btree', 'hash', 'fulltext', 'vector']] = None
unique: bool = False
tokenizer: str | None = None
metric: Optional[Literal['l2', 'cosine', 'dot']] = None
generated: str | None = None
default: Any
default_factory: Callable[[], typing.Any] | None = None
alias: str | None = None
title: str | None = None
description: str | None = None
examples: list[typing.Any] | None = None
exclude: bool = False
json_schema_extra: dict[str, typing.Any] | None = None
def Relationship( edge_type: str, *, direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing', edge_model: type[UniEdge] | None = None, eager: bool = False, cascade_delete: bool = False) -> Any:
272def Relationship(
273    edge_type: str,
274    *,
275    direction: Direction = "outgoing",
276    edge_model: type[UniEdge] | None = None,
277    eager: bool = False,
278    cascade_delete: bool = False,
279) -> Any:
280    """
281    Declare a relationship to another node type.
282
283    Relationships are lazy-loaded by default. Use eager=True or
284    query.eager_load() to load them with the parent query.
285
286    Args:
287        edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT").
288        direction: Relationship direction:
289            - "outgoing": Follow edges from this node (default)
290            - "incoming": Follow edges to this node
291            - "both": Follow edges in both directions
292        edge_model: Optional UniEdge subclass for typed edge properties.
293        eager: Whether to eager-load this relationship by default.
294        cascade_delete: Whether to delete related edges when this node is deleted.
295
296    Returns:
297        A RelationshipDescriptor that will be processed during model creation.
298
299    Examples:
300        >>> class Person(UniNode):
301        ...     # Outgoing relationship (default)
302        ...     follows: list["Person"] = Relationship("FOLLOWS")
303        ...
304        ...     # Incoming relationship
305        ...     followers: list["Person"] = Relationship("FOLLOWS", direction="incoming")
306        ...
307        ...     # Single optional relationship
308        ...     manager: "Person | None" = Relationship("REPORTS_TO")
309        ...
310        ...     # Relationship with edge properties
311        ...     friendships: list[tuple["Person", FriendshipEdge]] = Relationship(
312        ...         "FRIEND_OF",
313        ...         edge_model=FriendshipEdge
314        ...     )
315    """
316    config = RelationshipConfig(
317        edge_type=edge_type,
318        direction=direction,
319        edge_model=edge_model,
320        eager=eager,
321        cascade_delete=cascade_delete,
322    )
323    # Return a marker that will be processed by the metaclass
324    return _RelationshipMarker(config)

Declare a relationship to another node type.

Relationships are lazy-loaded by default. Use eager=True or query.eager_load() to load them with the parent query.

Args: edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT"). direction: Relationship direction: - "outgoing": Follow edges from this node (default) - "incoming": Follow edges to this node - "both": Follow edges in both directions edge_model: Optional UniEdge subclass for typed edge properties. eager: Whether to eager-load this relationship by default. cascade_delete: Whether to delete related edges when this node is deleted.

Returns: A RelationshipDescriptor that will be processed during model creation.

Examples:

class Person(UniNode): ... # Outgoing relationship (default) ... follows: list["Person"] = Relationship("FOLLOWS") ... ... # Incoming relationship ... followers: list["Person"] = Relationship("FOLLOWS", direction="incoming") ... ... # Single optional relationship ... manager: "Person | None" = Relationship("REPORTS_TO") ... ... # Relationship with edge properties ... friendships: list[tuple["Person", FriendshipEdge]] = Relationship( ... "FRIEND_OF", ... edge_model=FriendshipEdge ... )

@dataclass
class RelationshipConfig:
185@dataclass
186class RelationshipConfig:
187    """Configuration for a relationship field."""
188
189    edge_type: str
190    direction: Direction = "outgoing"
191    edge_model: type[UniEdge] | None = None
192    eager: bool = False
193    cascade_delete: bool = False

Configuration for a relationship field.

RelationshipConfig( edge_type: str, direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing', edge_model: type[UniEdge] | None = None, eager: bool = False, cascade_delete: bool = False)
edge_type: str
direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing'
edge_model: type[UniEdge] | None = None
eager: bool = False
cascade_delete: bool = False
class RelationshipDescriptor(typing.Generic[~NodeT]):
196class RelationshipDescriptor(Generic[NodeT]):
197    """
198    Descriptor for relationship fields that enables lazy loading.
199
200    When accessed on an instance, it returns the related nodes.
201    When accessed on the class, it returns the descriptor for query building.
202    """
203
204    def __init__(
205        self,
206        config: RelationshipConfig,
207        field_name: str,
208        target_type: type[NodeT] | str | None = None,
209        is_list: bool = True,
210    ) -> None:
211        self.config = config
212        self.field_name = field_name
213        self.target_type = target_type
214        self.is_list = is_list
215        self._cache_attr = f"_rel_cache_{field_name}"
216
217    def __set_name__(self, owner: type, name: str) -> None:
218        self.field_name = name
219        self._cache_attr = f"_rel_cache_{name}"
220
221    @overload
222    def __get__(
223        self, obj: None, objtype: type[NodeT]
224    ) -> RelationshipDescriptor[NodeT]: ...
225
226    @overload
227    def __get__(
228        self, obj: NodeT, objtype: type[NodeT] | None = None
229    ) -> list[NodeT] | NodeT | None: ...
230
231    def __get__(
232        self, obj: NodeT | None, objtype: type[NodeT] | None = None
233    ) -> RelationshipDescriptor[NodeT] | list[NodeT] | NodeT | None:
234        if obj is None:
235            # Class-level access returns the descriptor
236            return self
237
238        # Instance-level access - check cache first
239        if hasattr(obj, self._cache_attr):
240            cached = getattr(obj, self._cache_attr)
241            return cast("list[NodeT] | NodeT | None", cached)
242
243        # Check if we have a session for lazy loading
244        session = getattr(obj, "_session", None)
245        if session is None:
246            from .exceptions import LazyLoadError
247
248            raise LazyLoadError(
249                self.field_name,
250                "No session attached. Use session.get() or enable eager loading.",
251            )
252
253        # Lazy load the relationship
254        result = session._load_relationship(obj, self)
255
256        # Cache the result
257        setattr(obj, self._cache_attr, result)
258        return cast("list[NodeT] | NodeT | None", result)
259
260    def __set__(self, obj: NodeT, value: list[NodeT] | NodeT | None) -> None:
261        # Allow setting the cached value (e.g., during eager loading)
262        setattr(obj, self._cache_attr, value)
263
264    def __repr__(self) -> str:
265        return f"Relationship({self.config.edge_type!r}, direction={self.config.direction!r})"

Descriptor for relationship fields that enables lazy loading.

When accessed on an instance, it returns the related nodes. When accessed on the class, it returns the descriptor for query building.

RelationshipDescriptor( config: RelationshipConfig, field_name: str, target_type: type[~NodeT] | str | None = None, is_list: bool = True)
204    def __init__(
205        self,
206        config: RelationshipConfig,
207        field_name: str,
208        target_type: type[NodeT] | str | None = None,
209        is_list: bool = True,
210    ) -> None:
211        self.config = config
212        self.field_name = field_name
213        self.target_type = target_type
214        self.is_list = is_list
215        self._cache_attr = f"_rel_cache_{field_name}"
config
field_name
target_type
is_list
def get_field_config( field_info: pydantic.fields.FieldInfo) -> FieldConfig | None:
175def get_field_config(field_info: FieldInfo) -> FieldConfig | None:
176    """Extract uni-pydantic config from a Pydantic FieldInfo."""
177    extra = field_info.json_schema_extra
178    if isinstance(extra, dict):
179        config = extra.get("uni_config")
180        if isinstance(config, FieldConfig):
181            return config
182    return None

Extract uni-pydantic config from a Pydantic FieldInfo.

IndexType = typing.Literal['btree', 'hash', 'fulltext', 'vector']
Direction = typing.Literal['outgoing', 'incoming', 'both']
VectorMetric = typing.Literal['l2', 'cosine', 'dot']
class Btic:
150class Btic:
151    """A BTIC temporal interval value for Uni graph database.
152
153    Construct from an ISO 8601-inspired string literal::
154
155        Btic("1985")
156        Btic("1985-03/2024-06")
157        Btic("~1985")           # approximate certainty
158        Btic("2020-03/")        # ongoing (unbounded hi)
159
160    Use as a Pydantic model field type::
161
162        class Event(UniNode):
163            when: Btic
164    """
165
166    def __init__(self, value: str | object) -> None:
167        if _PyBtic is None:
168            raise ImportError("uni_db is required for Btic type")
169        if isinstance(value, str):
170            self._inner = _PyBtic(value)
171        elif _PyBtic is not None and isinstance(value, _PyBtic):
172            self._inner = value
173        elif isinstance(value, Btic):
174            self._inner = value._inner
175        else:
176            raise TypeError(f"Expected str or Btic, got {type(value)}")
177
178    @property
179    def lo(self) -> int:
180        """Lower bound in milliseconds since epoch."""
181        return self._inner.lo
182
183    @property
184    def hi(self) -> int:
185        """Upper bound in milliseconds since epoch."""
186        return self._inner.hi
187
188    @property
189    def meta(self) -> int:
190        """Raw 64-bit metadata word."""
191        return self._inner.meta
192
193    @property
194    def lo_granularity(self) -> str:
195        """Lower bound granularity name."""
196        return self._inner.lo_granularity
197
198    @property
199    def hi_granularity(self) -> str:
200        """Upper bound granularity name."""
201        return self._inner.hi_granularity
202
203    @property
204    def lo_certainty(self) -> str:
205        """Lower bound certainty name."""
206        return self._inner.lo_certainty
207
208    @property
209    def hi_certainty(self) -> str:
210        """Upper bound certainty name."""
211        return self._inner.hi_certainty
212
213    @property
214    def duration_ms(self) -> int | None:
215        """Duration in milliseconds, or None if unbounded."""
216        return self._inner.duration_ms
217
218    @property
219    def is_instant(self) -> bool:
220        """True if the interval is exactly 1 millisecond wide."""
221        return self._inner.is_instant
222
223    @property
224    def is_unbounded(self) -> bool:
225        """True if either bound is infinite."""
226        return self._inner.is_unbounded
227
228    @property
229    def is_finite(self) -> bool:
230        """True if both bounds are finite."""
231        return self._inner.is_finite
232
233    def __repr__(self) -> str:
234        return f'Btic("{self._inner}")'
235
236    def __str__(self) -> str:
237        return str(self._inner)
238
239    def __eq__(self, other: object) -> bool:
240        if isinstance(other, Btic):
241            return self._inner == other._inner
242        return False
243
244    def __hash__(self) -> int:
245        return hash(self._inner)
246
247    @classmethod
248    def __get_pydantic_core_schema__(
249        cls, source_type: Any, handler: GetCoreSchemaHandler
250    ) -> CoreSchema:
251        """Make Btic compatible with Pydantic v2."""
252
253        def validate_btic(v: Any) -> Btic:
254            if isinstance(v, Btic):
255                return v
256            if isinstance(v, str):
257                return Btic(v)
258            if _PyBtic is not None and isinstance(v, _PyBtic):
259                return Btic(v)
260            raise TypeError(f"Expected str or Btic, got {type(v)}")
261
262        return core_schema.no_info_plain_validator_function(
263            validate_btic,
264            serialization=core_schema.plain_serializer_function_ser_schema(
265                lambda v: str(v._inner) if isinstance(v, Btic) else str(v),
266                info_arg=False,
267            ),
268        )

A BTIC temporal interval value for Uni graph database.

Construct from an ISO 8601-inspired string literal::

Btic("1985")
Btic("1985-03/2024-06")
Btic("~1985")           # approximate certainty
Btic("2020-03/")        # ongoing (unbounded hi)

Use as a Pydantic model field type::

class Event(UniNode):
    when: Btic
Btic(value: str | object)
166    def __init__(self, value: str | object) -> None:
167        if _PyBtic is None:
168            raise ImportError("uni_db is required for Btic type")
169        if isinstance(value, str):
170            self._inner = _PyBtic(value)
171        elif _PyBtic is not None and isinstance(value, _PyBtic):
172            self._inner = value
173        elif isinstance(value, Btic):
174            self._inner = value._inner
175        else:
176            raise TypeError(f"Expected str or Btic, got {type(value)}")
lo: int
178    @property
179    def lo(self) -> int:
180        """Lower bound in milliseconds since epoch."""
181        return self._inner.lo

Lower bound in milliseconds since epoch.

hi: int
183    @property
184    def hi(self) -> int:
185        """Upper bound in milliseconds since epoch."""
186        return self._inner.hi

Upper bound in milliseconds since epoch.

meta: int
188    @property
189    def meta(self) -> int:
190        """Raw 64-bit metadata word."""
191        return self._inner.meta

Raw 64-bit metadata word.

lo_granularity: str
193    @property
194    def lo_granularity(self) -> str:
195        """Lower bound granularity name."""
196        return self._inner.lo_granularity

Lower bound granularity name.

hi_granularity: str
198    @property
199    def hi_granularity(self) -> str:
200        """Upper bound granularity name."""
201        return self._inner.hi_granularity

Upper bound granularity name.

lo_certainty: str
203    @property
204    def lo_certainty(self) -> str:
205        """Lower bound certainty name."""
206        return self._inner.lo_certainty

Lower bound certainty name.

hi_certainty: str
208    @property
209    def hi_certainty(self) -> str:
210        """Upper bound certainty name."""
211        return self._inner.hi_certainty

Upper bound certainty name.

duration_ms: int | None
213    @property
214    def duration_ms(self) -> int | None:
215        """Duration in milliseconds, or None if unbounded."""
216        return self._inner.duration_ms

Duration in milliseconds, or None if unbounded.

is_instant: bool
218    @property
219    def is_instant(self) -> bool:
220        """True if the interval is exactly 1 millisecond wide."""
221        return self._inner.is_instant

True if the interval is exactly 1 millisecond wide.

is_unbounded: bool
223    @property
224    def is_unbounded(self) -> bool:
225        """True if either bound is infinite."""
226        return self._inner.is_unbounded

True if either bound is infinite.

is_finite: bool
228    @property
229    def is_finite(self) -> bool:
230        """True if both bounds are finite."""
231        return self._inner.is_finite

True if both bounds are finite.

class Vector(typing.Generic[~N]):
 61class Vector(Generic[N], metaclass=VectorMeta):
 62    """
 63    A vector type with fixed dimensions for embeddings.
 64
 65    Usage:
 66        embedding: Vector[1536]  # 1536-dimensional vector
 67
 68    At runtime, vectors are stored as list[float].
 69    """
 70
 71    __dimensions__: int = 0
 72    __origin__: type | None = None
 73
 74    def __init__(self, values: list[float]) -> None:
 75        expected = self.__class__.__dimensions__
 76        if expected > 0 and len(values) != expected:
 77            raise ValueError(f"Vector expects {expected} dimensions, got {len(values)}")
 78        self._values = values
 79
 80    @property
 81    def values(self) -> list[float]:
 82        return self._values
 83
 84    def __repr__(self) -> str:
 85        dims = self.__class__.__dimensions__
 86        return (
 87            f"Vector[{dims}]({self._values[:3]}...)"
 88            if len(self._values) > 3
 89            else f"Vector[{dims}]({self._values})"
 90        )
 91
 92    def __eq__(self, other: object) -> bool:
 93        if isinstance(other, Vector):
 94            return self._values == other._values
 95        if isinstance(other, list):
 96            return self._values == other
 97        return False
 98
 99    def __len__(self) -> int:
100        return len(self._values)
101
102    def __iter__(self):  # type: ignore[no-untyped-def]
103        return iter(self._values)
104
105    @classmethod
106    def __get_pydantic_core_schema__(
107        cls, source_type: Any, handler: GetCoreSchemaHandler
108    ) -> CoreSchema:
109        """Make Vector compatible with Pydantic v2."""
110        dimensions = getattr(source_type, "__dimensions__", 0)
111        vec_cls = source_type if dimensions > 0 else cls
112
113        def validate_vector(v: Any) -> Vector:  # type: ignore[type-arg]
114            if isinstance(v, Vector):
115                if dimensions > 0 and len(v) != dimensions:
116                    raise ValueError(
117                        f"Vector expects {dimensions} dimensions, got {len(v)}"
118                    )
119                return v
120            if isinstance(v, list):
121                if dimensions > 0 and len(v) != dimensions:
122                    raise ValueError(
123                        f"Vector expects {dimensions} dimensions, got {len(v)}"
124                    )
125                return vec_cls([float(x) for x in v])
126            raise TypeError(f"Expected list or Vector, got {type(v)}")
127
128        return core_schema.no_info_plain_validator_function(
129            validate_vector,
130            serialization=core_schema.plain_serializer_function_ser_schema(
131                lambda v: v.values if isinstance(v, Vector) else list(v),
132                info_arg=False,
133            ),
134        )

A vector type with fixed dimensions for embeddings.

Usage: embedding: Vector[1536] # 1536-dimensional vector

At runtime, vectors are stored as list[float].

Vector(values: list[float])
74    def __init__(self, values: list[float]) -> None:
75        expected = self.__class__.__dimensions__
76        if expected > 0 and len(values) != expected:
77            raise ValueError(f"Vector expects {expected} dimensions, got {len(values)}")
78        self._values = values
values: list[float]
80    @property
81    def values(self) -> list[float]:
82        return self._values
def python_type_to_uni(type_hint: Any, *, nullable: bool = False) -> tuple[str, bool]:
415def python_type_to_uni(type_hint: Any, *, nullable: bool = False) -> tuple[str, bool]:
416    """
417    Convert a Python type hint to a Uni DataType string.
418
419    Args:
420        type_hint: The Python type hint to convert.
421        nullable: Whether the field is explicitly nullable.
422
423    Returns:
424        Tuple of (uni_data_type, is_nullable)
425
426    Raises:
427        TypeMappingError: If the type cannot be mapped.
428    """
429    # Unwrap Annotated if present
430    type_hint, _ = unwrap_annotated(type_hint)
431
432    # Check for optional (T | None)
433    is_opt, inner_type = is_optional(type_hint)
434    if is_opt:
435        uni_type, _ = python_type_to_uni(inner_type)
436        return uni_type, True
437
438    # Check for Vector types
439    dims = get_vector_dimensions(type_hint)
440    if dims is not None:
441        return f"vector:{dims}", nullable
442
443    # Check for list types
444    is_lst, elem_type = is_list_type(type_hint)
445    if is_lst:
446        if elem_type in (str, int, float, bool):
447            # Simple list types
448            elem_uni = TYPE_MAP.get(elem_type, "string")
449            return f"list:{elem_uni}", nullable
450        # Complex list types stored as JSON
451        return "json", nullable
452
453    # Direct type mapping
454    if type_hint in TYPE_MAP:
455        return TYPE_MAP[type_hint], nullable
456
457    # Handle generic dict types
458    origin = get_origin(type_hint)
459    if origin is dict:
460        return "json", nullable
461
462    # Handle forward references (strings)
463    if isinstance(type_hint, str):
464        # This is a forward reference, can't resolve here
465        raise TypeMappingError(
466            type_hint,
467            f"Cannot resolve forward reference {type_hint!r}. "
468            "Ensure the referenced class is defined before schema sync.",
469        )
470
471    raise TypeMappingError(type_hint)

Convert a Python type hint to a Uni DataType string.

Args: type_hint: The Python type hint to convert. nullable: Whether the field is explicitly nullable.

Returns: Tuple of (uni_data_type, is_nullable)

Raises: TypeMappingError: If the type cannot be mapped.

def uni_to_python_type(uni_type: str) -> type:
474def uni_to_python_type(uni_type: str) -> type:
475    """
476    Convert a Uni DataType string to a Python type.
477
478    Args:
479        uni_type: The Uni data type string.
480
481    Returns:
482        The corresponding Python type.
483    """
484    # Reverse mapping — manually constructed to avoid bytes overwriting str for "string"
485    _REVERSE_MAP: dict[str, type] = {
486        "string": str,
487        "int64": int,
488        "float64": float,
489        "bool": bool,
490        "datetime": datetime,
491        "date": date,
492        "time": time,
493        "duration": timedelta,
494        "json": dict,
495        "btic": Btic,
496    }
497
498    # Handle vector types
499    if uni_type.startswith("vector:"):
500        return list  # Vectors are stored as list[float]
501
502    # Handle list types
503    if uni_type.startswith("list:"):
504        return list
505
506    return _REVERSE_MAP.get(uni_type.lower(), str)

Convert a Uni DataType string to a Python type.

Args: uni_type: The Uni data type string.

Returns: The corresponding Python type.

def get_vector_dimensions(type_hint: Any) -> int | None:
137def get_vector_dimensions(type_hint: Any) -> int | None:
138    """Extract vector dimensions from a Vector[N] type hint."""
139    if hasattr(type_hint, "__dimensions__"):
140        dims: int = type_hint.__dimensions__
141        return dims
142    origin = get_origin(type_hint)
143    if origin is Vector:
144        args = get_args(type_hint)
145        if args and isinstance(args[0], int):
146            return args[0]
147    return None

Extract vector dimensions from a Vector[N] type hint.

def is_optional(type_hint: Any) -> tuple[bool, typing.Any]:
271def is_optional(type_hint: Any) -> tuple[bool, Any]:
272    """
273    Check if a type hint is Optional (T | None).
274
275    Returns:
276        Tuple of (is_optional, inner_type)
277    """
278    origin = get_origin(type_hint)
279
280    # Handle Union types (including T | None which is Union[T, None])
281    if origin is Union:
282        args = get_args(type_hint)
283        non_none_args = [arg for arg in args if arg is not type(None)]
284        if len(non_none_args) == 1 and type(None) in args:
285            return True, non_none_args[0]
286
287    # Python 3.10+ uses types.UnionType for X | Y syntax
288    if isinstance(type_hint, types.UnionType):
289        args = get_args(type_hint)
290        non_none_args = [arg for arg in args if arg is not type(None)]
291        if len(non_none_args) == 1 and type(None) in args:
292            return True, non_none_args[0]
293
294    return False, type_hint

Check if a type hint is Optional (T | None).

Returns: Tuple of (is_optional, inner_type)

def is_list_type(type_hint: Any) -> tuple[bool, typing.Any | None]:
297def is_list_type(type_hint: Any) -> tuple[bool, Any | None]:
298    """
299    Check if a type hint is a list type.
300
301    Returns:
302        Tuple of (is_list, element_type)
303    """
304    origin = get_origin(type_hint)
305    if origin is list:
306        args = get_args(type_hint)
307        return True, args[0] if args else Any
308    return False, None

Check if a type hint is a list type.

Returns: Tuple of (is_list, element_type)

def unwrap_annotated(type_hint: Any) -> tuple[typing.Any, tuple[typing.Any, ...]]:
311def unwrap_annotated(type_hint: Any) -> tuple[Any, tuple[Any, ...]]:
312    """
313    Unwrap an Annotated type.
314
315    Returns:
316        Tuple of (base_type, metadata_tuple)
317    """
318    origin = get_origin(type_hint)
319    if origin is Annotated:
320        args = get_args(type_hint)
321        return args[0], args[1:]
322    return type_hint, ()

Unwrap an Annotated type.

Returns: Tuple of (base_type, metadata_tuple)

def python_to_db_value(value: Any, type_hint: Any) -> Any:
329def python_to_db_value(value: Any, type_hint: Any) -> Any:
330    """Convert a Python value to a database-compatible value.
331
332    Passes datetime/date/time/timedelta through to the Rust layer which
333    converts them to proper Value::Temporal types. Converts Vector to
334    list[float] and passes through everything else.
335    """
336    if value is None:
337        return None
338
339    # Vector → list[float]
340    if isinstance(value, Vector):
341        return value.values
342
343    # Btic → unwrap to the Rust PyBtic for py_object_to_value
344    if isinstance(value, Btic):
345        return value._inner
346
347    # datetime/date/time/timedelta pass through — the Rust py_object_to_value
348    # handles conversion to Value::Temporal with proper type information.
349    return value

Convert a Python value to a database-compatible value.

Passes datetime/date/time/timedelta through to the Rust layer which converts them to proper Value::Temporal types. Converts Vector to list[float] and passes through everything else.

def db_to_python_value(value: Any, type_hint: Any) -> Any:
352def db_to_python_value(value: Any, type_hint: Any) -> Any:
353    """Convert a database value back to a Python value.
354
355    The Rust layer now returns proper Python datetime/date/time objects
356    via Value::Temporal, so in most cases values pass through directly.
357    """
358    if value is None:
359        return None
360
361    # Unwrap Optional
362    _, inner = is_optional(type_hint)
363    if inner is not type_hint:
364        type_hint = inner
365
366    # Unwrap Annotated
367    type_hint, _ = unwrap_annotated(type_hint)
368
369    # If value is already the right Python type, pass through
370    if type_hint is datetime and isinstance(value, datetime):
371        return value
372    if type_hint is date and isinstance(value, date):
373        return value
374    if type_hint is time and isinstance(value, time):
375        return value
376    if type_hint is timedelta and isinstance(value, timedelta):
377        return value
378
379    # Btic — wrap Rust PyBtic in the pydantic Btic wrapper
380    if type_hint is Btic and _PyBtic is not None and isinstance(value, _PyBtic):
381        return Btic(value)
382
383    # Handle struct dict from Arrow deserialization (e.g. datetime struct)
384    if type_hint is datetime and isinstance(value, dict):
385        nanos = value.get("nanos_since_epoch")
386        if nanos is not None:
387            return datetime.fromtimestamp(nanos / 1_000_000_000)
388        return None
389
390    # Vector fields: list[float] → Vector
391    dims = get_vector_dimensions(type_hint)
392    if dims is not None and isinstance(value, list):
393        vec_cls = Vector[dims]
394        return vec_cls(value)
395
396    return value

Convert a database value back to a Python value.

The Rust layer now returns proper Python datetime/date/time objects via Value::Temporal, so in most cases values pass through directly.

DATETIME_TYPES = {<class 'datetime.date'>, <class 'datetime.datetime'>, <class 'datetime.timedelta'>, <class 'datetime.time'>, <class 'Btic'>}
class QueryBuilder(uni_pydantic.query._QueryBuilderBase[~NodeT]):
591class QueryBuilder(_QueryBuilderBase[NodeT]):
592    """
593    Immutable, type-safe query builder for graph queries.
594
595    Each method returns a **new** QueryBuilder instance. The original is
596    never mutated. Provides a fluent API for building Cypher queries
597    with type checking and IDE autocomplete support.
598
599    Example:
600        >>> adults = (
601        ...     session.query(Person)
602        ...     .filter(Person.age >= 18)
603        ...     .order_by(Person.name)
604        ...     .limit(10)
605        ...     .all()
606        ... )
607    """
608
609    def __init__(self, session: UniSession, model: type[NodeT]) -> None:
610        self._init_state(session, model)
611
612    def _execute_query(
613        self, cypher: str, params: dict[str, Any]
614    ) -> list[dict[str, Any]]:
615        """Execute a query, using query_with if timeout/max_memory is set."""
616        if self._timeout is not None or self._max_memory is not None:
617            builder = self._session._db_session.query_with(cypher)
618            if params:
619                builder = builder.params(params)
620            if self._timeout is not None:
621                builder = builder.timeout(self._timeout)
622            if self._max_memory is not None:
623                builder = builder.max_memory(self._max_memory)
624            result = builder.fetch_all()
625        else:
626            result = self._session._db_session.query(cypher, params)
627        return [row.to_dict() for row in result]
628
629    def all(self) -> list[NodeT]:
630        """Execute the query and return all results."""
631        cypher, params = self._build_cypher()
632        results = self._execute_query(cypher, params)
633        instances = self._rows_to_instances(results)
634        if self._eager_load and instances:
635            self._session._eager_load_relationships(instances, self._eager_load)
636        return instances
637
638    def first(self) -> NodeT | None:
639        """Execute the query and return the first result."""
640        clone = self._clone()
641        clone._limit = 1
642        results = clone.all()
643        return results[0] if results else None
644
645    def one(self) -> NodeT:
646        """Execute the query and return exactly one result.
647
648        Raises QueryError if no results or more than one result.
649        """
650        clone = self._clone()
651        clone._limit = 2
652        results = clone.all()
653        if not results:
654            raise QueryError("Query returned no results")
655        if len(results) > 1:
656            raise QueryError("Query returned more than one result")
657        return results[0]
658
659    def count(self) -> int:
660        """Execute the query and return the count of results."""
661        cypher, params = self._build_count_cypher()
662        results = self._execute_query(cypher, params)
663        return cast(int, results[0]["count"]) if results else 0
664
665    def exists(self) -> bool:
666        """Check if any matching records exist."""
667        cypher, params = self._build_exists_cypher()
668        results = self._execute_query(cypher, params)
669        return len(results) > 0
670
671    def delete(self) -> int:
672        """Delete all matching records (DETACH DELETE)."""
673        cypher, params = self._build_delete_cypher()
674        with self._session._db_session.tx() as tx:
675            results = tx.query(cypher, params)
676            tx.commit()
677        return results[0].to_dict()["count"] if results else 0
678
679    def update(self, **kwargs: Any) -> int:
680        """Update all matching records."""
681        cypher, params = self._build_update_cypher(**kwargs)
682        with self._session._db_session.tx() as tx:
683            results = tx.query(cypher, params)
684            tx.commit()
685        return results[0].to_dict()["count"] if results else 0

Immutable, type-safe query builder for graph queries.

Each method returns a new QueryBuilder instance. The original is never mutated. Provides a fluent API for building Cypher queries with type checking and IDE autocomplete support.

Example:

adults = ( ... session.query(Person) ... .filter(Person.age >= 18) ... .order_by(Person.name) ... .limit(10) ... .all() ... )

QueryBuilder(session: UniSession, model: type[~NodeT])
609    def __init__(self, session: UniSession, model: type[NodeT]) -> None:
610        self._init_state(session, model)
def all(self) -> list[~NodeT]:
629    def all(self) -> list[NodeT]:
630        """Execute the query and return all results."""
631        cypher, params = self._build_cypher()
632        results = self._execute_query(cypher, params)
633        instances = self._rows_to_instances(results)
634        if self._eager_load and instances:
635            self._session._eager_load_relationships(instances, self._eager_load)
636        return instances

Execute the query and return all results.

def first(self) -> Optional[~NodeT]:
638    def first(self) -> NodeT | None:
639        """Execute the query and return the first result."""
640        clone = self._clone()
641        clone._limit = 1
642        results = clone.all()
643        return results[0] if results else None

Execute the query and return the first result.

def one(self) -> ~NodeT:
645    def one(self) -> NodeT:
646        """Execute the query and return exactly one result.
647
648        Raises QueryError if no results or more than one result.
649        """
650        clone = self._clone()
651        clone._limit = 2
652        results = clone.all()
653        if not results:
654            raise QueryError("Query returned no results")
655        if len(results) > 1:
656            raise QueryError("Query returned more than one result")
657        return results[0]

Execute the query and return exactly one result.

Raises QueryError if no results or more than one result.

def count(self) -> int:
659    def count(self) -> int:
660        """Execute the query and return the count of results."""
661        cypher, params = self._build_count_cypher()
662        results = self._execute_query(cypher, params)
663        return cast(int, results[0]["count"]) if results else 0

Execute the query and return the count of results.

def exists(self) -> bool:
665    def exists(self) -> bool:
666        """Check if any matching records exist."""
667        cypher, params = self._build_exists_cypher()
668        results = self._execute_query(cypher, params)
669        return len(results) > 0

Check if any matching records exist.

def delete(self) -> int:
671    def delete(self) -> int:
672        """Delete all matching records (DETACH DELETE)."""
673        cypher, params = self._build_delete_cypher()
674        with self._session._db_session.tx() as tx:
675            results = tx.query(cypher, params)
676            tx.commit()
677        return results[0].to_dict()["count"] if results else 0

Delete all matching records (DETACH DELETE).

def update(self, **kwargs: Any) -> int:
679    def update(self, **kwargs: Any) -> int:
680        """Update all matching records."""
681        cypher, params = self._build_update_cypher(**kwargs)
682        with self._session._db_session.tx() as tx:
683            results = tx.query(cypher, params)
684            tx.commit()
685        return results[0].to_dict()["count"] if results else 0

Update all matching records.

class AsyncQueryBuilder(uni_pydantic.query._QueryBuilderBase[~NodeT]):
 26class AsyncQueryBuilder(_QueryBuilderBase[NodeT]):
 27    """
 28    Immutable, async query builder for graph queries.
 29
 30    Inherits all Cypher-building and immutable builder methods from
 31    ``_QueryBuilderBase``. Only the execution methods are async.
 32    """
 33
 34    def __init__(self, session: AsyncUniSession, model: type[NodeT]) -> None:
 35        self._init_state(session, model)
 36
 37    async def _execute_query(
 38        self, cypher: str, params: dict[str, Any]
 39    ) -> list[dict[str, Any]]:
 40        """Execute a query, using query_with if timeout/max_memory is set."""
 41        if self._timeout is not None or self._max_memory is not None:
 42            builder = self._session._db_session.query_with(cypher)
 43            if params:
 44                builder = builder.params(params)
 45            if self._timeout is not None:
 46                builder = builder.timeout(self._timeout)
 47            if self._max_memory is not None:
 48                builder = builder.max_memory(self._max_memory)
 49            result = await builder.fetch_all()
 50        else:
 51            result = await self._session._db_session.query(cypher, params)
 52        return [row.to_dict() for row in result]
 53
 54    async def all(self) -> list[NodeT]:
 55        """Execute the query and return all results."""
 56        cypher, params = self._build_cypher()
 57        results = await self._execute_query(cypher, params)
 58        instances = self._rows_to_instances(results)
 59        if self._eager_load and instances:
 60            await self._session._async_eager_load_relationships(
 61                instances, self._eager_load
 62            )
 63        return instances
 64
 65    async def first(self) -> NodeT | None:
 66        """Execute the query and return the first result."""
 67        clone = self._clone()
 68        clone._limit = 1
 69        results = await clone.all()
 70        return results[0] if results else None
 71
 72    async def one(self) -> NodeT:
 73        """Execute the query and return exactly one result.
 74
 75        Raises QueryError if no results or more than one result.
 76        """
 77        clone = self._clone()
 78        clone._limit = 2
 79        results = await clone.all()
 80        if not results:
 81            raise QueryError("Query returned no results")
 82        if len(results) > 1:
 83            raise QueryError("Query returned more than one result")
 84        return results[0]
 85
 86    async def count(self) -> int:
 87        """Execute the query and return the count of results."""
 88        cypher, params = self._build_count_cypher()
 89        results = await self._execute_query(cypher, params)
 90        return cast(int, results[0]["count"]) if results else 0
 91
 92    async def exists(self) -> bool:
 93        """Check if any matching records exist."""
 94        cypher, params = self._build_exists_cypher()
 95        results = await self._execute_query(cypher, params)
 96        return len(results) > 0
 97
 98    async def delete(self) -> int:
 99        """Delete all matching records (DETACH DELETE)."""
100        cypher, params = self._build_delete_cypher()
101        async with await self._session._db_session.tx() as tx:
102            results = await tx.query(cypher, params)
103            await tx.commit()
104        return results[0].to_dict()["count"] if results else 0
105
106    async def update(self, **kwargs: Any) -> int:
107        """Update all matching records."""
108        cypher, params = self._build_update_cypher(**kwargs)
109        async with await self._session._db_session.tx() as tx:
110            results = await tx.query(cypher, params)
111            await tx.commit()
112        return results[0].to_dict()["count"] if results else 0

Immutable, async query builder for graph queries.

Inherits all Cypher-building and immutable builder methods from _QueryBuilderBase. Only the execution methods are async.

AsyncQueryBuilder( session: AsyncUniSession, model: type[~NodeT])
34    def __init__(self, session: AsyncUniSession, model: type[NodeT]) -> None:
35        self._init_state(session, model)
async def all(self) -> list[~NodeT]:
54    async def all(self) -> list[NodeT]:
55        """Execute the query and return all results."""
56        cypher, params = self._build_cypher()
57        results = await self._execute_query(cypher, params)
58        instances = self._rows_to_instances(results)
59        if self._eager_load and instances:
60            await self._session._async_eager_load_relationships(
61                instances, self._eager_load
62            )
63        return instances

Execute the query and return all results.

async def first(self) -> Optional[~NodeT]:
65    async def first(self) -> NodeT | None:
66        """Execute the query and return the first result."""
67        clone = self._clone()
68        clone._limit = 1
69        results = await clone.all()
70        return results[0] if results else None

Execute the query and return the first result.

async def one(self) -> ~NodeT:
72    async def one(self) -> NodeT:
73        """Execute the query and return exactly one result.
74
75        Raises QueryError if no results or more than one result.
76        """
77        clone = self._clone()
78        clone._limit = 2
79        results = await clone.all()
80        if not results:
81            raise QueryError("Query returned no results")
82        if len(results) > 1:
83            raise QueryError("Query returned more than one result")
84        return results[0]

Execute the query and return exactly one result.

Raises QueryError if no results or more than one result.

async def count(self) -> int:
86    async def count(self) -> int:
87        """Execute the query and return the count of results."""
88        cypher, params = self._build_count_cypher()
89        results = await self._execute_query(cypher, params)
90        return cast(int, results[0]["count"]) if results else 0

Execute the query and return the count of results.

async def exists(self) -> bool:
92    async def exists(self) -> bool:
93        """Check if any matching records exist."""
94        cypher, params = self._build_exists_cypher()
95        results = await self._execute_query(cypher, params)
96        return len(results) > 0

Check if any matching records exist.

async def delete(self) -> int:
 98    async def delete(self) -> int:
 99        """Delete all matching records (DETACH DELETE)."""
100        cypher, params = self._build_delete_cypher()
101        async with await self._session._db_session.tx() as tx:
102            results = await tx.query(cypher, params)
103            await tx.commit()
104        return results[0].to_dict()["count"] if results else 0

Delete all matching records (DETACH DELETE).

async def update(self, **kwargs: Any) -> int:
106    async def update(self, **kwargs: Any) -> int:
107        """Update all matching records."""
108        cypher, params = self._build_update_cypher(**kwargs)
109        async with await self._session._db_session.tx() as tx:
110            results = await tx.query(cypher, params)
111            await tx.commit()
112        return results[0].to_dict()["count"] if results else 0

Update all matching records.

@dataclass
class FilterExpr:
116@dataclass
117class FilterExpr:
118    """A filter expression for a query."""
119
120    property_name: str
121    op: FilterOp
122    value: Any = None
123
124    def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]:
125        """Convert to Cypher WHERE clause fragment."""
126        prop = f"{node_var}.{self.property_name}"
127
128        if self.op == FilterOp.IS_NULL:
129            return f"{prop} IS NULL", {}
130        elif self.op == FilterOp.IS_NOT_NULL:
131            return f"{prop} IS NOT NULL", {}
132        elif self.op == FilterOp.IN:
133            return f"{prop} IN ${param_name}", {param_name: self.value}
134        elif self.op == FilterOp.NOT_IN:
135            return f"NOT {prop} IN ${param_name}", {param_name: self.value}
136        elif self.op == FilterOp.LIKE:
137            return f"{prop} =~ ${param_name}", {param_name: self.value}
138        elif self.op == FilterOp.STARTS_WITH:
139            return f"{prop} STARTS WITH ${param_name}", {param_name: self.value}
140        elif self.op == FilterOp.ENDS_WITH:
141            return f"{prop} ENDS WITH ${param_name}", {param_name: self.value}
142        elif self.op == FilterOp.CONTAINS:
143            return f"{prop} CONTAINS ${param_name}", {param_name: self.value}
144        else:
145            return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}

A filter expression for a query.

FilterExpr( property_name: str, op: FilterOp, value: Any = None)
property_name: str
op: FilterOp
value: Any = None
def to_cypher( self, node_var: str, param_name: str) -> tuple[str, dict[str, typing.Any]]:
124    def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]:
125        """Convert to Cypher WHERE clause fragment."""
126        prop = f"{node_var}.{self.property_name}"
127
128        if self.op == FilterOp.IS_NULL:
129            return f"{prop} IS NULL", {}
130        elif self.op == FilterOp.IS_NOT_NULL:
131            return f"{prop} IS NOT NULL", {}
132        elif self.op == FilterOp.IN:
133            return f"{prop} IN ${param_name}", {param_name: self.value}
134        elif self.op == FilterOp.NOT_IN:
135            return f"NOT {prop} IN ${param_name}", {param_name: self.value}
136        elif self.op == FilterOp.LIKE:
137            return f"{prop} =~ ${param_name}", {param_name: self.value}
138        elif self.op == FilterOp.STARTS_WITH:
139            return f"{prop} STARTS WITH ${param_name}", {param_name: self.value}
140        elif self.op == FilterOp.ENDS_WITH:
141            return f"{prop} ENDS WITH ${param_name}", {param_name: self.value}
142        elif self.op == FilterOp.CONTAINS:
143            return f"{prop} CONTAINS ${param_name}", {param_name: self.value}
144        else:
145            return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}

Convert to Cypher WHERE clause fragment.

class FilterOp(enum.Enum):
 97class FilterOp(Enum):
 98    """Filter operation types."""
 99
100    EQ = "="
101    NE = "<>"
102    LT = "<"
103    LE = "<="
104    GT = ">"
105    GE = ">="
106    IN = "IN"
107    NOT_IN = "NOT IN"
108    LIKE = "=~"
109    IS_NULL = "IS NULL"
110    IS_NOT_NULL = "IS NOT NULL"
111    STARTS_WITH = "STARTS WITH"
112    ENDS_WITH = "ENDS WITH"
113    CONTAINS = "CONTAINS"

Filter operation types.

EQ = <FilterOp.EQ: '='>
NE = <FilterOp.NE: '<>'>
LT = <FilterOp.LT: '<'>
LE = <FilterOp.LE: '<='>
GT = <FilterOp.GT: '>'>
GE = <FilterOp.GE: '>='>
IN = <FilterOp.IN: 'IN'>
NOT_IN = <FilterOp.NOT_IN: 'NOT IN'>
LIKE = <FilterOp.LIKE: '=~'>
IS_NULL = <FilterOp.IS_NULL: 'IS NULL'>
IS_NOT_NULL = <FilterOp.IS_NOT_NULL: 'IS NOT NULL'>
STARTS_WITH = <FilterOp.STARTS_WITH: 'STARTS WITH'>
ENDS_WITH = <FilterOp.ENDS_WITH: 'ENDS WITH'>
CONTAINS = <FilterOp.CONTAINS: 'CONTAINS'>
class PropertyProxy(typing.Generic[~T]):
148class PropertyProxy(Generic[T]):
149    """
150    Proxy for model properties that enables filter expressions.
151
152    Used in query builder to create type-safe filter conditions.
153
154    Example:
155        >>> query.filter(Person.age >= 18)
156        >>> query.filter(Person.name.starts_with("A"))
157    """
158
159    def __init__(self, property_name: str, model: type[UniNode]) -> None:
160        self._property_name = property_name
161        self._model = model
162
163    def __eq__(self, other: Any) -> FilterExpr:  # type: ignore[override]
164        return FilterExpr(self._property_name, FilterOp.EQ, other)
165
166    def __ne__(self, other: Any) -> FilterExpr:  # type: ignore[override]
167        return FilterExpr(self._property_name, FilterOp.NE, other)
168
169    def __lt__(self, other: Any) -> FilterExpr:
170        return FilterExpr(self._property_name, FilterOp.LT, other)
171
172    def __le__(self, other: Any) -> FilterExpr:
173        return FilterExpr(self._property_name, FilterOp.LE, other)
174
175    def __gt__(self, other: Any) -> FilterExpr:
176        return FilterExpr(self._property_name, FilterOp.GT, other)
177
178    def __ge__(self, other: Any) -> FilterExpr:
179        return FilterExpr(self._property_name, FilterOp.GE, other)
180
181    def in_(self, values: Sequence[T]) -> FilterExpr:
182        """Check if value is in a list."""
183        return FilterExpr(self._property_name, FilterOp.IN, list(values))
184
185    def not_in(self, values: Sequence[T]) -> FilterExpr:
186        """Check if value is not in a list."""
187        return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values))
188
189    def like(self, pattern: str) -> FilterExpr:
190        """Match a regex pattern."""
191        return FilterExpr(self._property_name, FilterOp.LIKE, pattern)
192
193    def is_null(self) -> FilterExpr:
194        """Check if value is null."""
195        return FilterExpr(self._property_name, FilterOp.IS_NULL)
196
197    def is_not_null(self) -> FilterExpr:
198        """Check if value is not null."""
199        return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL)
200
201    def starts_with(self, prefix: str) -> FilterExpr:
202        """Check if string starts with prefix."""
203        return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix)
204
205    def ends_with(self, suffix: str) -> FilterExpr:
206        """Check if string ends with suffix."""
207        return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix)
208
209    def contains(self, substring: str) -> FilterExpr:
210        """Check if string contains substring."""
211        return FilterExpr(self._property_name, FilterOp.CONTAINS, substring)

Proxy for model properties that enables filter expressions.

Used in query builder to create type-safe filter conditions.

Example:

query.filter(Person.age >= 18) query.filter(Person.name.starts_with("A"))

PropertyProxy(property_name: str, model: type[UniNode])
159    def __init__(self, property_name: str, model: type[UniNode]) -> None:
160        self._property_name = property_name
161        self._model = model
def in_(self, values: Sequence[~T]) -> FilterExpr:
181    def in_(self, values: Sequence[T]) -> FilterExpr:
182        """Check if value is in a list."""
183        return FilterExpr(self._property_name, FilterOp.IN, list(values))

Check if value is in a list.

def not_in(self, values: Sequence[~T]) -> FilterExpr:
185    def not_in(self, values: Sequence[T]) -> FilterExpr:
186        """Check if value is not in a list."""
187        return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values))

Check if value is not in a list.

def like(self, pattern: str) -> FilterExpr:
189    def like(self, pattern: str) -> FilterExpr:
190        """Match a regex pattern."""
191        return FilterExpr(self._property_name, FilterOp.LIKE, pattern)

Match a regex pattern.

def is_null(self) -> FilterExpr:
193    def is_null(self) -> FilterExpr:
194        """Check if value is null."""
195        return FilterExpr(self._property_name, FilterOp.IS_NULL)

Check if value is null.

def is_not_null(self) -> FilterExpr:
197    def is_not_null(self) -> FilterExpr:
198        """Check if value is not null."""
199        return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL)

Check if value is not null.

def starts_with(self, prefix: str) -> FilterExpr:
201    def starts_with(self, prefix: str) -> FilterExpr:
202        """Check if string starts with prefix."""
203        return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix)

Check if string starts with prefix.

def ends_with(self, suffix: str) -> FilterExpr:
205    def ends_with(self, suffix: str) -> FilterExpr:
206        """Check if string ends with suffix."""
207        return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix)

Check if string ends with suffix.

def contains(self, substring: str) -> FilterExpr:
209    def contains(self, substring: str) -> FilterExpr:
210        """Check if string contains substring."""
211        return FilterExpr(self._property_name, FilterOp.CONTAINS, substring)

Check if string contains substring.

class ModelProxy(typing.Generic[~NodeT]):
214class ModelProxy(Generic[NodeT]):
215    """
216    Proxy for model classes that provides property proxies.
217
218    Enables type-safe property access in query filters.
219
220    Example:
221        >>> Person.name  # Returns PropertyProxy for 'name'
222        >>> query.filter(Person.age >= 18)
223    """
224
225    def __init__(self, model: type[NodeT]) -> None:
226        self._model = model
227
228    def __getattr__(self, name: str) -> PropertyProxy[Any]:
229        if name.startswith("_"):
230            raise AttributeError(name)
231        return PropertyProxy(name, self._model)

Proxy for model classes that provides property proxies.

Enables type-safe property access in query filters.

Example:

Person.name # Returns PropertyProxy for 'name' query.filter(Person.age >= 18)

ModelProxy(model: type[~NodeT])
225    def __init__(self, model: type[NodeT]) -> None:
226        self._model = model
@dataclass
class OrderByClause:
234@dataclass
235class OrderByClause:
236    """An ORDER BY clause."""
237
238    property_name: str
239    descending: bool = False

An ORDER BY clause.

OrderByClause(property_name: str, descending: bool = False)
property_name: str
descending: bool = False
@dataclass
class TraversalStep:
242@dataclass
243class TraversalStep:
244    """A relationship traversal step."""
245
246    edge_type: str
247    direction: Literal["outgoing", "incoming", "both"]
248    target_label: str | None = None

A relationship traversal step.

TraversalStep( edge_type: str, direction: Literal['outgoing', 'incoming', 'both'], target_label: str | None = None)
edge_type: str
direction: Literal['outgoing', 'incoming', 'both']
target_label: str | None = None
@dataclass
class VectorSearchConfig:
251@dataclass
252class VectorSearchConfig:
253    """Configuration for vector similarity search."""
254
255    property_name: str
256    query_vector: list[float]
257    k: int
258    threshold: float | None = None
259    pre_filter: str | None = None

Configuration for vector similarity search.

VectorSearchConfig( property_name: str, query_vector: list[float], k: int, threshold: float | None = None, pre_filter: str | None = None)
property_name: str
query_vector: list[float]
k: int
threshold: float | None = None
pre_filter: str | None = None
class SchemaGenerator:
 60class SchemaGenerator:
 61    """Generates Uni database schema from registered models."""
 62
 63    def __init__(self) -> None:
 64        self._node_models: dict[str, type[UniNode]] = {}
 65        self._edge_models: dict[str, type[UniEdge]] = {}
 66        self._schema: DatabaseSchema | None = None
 67
 68    def register_node(self, model: type[UniNode]) -> None:
 69        """Register a node model for schema generation."""
 70        label = model.__label__
 71        if not label:
 72            raise SchemaError(f"Model {model.__name__} has no __label__", model)
 73        self._node_models[label] = model
 74        self._schema = None  # Invalidate cached schema
 75
 76    def register_edge(self, model: type[UniEdge]) -> None:
 77        """Register an edge model for schema generation."""
 78        edge_type = model.__edge_type__
 79        if not edge_type:
 80            raise SchemaError(f"Model {model.__name__} has no __edge_type__", model)
 81        self._edge_models[edge_type] = model
 82        self._schema = None
 83
 84    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
 85        """Register multiple models."""
 86        for model in models:
 87            if issubclass(model, UniEdge):
 88                self.register_edge(model)
 89            elif issubclass(model, UniNode):
 90                self.register_node(model)
 91            else:
 92                raise SchemaError(
 93                    f"Model {model.__name__} must be a subclass of UniNode or UniEdge"
 94                )
 95
 96    def _generate_property_schema(
 97        self,
 98        model: type[UniNode] | type[UniEdge],
 99        field_name: str,
100    ) -> PropertySchema:
101        """Generate schema for a single property field."""
102        field_info = model.model_fields[field_name]
103
104        # Get type hints with forward refs resolved
105        try:
106            hints = get_type_hints(model)
107            type_hint = hints.get(field_name, field_info.annotation)
108        except Exception:
109            type_hint = field_info.annotation
110
111        # Check for nullability
112        is_nullable, inner_type = is_optional(type_hint)
113
114        # Get Uni data type
115        data_type, nullable = python_type_to_uni(type_hint, nullable=is_nullable)
116
117        # Check for vector dimensions
118        vec_dims = get_vector_dimensions(inner_type if is_nullable else type_hint)
119        if vec_dims:
120            data_type = f"vector:{vec_dims}"
121
122        # Get field config for index settings
123        config = get_field_config(field_info)
124        index_type = config.index if config else None
125        unique = config.unique if config else False
126        tokenizer = config.tokenizer if config else None
127        metric = config.metric if config else None
128
129        # Auto-create vector index for Vector fields (regardless of Field config)
130        if vec_dims and not index_type:
131            index_type = "vector"
132
133        return PropertySchema(
134            name=field_name,
135            data_type=data_type,
136            nullable=nullable,
137            index_type=index_type,
138            unique=unique,
139            tokenizer=tokenizer,
140            metric=metric,
141        )
142
143    def _generate_label_schema(self, model: type[UniNode]) -> LabelSchema:
144        """Generate schema for a node model."""
145        label = model.__label__
146
147        properties = {}
148        for field_name in model.get_property_fields():
149            prop_schema = self._generate_property_schema(model, field_name)
150            properties[field_name] = prop_schema
151
152        return LabelSchema(
153            name=label,
154            properties=properties,
155        )
156
157    def _generate_edge_type_schema(self, model: type[UniEdge]) -> EdgeTypeSchema:
158        """Generate schema for an edge model."""
159        edge_type = model.__edge_type__
160        from_labels = model.get_from_labels()
161        to_labels = model.get_to_labels()
162
163        # If from/to not specified, allow any labels
164        if not from_labels:
165            from_labels = list(self._node_models.keys())
166        if not to_labels:
167            to_labels = list(self._node_models.keys())
168
169        properties = {}
170        for field_name in model.get_property_fields():
171            prop_schema = self._generate_property_schema(model, field_name)
172            properties[field_name] = prop_schema
173
174        return EdgeTypeSchema(
175            name=edge_type,
176            from_labels=from_labels,
177            to_labels=to_labels,
178            properties=properties,
179        )
180
181    def generate(self) -> DatabaseSchema:
182        """Generate the complete database schema."""
183        if self._schema is not None:
184            return self._schema
185
186        schema = DatabaseSchema()
187
188        # Generate label schemas
189        for label, model in self._node_models.items():
190            schema.labels[label] = self._generate_label_schema(model)
191
192        # Generate edge type schemas
193        for edge_type_name, edge_model in self._edge_models.items():
194            schema.edge_types[edge_type_name] = self._generate_edge_type_schema(
195                edge_model
196            )
197
198        # Also generate labels from relationships in node models
199        for model in self._node_models.values():
200            for rel_name, rel_config in model.get_relationship_fields().items():
201                edge_type = rel_config.edge_type
202                if edge_type not in schema.edge_types:
203                    # Create a minimal edge type schema
204                    schema.edge_types[edge_type] = EdgeTypeSchema(
205                        name=edge_type,
206                        from_labels=list(self._node_models.keys()),
207                        to_labels=list(self._node_models.keys()),
208                    )
209
210        self._schema = schema
211        return schema
212
213    def apply_to_database(self, db: uni_db.Uni) -> None:
214        """Apply the generated schema to a database using SchemaBuilder.
215
216        Uses db.schema() for atomic schema application with additive-only
217        semantics. Creates labels, edge types, properties, and indexes.
218        """
219        schema = self.generate()
220
221        # Build the full schema using SchemaBuilder, skipping existing labels/edge types
222        builder = db.schema()
223        has_changes = False
224
225        for label, label_schema in schema.labels.items():
226            if db.label_exists(label):
227                continue  # Additive-only: skip existing labels
228            lb = builder.label(label)
229            for prop in label_schema.properties.values():
230                # Check for vector type
231                if prop.data_type.startswith("vector:"):
232                    dims = int(prop.data_type.split(":")[1])
233                    lb = lb.vector(prop.name, dims)
234                elif prop.nullable:
235                    lb = lb.property_nullable(prop.name, prop.data_type)
236                else:
237                    lb = lb.property(prop.name, prop.data_type)
238
239                # Add indexes (not vector — vector is handled by .vector())
240                if prop.index_type and prop.index_type in ("btree", "hash"):
241                    lb = lb.index(prop.name, prop.index_type)
242            builder = lb.done()
243            has_changes = True
244
245        for edge_type, edge_schema in schema.edge_types.items():
246            if db.edge_type_exists(edge_type):
247                continue  # Skip existing edge types
248            eb = builder.edge_type(
249                edge_type, edge_schema.from_labels, edge_schema.to_labels
250            )
251            for prop in edge_schema.properties.values():
252                if prop.nullable:
253                    eb = eb.property_nullable(prop.name, prop.data_type)
254                else:
255                    eb = eb.property(prop.name, prop.data_type)
256            builder = eb.done()
257            has_changes = True
258
259        if has_changes:
260            builder.apply()
261
262        # Create vector and fulltext indexes via schema builder
263        for label, label_schema in schema.labels.items():
264            for prop in label_schema.properties.values():
265                if prop.index_type == "vector":
266                    metric = prop.metric or "l2"
267                    try:
268                        db.schema().label(label).index(
269                            prop.name, {"type": "vector", "metric": metric}
270                        ).apply()
271                    except Exception:
272                        pass  # Index may already exist
273                elif prop.index_type == "fulltext":
274                    try:
275                        db.schema().label(label).index(prop.name, "fulltext").apply()
276                    except Exception:
277                        pass  # Index may already exist
278
279    async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None:
280        """Apply the generated schema to an async database.
281
282        Async variant of apply_to_database using AsyncSchemaBuilder.
283        """
284        schema = self.generate()
285
286        # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types
287        builder = db.schema()
288        has_changes = False
289
290        for label, label_schema in schema.labels.items():
291            if await db.label_exists(label):
292                continue
293            lb = builder.label(label)
294            for prop in label_schema.properties.values():
295                if prop.data_type.startswith("vector:"):
296                    dims = int(prop.data_type.split(":")[1])
297                    lb = lb.vector(prop.name, dims)
298                elif prop.nullable:
299                    lb = lb.property_nullable(prop.name, prop.data_type)
300                else:
301                    lb = lb.property(prop.name, prop.data_type)
302
303                if prop.index_type and prop.index_type in ("btree", "hash"):
304                    lb = lb.index(prop.name, prop.index_type)
305            builder = lb.done()
306            has_changes = True
307
308        for edge_type, edge_schema in schema.edge_types.items():
309            if await db.edge_type_exists(edge_type):
310                continue
311            eb = builder.edge_type(
312                edge_type, edge_schema.from_labels, edge_schema.to_labels
313            )
314            for prop in edge_schema.properties.values():
315                if prop.nullable:
316                    eb = eb.property_nullable(prop.name, prop.data_type)
317                else:
318                    eb = eb.property(prop.name, prop.data_type)
319            builder = eb.done()
320            has_changes = True
321
322        if has_changes:
323            await builder.apply()
324
325        # Create vector and fulltext indexes via schema builder
326        for label, label_schema in schema.labels.items():
327            for prop in label_schema.properties.values():
328                if prop.index_type == "vector":
329                    metric = prop.metric or "l2"
330                    try:
331                        await (
332                            db.schema()
333                            .label(label)
334                            .index(prop.name, {"type": "vector", "metric": metric})
335                            .apply()
336                        )
337                    except Exception:
338                        pass  # Index may already exist
339                elif prop.index_type == "fulltext":
340                    try:
341                        await (
342                            db.schema()
343                            .label(label)
344                            .index(prop.name, "fulltext")
345                            .apply()
346                        )
347                    except Exception:
348                        pass  # Index may already exist

Generates Uni database schema from registered models.

def register_node(self, model: type[UniNode]) -> None:
68    def register_node(self, model: type[UniNode]) -> None:
69        """Register a node model for schema generation."""
70        label = model.__label__
71        if not label:
72            raise SchemaError(f"Model {model.__name__} has no __label__", model)
73        self._node_models[label] = model
74        self._schema = None  # Invalidate cached schema

Register a node model for schema generation.

def register_edge(self, model: type[UniEdge]) -> None:
76    def register_edge(self, model: type[UniEdge]) -> None:
77        """Register an edge model for schema generation."""
78        edge_type = model.__edge_type__
79        if not edge_type:
80            raise SchemaError(f"Model {model.__name__} has no __edge_type__", model)
81        self._edge_models[edge_type] = model
82        self._schema = None

Register an edge model for schema generation.

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
84    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
85        """Register multiple models."""
86        for model in models:
87            if issubclass(model, UniEdge):
88                self.register_edge(model)
89            elif issubclass(model, UniNode):
90                self.register_node(model)
91            else:
92                raise SchemaError(
93                    f"Model {model.__name__} must be a subclass of UniNode or UniEdge"
94                )

Register multiple models.

def generate(self) -> DatabaseSchema:
181    def generate(self) -> DatabaseSchema:
182        """Generate the complete database schema."""
183        if self._schema is not None:
184            return self._schema
185
186        schema = DatabaseSchema()
187
188        # Generate label schemas
189        for label, model in self._node_models.items():
190            schema.labels[label] = self._generate_label_schema(model)
191
192        # Generate edge type schemas
193        for edge_type_name, edge_model in self._edge_models.items():
194            schema.edge_types[edge_type_name] = self._generate_edge_type_schema(
195                edge_model
196            )
197
198        # Also generate labels from relationships in node models
199        for model in self._node_models.values():
200            for rel_name, rel_config in model.get_relationship_fields().items():
201                edge_type = rel_config.edge_type
202                if edge_type not in schema.edge_types:
203                    # Create a minimal edge type schema
204                    schema.edge_types[edge_type] = EdgeTypeSchema(
205                        name=edge_type,
206                        from_labels=list(self._node_models.keys()),
207                        to_labels=list(self._node_models.keys()),
208                    )
209
210        self._schema = schema
211        return schema

Generate the complete database schema.

def apply_to_database(self, db: Uni) -> None:
213    def apply_to_database(self, db: uni_db.Uni) -> None:
214        """Apply the generated schema to a database using SchemaBuilder.
215
216        Uses db.schema() for atomic schema application with additive-only
217        semantics. Creates labels, edge types, properties, and indexes.
218        """
219        schema = self.generate()
220
221        # Build the full schema using SchemaBuilder, skipping existing labels/edge types
222        builder = db.schema()
223        has_changes = False
224
225        for label, label_schema in schema.labels.items():
226            if db.label_exists(label):
227                continue  # Additive-only: skip existing labels
228            lb = builder.label(label)
229            for prop in label_schema.properties.values():
230                # Check for vector type
231                if prop.data_type.startswith("vector:"):
232                    dims = int(prop.data_type.split(":")[1])
233                    lb = lb.vector(prop.name, dims)
234                elif prop.nullable:
235                    lb = lb.property_nullable(prop.name, prop.data_type)
236                else:
237                    lb = lb.property(prop.name, prop.data_type)
238
239                # Add indexes (not vector — vector is handled by .vector())
240                if prop.index_type and prop.index_type in ("btree", "hash"):
241                    lb = lb.index(prop.name, prop.index_type)
242            builder = lb.done()
243            has_changes = True
244
245        for edge_type, edge_schema in schema.edge_types.items():
246            if db.edge_type_exists(edge_type):
247                continue  # Skip existing edge types
248            eb = builder.edge_type(
249                edge_type, edge_schema.from_labels, edge_schema.to_labels
250            )
251            for prop in edge_schema.properties.values():
252                if prop.nullable:
253                    eb = eb.property_nullable(prop.name, prop.data_type)
254                else:
255                    eb = eb.property(prop.name, prop.data_type)
256            builder = eb.done()
257            has_changes = True
258
259        if has_changes:
260            builder.apply()
261
262        # Create vector and fulltext indexes via schema builder
263        for label, label_schema in schema.labels.items():
264            for prop in label_schema.properties.values():
265                if prop.index_type == "vector":
266                    metric = prop.metric or "l2"
267                    try:
268                        db.schema().label(label).index(
269                            prop.name, {"type": "vector", "metric": metric}
270                        ).apply()
271                    except Exception:
272                        pass  # Index may already exist
273                elif prop.index_type == "fulltext":
274                    try:
275                        db.schema().label(label).index(prop.name, "fulltext").apply()
276                    except Exception:
277                        pass  # Index may already exist

Apply the generated schema to a database using SchemaBuilder.

Uses db.schema() for atomic schema application with additive-only semantics. Creates labels, edge types, properties, and indexes.

async def async_apply_to_database(self, db: AsyncUni) -> None:
279    async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None:
280        """Apply the generated schema to an async database.
281
282        Async variant of apply_to_database using AsyncSchemaBuilder.
283        """
284        schema = self.generate()
285
286        # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types
287        builder = db.schema()
288        has_changes = False
289
290        for label, label_schema in schema.labels.items():
291            if await db.label_exists(label):
292                continue
293            lb = builder.label(label)
294            for prop in label_schema.properties.values():
295                if prop.data_type.startswith("vector:"):
296                    dims = int(prop.data_type.split(":")[1])
297                    lb = lb.vector(prop.name, dims)
298                elif prop.nullable:
299                    lb = lb.property_nullable(prop.name, prop.data_type)
300                else:
301                    lb = lb.property(prop.name, prop.data_type)
302
303                if prop.index_type and prop.index_type in ("btree", "hash"):
304                    lb = lb.index(prop.name, prop.index_type)
305            builder = lb.done()
306            has_changes = True
307
308        for edge_type, edge_schema in schema.edge_types.items():
309            if await db.edge_type_exists(edge_type):
310                continue
311            eb = builder.edge_type(
312                edge_type, edge_schema.from_labels, edge_schema.to_labels
313            )
314            for prop in edge_schema.properties.values():
315                if prop.nullable:
316                    eb = eb.property_nullable(prop.name, prop.data_type)
317                else:
318                    eb = eb.property(prop.name, prop.data_type)
319            builder = eb.done()
320            has_changes = True
321
322        if has_changes:
323            await builder.apply()
324
325        # Create vector and fulltext indexes via schema builder
326        for label, label_schema in schema.labels.items():
327            for prop in label_schema.properties.values():
328                if prop.index_type == "vector":
329                    metric = prop.metric or "l2"
330                    try:
331                        await (
332                            db.schema()
333                            .label(label)
334                            .index(prop.name, {"type": "vector", "metric": metric})
335                            .apply()
336                        )
337                    except Exception:
338                        pass  # Index may already exist
339                elif prop.index_type == "fulltext":
340                    try:
341                        await (
342                            db.schema()
343                            .label(label)
344                            .index(prop.name, "fulltext")
345                            .apply()
346                        )
347                    except Exception:
348                        pass  # Index may already exist

Apply the generated schema to an async database.

Async variant of apply_to_database using AsyncSchemaBuilder.

@dataclass
class DatabaseSchema:
52@dataclass
53class DatabaseSchema:
54    """Complete database schema generated from models."""
55
56    labels: dict[str, LabelSchema] = field(default_factory=dict)
57    edge_types: dict[str, EdgeTypeSchema] = field(default_factory=dict)

Complete database schema generated from models.

DatabaseSchema( labels: dict[str, LabelSchema] = <factory>, edge_types: dict[str, EdgeTypeSchema] = <factory>)
labels: dict[str, LabelSchema]
edge_types: dict[str, EdgeTypeSchema]
@dataclass
class LabelSchema:
34@dataclass
35class LabelSchema:
36    """Schema for a vertex label."""
37
38    name: str
39    properties: dict[str, PropertySchema] = field(default_factory=dict)

Schema for a vertex label.

LabelSchema( name: str, properties: dict[str, PropertySchema] = <factory>)
name: str
properties: dict[str, PropertySchema]
@dataclass
class EdgeTypeSchema:
42@dataclass
43class EdgeTypeSchema:
44    """Schema for an edge type."""
45
46    name: str
47    from_labels: list[str] = field(default_factory=list)
48    to_labels: list[str] = field(default_factory=list)
49    properties: dict[str, PropertySchema] = field(default_factory=dict)

Schema for an edge type.

EdgeTypeSchema( name: str, from_labels: list[str] = <factory>, to_labels: list[str] = <factory>, properties: dict[str, PropertySchema] = <factory>)
name: str
from_labels: list[str]
to_labels: list[str]
properties: dict[str, PropertySchema]
@dataclass
class PropertySchema:
21@dataclass
22class PropertySchema:
23    """Schema for a single property."""
24
25    name: str
26    data_type: str
27    nullable: bool = False
28    index_type: str | None = None
29    unique: bool = False
30    tokenizer: str | None = None
31    metric: str | None = None

Schema for a single property.

PropertySchema( name: str, data_type: str, nullable: bool = False, index_type: str | None = None, unique: bool = False, tokenizer: str | None = None, metric: str | None = None)
name: str
data_type: str
nullable: bool = False
index_type: str | None = None
unique: bool = False
tokenizer: str | None = None
metric: str | None = None
def generate_schema( *models: type[UniNode] | type[UniEdge]) -> DatabaseSchema:
351def generate_schema(*models: type[UniNode] | type[UniEdge]) -> DatabaseSchema:
352    """Generate a database schema from the given models."""
353    generator = SchemaGenerator()
354    generator.register(*models)
355    return generator.generate()

Generate a database schema from the given models.

class UniDatabase:
15class UniDatabase:
16    """
17    Thin wrapper around uni-db UniBuilder for ergonomic database creation.
18
19    Example:
20        >>> db = UniDatabase.open("./path").cache_size(1024*1024).build()
21        >>> db = UniDatabase.temporary().build()
22        >>> db = UniDatabase.in_memory().build()
23    """
24
25    def __init__(self, builder: uni_db.UniBuilder) -> None:
26        self._builder = builder
27
28    @classmethod
29    def open(cls, path: str) -> UniDatabase:
30        """Open or create a database at the given path."""
31        import uni_db
32
33        return cls(uni_db.UniBuilder.open(path))
34
35    @classmethod
36    def create(cls, path: str) -> UniDatabase:
37        """Create a new database at the given path."""
38        import uni_db
39
40        return cls(uni_db.UniBuilder.create(path))
41
42    @classmethod
43    def open_existing(cls, path: str) -> UniDatabase:
44        """Open an existing database (must already exist)."""
45        import uni_db
46
47        return cls(uni_db.UniBuilder.open_existing(path))
48
49    @classmethod
50    def temporary(cls) -> UniDatabase:
51        """Create an ephemeral in-memory database."""
52        import uni_db
53
54        return cls(uni_db.UniBuilder.temporary())
55
56    @classmethod
57    def in_memory(cls) -> UniDatabase:
58        """Create a persistent in-memory database."""
59        import uni_db
60
61        return cls(uni_db.UniBuilder.in_memory())
62
63    def cache_size(self, bytes_: int) -> UniDatabase:
64        """Set the cache size in bytes."""
65        self._builder = self._builder.cache_size(bytes_)
66        return self
67
68    def parallelism(self, n: int) -> UniDatabase:
69        """Set the parallelism level."""
70        self._builder = self._builder.parallelism(n)
71        return self
72
73    def build(self) -> uni_db.Uni:
74        """Build and return the database instance."""
75        return self._builder.build()

Thin wrapper around uni-db UniBuilder for ergonomic database creation.

Example:

db = UniDatabase.open("./path").cache_size(1024*1024).build() db = UniDatabase.temporary().build() db = UniDatabase.in_memory().build()

UniDatabase(builder: UniBuilder)
25    def __init__(self, builder: uni_db.UniBuilder) -> None:
26        self._builder = builder
@classmethod
def open(cls, path: str) -> UniDatabase:
28    @classmethod
29    def open(cls, path: str) -> UniDatabase:
30        """Open or create a database at the given path."""
31        import uni_db
32
33        return cls(uni_db.UniBuilder.open(path))

Open or create a database at the given path.

@classmethod
def create(cls, path: str) -> UniDatabase:
35    @classmethod
36    def create(cls, path: str) -> UniDatabase:
37        """Create a new database at the given path."""
38        import uni_db
39
40        return cls(uni_db.UniBuilder.create(path))

Create a new database at the given path.

@classmethod
def open_existing(cls, path: str) -> UniDatabase:
42    @classmethod
43    def open_existing(cls, path: str) -> UniDatabase:
44        """Open an existing database (must already exist)."""
45        import uni_db
46
47        return cls(uni_db.UniBuilder.open_existing(path))

Open an existing database (must already exist).

@classmethod
def temporary(cls) -> UniDatabase:
49    @classmethod
50    def temporary(cls) -> UniDatabase:
51        """Create an ephemeral in-memory database."""
52        import uni_db
53
54        return cls(uni_db.UniBuilder.temporary())

Create an ephemeral in-memory database.

@classmethod
def in_memory(cls) -> UniDatabase:
56    @classmethod
57    def in_memory(cls) -> UniDatabase:
58        """Create a persistent in-memory database."""
59        import uni_db
60
61        return cls(uni_db.UniBuilder.in_memory())

Create a persistent in-memory database.

def cache_size(self, bytes_: int) -> UniDatabase:
63    def cache_size(self, bytes_: int) -> UniDatabase:
64        """Set the cache size in bytes."""
65        self._builder = self._builder.cache_size(bytes_)
66        return self

Set the cache size in bytes.

def parallelism(self, n: int) -> UniDatabase:
68    def parallelism(self, n: int) -> UniDatabase:
69        """Set the parallelism level."""
70        self._builder = self._builder.parallelism(n)
71        return self

Set the parallelism level.

def build(self) -> Uni:
73    def build(self) -> uni_db.Uni:
74        """Build and return the database instance."""
75        return self._builder.build()

Build and return the database instance.

class AsyncUniDatabase:
 78class AsyncUniDatabase:
 79    """
 80    Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation.
 81
 82    Example:
 83        >>> db = await AsyncUniDatabase.open("./path").build()
 84        >>> db = await AsyncUniDatabase.temporary().build()
 85    """
 86
 87    def __init__(self, builder: uni_db.AsyncUniBuilder) -> None:
 88        self._builder = builder
 89
 90    @classmethod
 91    def open(cls, path: str) -> AsyncUniDatabase:
 92        """Open or create a database at the given path."""
 93        import uni_db
 94
 95        return cls(uni_db.AsyncUniBuilder.open(path))
 96
 97    @classmethod
 98    def temporary(cls) -> AsyncUniDatabase:
 99        """Create an ephemeral in-memory database."""
100        import uni_db
101
102        return cls(uni_db.AsyncUniBuilder.temporary())
103
104    @classmethod
105    def in_memory(cls) -> AsyncUniDatabase:
106        """Create a persistent in-memory database."""
107        import uni_db
108
109        return cls(uni_db.AsyncUniBuilder.in_memory())
110
111    def cache_size(self, bytes_: int) -> AsyncUniDatabase:
112        """Set the cache size in bytes."""
113        self._builder = self._builder.cache_size(bytes_)
114        return self
115
116    def parallelism(self, n: int) -> AsyncUniDatabase:
117        """Set the parallelism level."""
118        self._builder = self._builder.parallelism(n)
119        return self
120
121    async def build(self) -> uni_db.AsyncUni:
122        """Build and return the async database instance."""
123        return await self._builder.build()

Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation.

Example:

db = await AsyncUniDatabase.open("./path").build() db = await AsyncUniDatabase.temporary().build()

AsyncUniDatabase(builder: AsyncUniBuilder)
87    def __init__(self, builder: uni_db.AsyncUniBuilder) -> None:
88        self._builder = builder
@classmethod
def open(cls, path: str) -> AsyncUniDatabase:
90    @classmethod
91    def open(cls, path: str) -> AsyncUniDatabase:
92        """Open or create a database at the given path."""
93        import uni_db
94
95        return cls(uni_db.AsyncUniBuilder.open(path))

Open or create a database at the given path.

@classmethod
def temporary(cls) -> AsyncUniDatabase:
 97    @classmethod
 98    def temporary(cls) -> AsyncUniDatabase:
 99        """Create an ephemeral in-memory database."""
100        import uni_db
101
102        return cls(uni_db.AsyncUniBuilder.temporary())

Create an ephemeral in-memory database.

@classmethod
def in_memory(cls) -> AsyncUniDatabase:
104    @classmethod
105    def in_memory(cls) -> AsyncUniDatabase:
106        """Create a persistent in-memory database."""
107        import uni_db
108
109        return cls(uni_db.AsyncUniBuilder.in_memory())

Create a persistent in-memory database.

def cache_size(self, bytes_: int) -> AsyncUniDatabase:
111    def cache_size(self, bytes_: int) -> AsyncUniDatabase:
112        """Set the cache size in bytes."""
113        self._builder = self._builder.cache_size(bytes_)
114        return self

Set the cache size in bytes.

def parallelism(self, n: int) -> AsyncUniDatabase:
116    def parallelism(self, n: int) -> AsyncUniDatabase:
117        """Set the parallelism level."""
118        self._builder = self._builder.parallelism(n)
119        return self

Set the parallelism level.

async def build(self) -> AsyncUni:
121    async def build(self) -> uni_db.AsyncUni:
122        """Build and return the async database instance."""
123        return await self._builder.build()

Build and return the async database instance.

def before_create(func: ~F) -> ~F:
40def before_create(func: F) -> F:
41    """
42    Mark a method to be called before the entity is created in the database.
43
44    The method is called after validation but before the INSERT operation.
45    Useful for setting timestamps, generating IDs, or final validation.
46
47    Example:
48        >>> class Person(UniNode):
49        ...     name: str
50        ...     created_at: datetime | None = None
51        ...
52        ...     @before_create
53        ...     def set_created_at(self):
54        ...         self.created_at = datetime.now()
55    """
56    return _mark_hook(_BEFORE_CREATE)(func)

Mark a method to be called before the entity is created in the database.

The method is called after validation but before the INSERT operation. Useful for setting timestamps, generating IDs, or final validation.

Example:

class Person(UniNode): ... name: str ... created_at: datetime | None = None ... ... @before_create ... def set_created_at(self): ... self.created_at = datetime.now()

def after_create(func: ~F) -> ~F:
59def after_create(func: F) -> F:
60    """
61    Mark a method to be called after the entity is created in the database.
62
63    The method is called after the INSERT operation completes successfully.
64    The entity will have its vid/eid assigned at this point.
65
66    Example:
67        >>> class Person(UniNode):
68        ...     name: str
69        ...
70        ...     @after_create
71        ...     def log_creation(self):
72        ...         logger.info(f"Created person {self.name} with vid={self.vid}")
73    """
74    return _mark_hook(_AFTER_CREATE)(func)

Mark a method to be called after the entity is created in the database.

The method is called after the INSERT operation completes successfully. The entity will have its vid/eid assigned at this point.

Example:

class Person(UniNode): ... name: str ... ... @after_create ... def log_creation(self): ... logger.info(f"Created person {self.name} with vid={self.vid}")

def before_update(func: ~F) -> ~F:
77def before_update(func: F) -> F:
78    """
79    Mark a method to be called before the entity is updated in the database.
80
81    The method is called before the UPDATE operation.
82    Useful for validation or updating timestamps.
83
84    Example:
85        >>> class Person(UniNode):
86        ...     name: str
87        ...     updated_at: datetime | None = None
88        ...
89        ...     @before_update
90        ...     def validate_and_timestamp(self):
91        ...         if not self.name:
92        ...             raise ValueError("Name cannot be empty")
93        ...         self.updated_at = datetime.now()
94    """
95    return _mark_hook(_BEFORE_UPDATE)(func)

Mark a method to be called before the entity is updated in the database.

The method is called before the UPDATE operation. Useful for validation or updating timestamps.

Example:

class Person(UniNode): ... name: str ... updated_at: datetime | None = None ... ... @before_update ... def validate_and_timestamp(self): ... if not self.name: ... raise ValueError("Name cannot be empty") ... self.updated_at = datetime.now()

def after_update(func: ~F) -> ~F:
 98def after_update(func: F) -> F:
 99    """
100    Mark a method to be called after the entity is updated in the database.
101
102    The method is called after the UPDATE operation completes successfully.
103
104    Example:
105        >>> class Person(UniNode):
106        ...     name: str
107        ...
108        ...     @after_update
109        ...     def notify_change(self):
110        ...         events.emit("person_updated", self.vid)
111    """
112    return _mark_hook(_AFTER_UPDATE)(func)

Mark a method to be called after the entity is updated in the database.

The method is called after the UPDATE operation completes successfully.

Example:

class Person(UniNode): ... name: str ... ... @after_update ... def notify_change(self): ... events.emit("person_updated", self.vid)

def before_delete(func: ~F) -> ~F:
115def before_delete(func: F) -> F:
116    """
117    Mark a method to be called before the entity is deleted from the database.
118
119    The method is called before the DELETE operation.
120    Useful for cleanup or validation.
121
122    Example:
123        >>> class Person(UniNode):
124        ...     name: str
125        ...
126        ...     @before_delete
127        ...     def cleanup(self):
128        ...         # Remove related data
129        ...         pass
130    """
131    return _mark_hook(_BEFORE_DELETE)(func)

Mark a method to be called before the entity is deleted from the database.

The method is called before the DELETE operation. Useful for cleanup or validation.

Example:

class Person(UniNode): ... name: str ... ... @before_delete ... def cleanup(self): ... # Remove related data ... pass

def after_delete(func: ~F) -> ~F:
134def after_delete(func: F) -> F:
135    """
136    Mark a method to be called after the entity is deleted from the database.
137
138    The method is called after the DELETE operation completes successfully.
139    The entity's vid/eid will be cleared at this point.
140
141    Example:
142        >>> class Person(UniNode):
143        ...     name: str
144        ...
145        ...     @after_delete
146        ...     def log_deletion(self):
147        ...         logger.info(f"Deleted person {self.name}")
148    """
149    return _mark_hook(_AFTER_DELETE)(func)

Mark a method to be called after the entity is deleted from the database.

The method is called after the DELETE operation completes successfully. The entity's vid/eid will be cleared at this point.

Example:

class Person(UniNode): ... name: str ... ... @after_delete ... def log_deletion(self): ... logger.info(f"Deleted person {self.name}")

def before_load(func: ~F) -> ~F:
152def before_load(func: F) -> F:
153    """
154    Mark a method to be called before the entity is loaded from the database.
155
156    This is a class method that receives the raw property dictionary.
157    Can be used to transform data before model instantiation.
158
159    Example:
160        >>> class Person(UniNode):
161        ...     name: str
162        ...
163        ...     @classmethod
164        ...     @before_load
165        ...     def transform_data(cls, props: dict) -> dict:
166        ...         # Normalize name
167        ...         if 'name' in props:
168        ...             props['name'] = props['name'].strip()
169        ...         return props
170    """
171    return _mark_hook(_BEFORE_LOAD)(func)

Mark a method to be called before the entity is loaded from the database.

This is a class method that receives the raw property dictionary. Can be used to transform data before model instantiation.

Example:

class Person(UniNode): ... name: str ... ... @classmethod ... @before_load ... def transform_data(cls, props: dict) -> dict: ... # Normalize name ... if 'name' in props: ... props['name'] = props['name'].strip() ... return props

def after_load(func: ~F) -> ~F:
174def after_load(func: F) -> F:
175    """
176    Mark a method to be called after the entity is loaded from the database.
177
178    The method is called after the entity is instantiated from database data.
179    Useful for computing derived values or initializing non-persisted state.
180
181    Example:
182        >>> class Person(UniNode):
183        ...     first_name: str
184        ...     last_name: str
185        ...     full_name: str | None = None
186        ...
187        ...     @after_load
188        ...     def compute_full_name(self):
189        ...         self.full_name = f"{self.first_name} {self.last_name}"
190    """
191    return _mark_hook(_AFTER_LOAD)(func)

Mark a method to be called after the entity is loaded from the database.

The method is called after the entity is instantiated from database data. Useful for computing derived values or initializing non-persisted state.

Example:

class Person(UniNode): ... first_name: str ... last_name: str ... full_name: str | None = None ... ... @after_load ... def compute_full_name(self): ... self.full_name = f"{self.first_name} {self.last_name}"

class UniPydanticError(builtins.Exception):
15class UniPydanticError(Exception):
16    """Base exception for all uni-pydantic errors."""

Base exception for all uni-pydantic errors.

class SchemaError(uni_pydantic.UniPydanticError):
19class SchemaError(UniPydanticError):
20    """Error related to schema definition or generation."""
21
22    def __init__(self, message: str, model: type | None = None) -> None:
23        self.model = model
24        super().__init__(message)

Error related to schema definition or generation.

SchemaError(message: str, model: type | None = None)
22    def __init__(self, message: str, model: type | None = None) -> None:
23        self.model = model
24        super().__init__(message)
model
class TypeMappingError(uni_pydantic.SchemaError):
27class TypeMappingError(SchemaError):
28    """Error mapping Python type to Uni DataType."""
29
30    def __init__(self, python_type: Any, message: str | None = None) -> None:
31        self.python_type = python_type
32        msg = message or f"Cannot map Python type {python_type!r} to Uni DataType"
33        super().__init__(msg)

Error mapping Python type to Uni DataType.

TypeMappingError(python_type: Any, message: str | None = None)
30    def __init__(self, python_type: Any, message: str | None = None) -> None:
31        self.python_type = python_type
32        msg = message or f"Cannot map Python type {python_type!r} to Uni DataType"
33        super().__init__(msg)
python_type
class ValidationError(uni_pydantic.UniPydanticError):
36class ValidationError(UniPydanticError):
37    """Validation error for model instances."""

Validation error for model instances.

class SessionError(uni_pydantic.UniPydanticError):
40class SessionError(UniPydanticError):
41    """Error related to session operations."""

Error related to session operations.

class NotRegisteredError(uni_pydantic.SessionError):
44class NotRegisteredError(SessionError):
45    """Model type not registered with session."""
46
47    def __init__(self, model: type[UniNode] | type[UniEdge]) -> None:
48        self.model = model
49        super().__init__(
50            f"Model {model.__name__!r} is not registered with this session. "
51            f"Call session.register({model.__name__}) first."
52        )

Model type not registered with session.

NotRegisteredError( model: type[UniNode] | type[UniEdge])
47    def __init__(self, model: type[UniNode] | type[UniEdge]) -> None:
48        self.model = model
49        super().__init__(
50            f"Model {model.__name__!r} is not registered with this session. "
51            f"Call session.register({model.__name__}) first."
52        )
model
class NotPersisted(uni_pydantic.SessionError):
55class NotPersisted(SessionError):
56    """Operation requires a persisted entity."""
57
58    def __init__(self, entity: UniNode | UniEdge) -> None:
59        self.entity = entity
60        super().__init__(
61            f"Entity {entity!r} is not persisted. Call session.add() and commit() first."
62        )

Operation requires a persisted entity.

NotPersisted(entity: UniNode | UniEdge)
58    def __init__(self, entity: UniNode | UniEdge) -> None:
59        self.entity = entity
60        super().__init__(
61            f"Entity {entity!r} is not persisted. Call session.add() and commit() first."
62        )
entity
class NotTrackedError(uni_pydantic.SessionError):
65class NotTrackedError(SessionError):
66    """Entity is not tracked by this session."""

Entity is not tracked by this session.

class TransactionError(uni_pydantic.SessionError):
69class TransactionError(SessionError):
70    """Error related to transaction operations."""

Error related to transaction operations.

class QueryError(uni_pydantic.UniPydanticError):
73class QueryError(UniPydanticError):
74    """Error executing a query."""

Error executing a query.

class RelationshipError(uni_pydantic.UniPydanticError):
77class RelationshipError(UniPydanticError):
78    """Error related to relationship operations."""

Error related to relationship operations.

class LazyLoadError(uni_pydantic.RelationshipError):
81class LazyLoadError(RelationshipError):
82    """Error lazy-loading a relationship."""
83
84    def __init__(self, field_name: str, reason: str) -> None:
85        self.field_name = field_name
86        super().__init__(f"Cannot lazy-load relationship '{field_name}': {reason}")

Error lazy-loading a relationship.

LazyLoadError(field_name: str, reason: str)
84    def __init__(self, field_name: str, reason: str) -> None:
85        self.field_name = field_name
86        super().__init__(f"Cannot lazy-load relationship '{field_name}': {reason}")
field_name
class BulkLoadError(uni_pydantic.UniPydanticError):
89class BulkLoadError(UniPydanticError):
90    """Error during bulk loading operations."""

Error during bulk loading operations.

class CypherInjectionError(uni_pydantic.QueryError):
93class CypherInjectionError(QueryError):
94    """Property name validation failure — potential Cypher injection."""
95
96    def __init__(self, name: str, reason: str | None = None) -> None:
97        self.name = name
98        msg = reason or f"Invalid property name {name!r}: possible Cypher injection"
99        super().__init__(msg)

Property name validation failure — potential Cypher injection.

CypherInjectionError(name: str, reason: str | None = None)
96    def __init__(self, name: str, reason: str | None = None) -> None:
97        self.name = name
98        msg = reason or f"Invalid property name {name!r}: possible Cypher injection"
99        super().__init__(msg)
name