uni_pydantic

uni-pydantic: Pydantic-based OGM for Uni Graph Database.

This package provides a type-safe Object-Graph Mapping layer on top of the Uni graph database, using Pydantic v2 for model definitions.

Example:

from uni_db import Uni from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector

class Person(UniNode): ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... embedding: Vector[1536] ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")

db = Uni("./my_graph") session = UniSession(db) session.register(Person) session.sync_schema()

alice = Person(name="Alice", age=30, email="alice@example.com") session.add(alice) session.commit()

Query with type safety

adults = session.query(Person).filter(Person.age >= 18).all()

  1# SPDX-License-Identifier: Apache-2.0
  2# Copyright 2024-2026 Dragonscale Team
  3
  4"""
  5uni-pydantic: Pydantic-based OGM for Uni Graph Database.
  6
  7This package provides a type-safe Object-Graph Mapping layer on top of
  8the Uni graph database, using Pydantic v2 for model definitions.
  9
 10Example:
 11    >>> from uni_db import Uni
 12    >>> from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector
 13    >>>
 14    >>> class Person(UniNode):
 15    ...     name: str
 16    ...     age: int | None = None
 17    ...     email: str = Field(unique=True)
 18    ...     embedding: Vector[1536]
 19    ...     friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
 20    >>>
 21    >>> db = Uni("./my_graph")
 22    >>> session = UniSession(db)
 23    >>> session.register(Person)
 24    >>> session.sync_schema()
 25    >>>
 26    >>> alice = Person(name="Alice", age=30, email="alice@example.com")
 27    >>> session.add(alice)
 28    >>> session.commit()
 29    >>>
 30    >>> # Query with type safety
 31    >>> adults = session.query(Person).filter(Person.age >= 18).all()
 32"""
 33
 34__version__ = "2.4.1"
 35
 36# Base classes
 37# Async support
 38from .async_query import AsyncQueryBuilder
 39from .async_session import AsyncUniSession, AsyncUniTransaction
 40from .base import UniEdge, UniNode
 41
 42# Database wrappers
 43from .database import AsyncUniDatabase, UniDatabase
 44
 45# Exceptions
 46from .exceptions import (
 47    BulkLoadError,
 48    CypherInjectionError,
 49    LazyLoadError,
 50    NotPersisted,
 51    NotRegisteredError,
 52    NotTrackedError,
 53    QueryError,
 54    RelationshipError,
 55    SchemaError,
 56    SessionError,
 57    TransactionError,
 58    TypeMappingError,
 59    UniPydanticError,
 60    ValidationError,
 61)
 62
 63# Field configuration
 64from .fields import (
 65    Direction,
 66    Field,
 67    FieldConfig,
 68    IndexType,
 69    Relationship,
 70    RelationshipConfig,
 71    RelationshipDescriptor,
 72    VectorMetric,
 73    get_field_config,
 74)
 75
 76# Lifecycle hooks
 77from .hooks import (
 78    after_create,
 79    after_delete,
 80    after_load,
 81    after_update,
 82    before_create,
 83    before_delete,
 84    before_load,
 85    before_update,
 86)
 87
 88# Query builder
 89from .query import (
 90    FilterExpr,
 91    FilterOp,
 92    ModelProxy,
 93    OrderByClause,
 94    PropertyProxy,
 95    QueryBuilder,
 96    SparseSearchConfig,
 97    TraversalStep,
 98    VectorSearchConfig,
 99)
100
101# Schema generation
102from .schema import (
103    DatabaseSchema,
104    EdgeTypeSchema,
105    LabelSchema,
106    PropertySchema,
107    SchemaGenerator,
108    generate_schema,
109)
110
111# Session management
112from .session import UniSession, UniTransaction
113
114# Type utilities
115from .types import (
116    DATETIME_TYPES,
117    Btic,
118    SparseVector,
119    Vector,
120    db_to_python_value,
121    get_sparse_vector_dimensions,
122    get_vector_dimensions,
123    is_list_type,
124    is_optional,
125    python_to_db_value,
126    python_type_to_uni,
127    uni_to_python_type,
128    unwrap_annotated,
129)
130
131__all__ = [
132    # Version
133    "__version__",
134    # Base classes
135    "UniNode",
136    "UniEdge",
137    # Session
138    "UniSession",
139    "UniTransaction",
140    # Async Session
141    "AsyncUniSession",
142    "AsyncUniTransaction",
143    # Fields
144    "Field",
145    "FieldConfig",
146    "Relationship",
147    "RelationshipConfig",
148    "RelationshipDescriptor",
149    "get_field_config",
150    "IndexType",
151    "Direction",
152    "VectorMetric",
153    # Types
154    "Btic",
155    "Vector",
156    "SparseVector",
157    "python_type_to_uni",
158    "uni_to_python_type",
159    "get_vector_dimensions",
160    "get_sparse_vector_dimensions",
161    "is_optional",
162    "is_list_type",
163    "unwrap_annotated",
164    "python_to_db_value",
165    "db_to_python_value",
166    "DATETIME_TYPES",
167    # Query
168    "QueryBuilder",
169    "AsyncQueryBuilder",
170    "FilterExpr",
171    "FilterOp",
172    "PropertyProxy",
173    "ModelProxy",
174    "OrderByClause",
175    "TraversalStep",
176    "VectorSearchConfig",
177    "SparseSearchConfig",
178    # Schema
179    "SchemaGenerator",
180    "DatabaseSchema",
181    "LabelSchema",
182    "EdgeTypeSchema",
183    "PropertySchema",
184    "generate_schema",
185    # Database
186    "UniDatabase",
187    "AsyncUniDatabase",
188    # Hooks
189    "before_create",
190    "after_create",
191    "before_update",
192    "after_update",
193    "before_delete",
194    "after_delete",
195    "before_load",
196    "after_load",
197    # Exceptions
198    "UniPydanticError",
199    "SchemaError",
200    "TypeMappingError",
201    "ValidationError",
202    "SessionError",
203    "NotRegisteredError",
204    "NotPersisted",
205    "NotTrackedError",
206    "TransactionError",
207    "QueryError",
208    "RelationshipError",
209    "LazyLoadError",
210    "BulkLoadError",
211    "CypherInjectionError",
212]
__version__ = '2.4.1'
class UniNode(pydantic.main.BaseModel):
159class UniNode(BaseModel, metaclass=UniModelMeta):
160    """
161    Base class for graph node models.
162
163    Subclass this to define your node types. Each UniNode subclass
164    represents a vertex label in the graph database.
165
166    Attributes:
167        __label__: The vertex label name. Defaults to the class name.
168        __relationships__: Dictionary of relationship configurations.
169
170    Private Attributes:
171        _vid: The vertex ID assigned by the database.
172        _uid: The unique identifier (content-addressed hash).
173        _session: Reference to the owning session.
174        _dirty: Set of modified field names.
175
176    Example:
177        >>> class Person(UniNode):
178        ...     __label__ = "Person"
179        ...
180        ...     name: str
181        ...     age: int | None = None
182        ...     email: str = Field(unique=True)
183        ...
184        ...     friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
185    """
186
187    model_config = ConfigDict(
188        # Allow extra fields for future extensibility
189        extra="forbid",
190        # Validate on assignment for dirty tracking
191        validate_assignment=True,
192        # Allow arbitrary types (for Vector, etc.)
193        arbitrary_types_allowed=True,
194        # Use enum values
195        use_enum_values=True,
196    )
197
198    # Class-level configuration
199    __label__: ClassVar[str] = ""
200    __relationships__: ClassVar[dict[str, RelationshipConfig]] = {}
201
202    # Private attributes for session tracking
203    _vid: int | None = PrivateAttr(default=None)
204    _uid: str | None = PrivateAttr(default=None)
205    _session: UniSession | None = PrivateAttr(default=None)
206    _dirty: set[str] = PrivateAttr(default_factory=set)
207    _is_new: bool = PrivateAttr(default=True)
208
209    def __init_subclass__(cls, **kwargs: Any) -> None:
210        super().__init_subclass__(**kwargs)
211        # Set default label to class name if not specified
212        if not cls.__label__:
213            cls.__label__ = cls.__name__
214
215    def model_post_init(self, __context: Any) -> None:
216        """Clear dirty tracking after construction."""
217        super().model_post_init(__context)
218        self._dirty = set()
219
220    @property
221    def vid(self) -> int | None:
222        """The vertex ID assigned by the database."""
223        return self._vid
224
225    @property
226    def uid(self) -> str | None:
227        """The unique identifier (content-addressed hash)."""
228        return self._uid
229
230    @property
231    def is_persisted(self) -> bool:
232        """Whether this node has been saved to the database."""
233        return self._vid is not None
234
235    @property
236    def is_dirty(self) -> bool:
237        """Whether this node has unsaved changes."""
238        return bool(self._dirty)
239
240    def __setattr__(self, name: str, value: Any) -> None:
241        # Track dirty fields (but not private attributes)
242        if not name.startswith("_") and hasattr(self, "_dirty"):
243            self._dirty.add(name)
244        super().__setattr__(name, value)
245
246    def _mark_clean(self) -> None:
247        """Mark all fields as clean (called after commit)."""
248        self._dirty.clear()
249        self._is_new = False
250
251    def _attach_session(
252        self, session: UniSession, vid: int, uid: str | None = None
253    ) -> None:
254        """Attach this node to a session with its database IDs."""
255        self._session = session
256        self._vid = vid
257        self._uid = uid
258        self._is_new = False
259
260    @classmethod
261    def get_property_fields(cls) -> dict[str, FieldInfo]:
262        """Get all property fields (excluding relationships)."""
263        return {
264            name: info
265            for name, info in cls.model_fields.items()
266            if name not in cls.__relationships__
267        }
268
269    @classmethod
270    def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
271        """Get all relationship field configurations."""
272        return cls.__relationships__
273
274    def to_properties(self) -> dict[str, Any]:
275        """Convert to a property dictionary for database storage.
276
277        Uses python_to_db_value() for type conversion. Includes None
278        explicitly so null-outs work.
279        """
280        return _model_to_properties(self, self.get_property_fields())
281
282    @classmethod
283    def from_properties(
284        cls,
285        props: dict[str, Any],
286        *,
287        vid: int | None = None,
288        uid: str | None = None,
289        session: UniSession | None = None,
290    ) -> UniNode:
291        """Create an instance from a property dictionary.
292
293        Accepts _id (string->int vid) and _label from uni-db node dicts.
294        Does not mutate the input dict.
295        """
296        data = dict(props)
297
298        raw_id = data.pop("_id", None)
299        if raw_id is not None and vid is None:
300            vid = int(raw_id) if not isinstance(raw_id, int) else raw_id
301        data.pop("_label", None)
302
303        converted = _convert_db_values(data, cls)
304
305        instance = cls.model_validate(converted)
306        if vid is not None:
307            instance._vid = vid
308        if uid is not None:
309            instance._uid = uid
310        if session is not None:
311            instance._session = session
312        instance._is_new = vid is None
313        instance._dirty = set()
314        return instance
315
316    def __repr__(self) -> str:
317        vid_str = f"vid={self._vid}" if self._vid else "unsaved"
318        return f"{self.__class__.__name__}({vid_str}, {super().__repr__()})"

Base class for graph node models.

Subclass this to define your node types. Each UniNode subclass represents a vertex label in the graph database.

Attributes: __label__: The vertex label name. Defaults to the class name. __relationships__: Dictionary of relationship configurations.

Private Attributes: _vid: The vertex ID assigned by the database. _uid: The unique identifier (content-addressed hash). _session: Reference to the owning session. _dirty: Set of modified field names.

Example:

class Person(UniNode): ... __label__ = "Person" ... ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")

vid: int | None
220    @property
221    def vid(self) -> int | None:
222        """The vertex ID assigned by the database."""
223        return self._vid

The vertex ID assigned by the database.

uid: str | None
225    @property
226    def uid(self) -> str | None:
227        """The unique identifier (content-addressed hash)."""
228        return self._uid

The unique identifier (content-addressed hash).

is_persisted: bool
230    @property
231    def is_persisted(self) -> bool:
232        """Whether this node has been saved to the database."""
233        return self._vid is not None

Whether this node has been saved to the database.

is_dirty: bool
235    @property
236    def is_dirty(self) -> bool:
237        """Whether this node has unsaved changes."""
238        return bool(self._dirty)

Whether this node has unsaved changes.

@classmethod
def get_property_fields(cls) -> dict[str, pydantic.fields.FieldInfo]:
260    @classmethod
261    def get_property_fields(cls) -> dict[str, FieldInfo]:
262        """Get all property fields (excluding relationships)."""
263        return {
264            name: info
265            for name, info in cls.model_fields.items()
266            if name not in cls.__relationships__
267        }

Get all property fields (excluding relationships).

@classmethod
def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
269    @classmethod
270    def get_relationship_fields(cls) -> dict[str, RelationshipConfig]:
271        """Get all relationship field configurations."""
272        return cls.__relationships__

Get all relationship field configurations.

def to_properties(self) -> dict[str, typing.Any]:
274    def to_properties(self) -> dict[str, Any]:
275        """Convert to a property dictionary for database storage.
276
277        Uses python_to_db_value() for type conversion. Includes None
278        explicitly so null-outs work.
279        """
280        return _model_to_properties(self, self.get_property_fields())

Convert to a property dictionary for database storage.

Uses python_to_db_value() for type conversion. Includes None explicitly so null-outs work.

@classmethod
def from_properties( cls, props: dict[str, typing.Any], *, vid: int | None = None, uid: str | None = None, session: UniSession | None = None) -> UniNode:
282    @classmethod
283    def from_properties(
284        cls,
285        props: dict[str, Any],
286        *,
287        vid: int | None = None,
288        uid: str | None = None,
289        session: UniSession | None = None,
290    ) -> UniNode:
291        """Create an instance from a property dictionary.
292
293        Accepts _id (string->int vid) and _label from uni-db node dicts.
294        Does not mutate the input dict.
295        """
296        data = dict(props)
297
298        raw_id = data.pop("_id", None)
299        if raw_id is not None and vid is None:
300            vid = int(raw_id) if not isinstance(raw_id, int) else raw_id
301        data.pop("_label", None)
302
303        converted = _convert_db_values(data, cls)
304
305        instance = cls.model_validate(converted)
306        if vid is not None:
307            instance._vid = vid
308        if uid is not None:
309            instance._uid = uid
310        if session is not None:
311            instance._session = session
312        instance._is_new = vid is None
313        instance._dirty = set()
314        return instance

Create an instance from a property dictionary.

Accepts _id (string->int vid) and _label from uni-db node dicts. Does not mutate the input dict.

class UniEdge(pydantic.main.BaseModel):
321class UniEdge(BaseModel, metaclass=UniModelMeta):
322    """
323    Base class for graph edge models with properties.
324
325    Subclass this to define edge types with typed properties.
326    Edges represent relationships between nodes.
327
328    Attributes:
329        __edge_type__: The edge type name.
330        __from__: The source node type(s).
331        __to__: The target node type(s).
332
333    Private Attributes:
334        _eid: The edge ID assigned by the database.
335        _src_vid: The source vertex ID.
336        _dst_vid: The destination vertex ID.
337        _session: Reference to the owning session.
338
339    Example:
340        >>> class FriendshipEdge(UniEdge):
341        ...     __edge_type__ = "FRIEND_OF"
342        ...     __from__ = Person
343        ...     __to__ = Person
344        ...
345        ...     since: date
346        ...     strength: float = 1.0
347    """
348
349    model_config = ConfigDict(
350        extra="forbid",
351        validate_assignment=True,
352        arbitrary_types_allowed=True,
353        use_enum_values=True,
354    )
355
356    # Class-level configuration
357    __edge_type__: ClassVar[str] = ""
358    __from__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None
359    __to__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None
360    __relationships__: ClassVar[dict[str, RelationshipConfig]] = {}
361
362    # Private attributes
363    _eid: int | None = PrivateAttr(default=None)
364    _src_vid: int | None = PrivateAttr(default=None)
365    _dst_vid: int | None = PrivateAttr(default=None)
366    _session: UniSession | None = PrivateAttr(default=None)
367    _is_new: bool = PrivateAttr(default=True)
368
369    def __init_subclass__(cls, **kwargs: Any) -> None:
370        super().__init_subclass__(**kwargs)
371        # Set default edge type to class name if not specified
372        if not cls.__edge_type__:
373            cls.__edge_type__ = cls.__name__
374
375    @property
376    def eid(self) -> int | None:
377        """The edge ID assigned by the database."""
378        return self._eid
379
380    @property
381    def src_vid(self) -> int | None:
382        """The source vertex ID."""
383        return self._src_vid
384
385    @property
386    def dst_vid(self) -> int | None:
387        """The destination vertex ID."""
388        return self._dst_vid
389
390    @property
391    def is_persisted(self) -> bool:
392        """Whether this edge has been saved to the database."""
393        return self._eid is not None
394
395    def _attach(
396        self,
397        session: UniSession,
398        eid: int,
399        src_vid: int,
400        dst_vid: int,
401    ) -> None:
402        """Attach this edge to a session with its database IDs."""
403        self._session = session
404        self._eid = eid
405        self._src_vid = src_vid
406        self._dst_vid = dst_vid
407        self._is_new = False
408
409    @classmethod
410    def get_from_labels(cls) -> list[str]:
411        """Get the source label names."""
412        if cls.__from__ is None:
413            return []
414        if isinstance(cls.__from__, tuple):
415            return [n.__label__ for n in cls.__from__]
416        return [cls.__from__.__label__]
417
418    @classmethod
419    def get_to_labels(cls) -> list[str]:
420        """Get the target label names."""
421        if cls.__to__ is None:
422            return []
423        if isinstance(cls.__to__, tuple):
424            return [n.__label__ for n in cls.__to__]
425        return [cls.__to__.__label__]
426
427    @classmethod
428    def get_property_fields(cls) -> dict[str, FieldInfo]:
429        """Get all property fields."""
430        return dict(cls.model_fields)
431
432    def to_properties(self) -> dict[str, Any]:
433        """Convert to a property dictionary for database storage."""
434        return _model_to_properties(self, self.get_property_fields())
435
436    @classmethod
437    def from_properties(
438        cls,
439        props: dict[str, Any],
440        *,
441        eid: int | None = None,
442        src_vid: int | None = None,
443        dst_vid: int | None = None,
444        session: UniSession | None = None,
445    ) -> UniEdge:
446        """Create an instance from a property dictionary.
447
448        Accepts _id, _type, _src, _dst from uni-db edge dicts.
449        Does not mutate the input dict.
450        """
451        data = dict(props)
452
453        raw_id = data.pop("_id", None)
454        if raw_id is not None and eid is None:
455            eid = int(raw_id) if not isinstance(raw_id, int) else raw_id
456        data.pop("_type", None)
457        raw_src = data.pop("_src", None)
458        if raw_src is not None and src_vid is None:
459            src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src
460        raw_dst = data.pop("_dst", None)
461        if raw_dst is not None and dst_vid is None:
462            dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst
463
464        converted = _convert_db_values(data, cls)
465
466        instance = cls.model_validate(converted)
467        if eid is not None:
468            instance._eid = eid
469        if src_vid is not None:
470            instance._src_vid = src_vid
471        if dst_vid is not None:
472            instance._dst_vid = dst_vid
473        if session is not None:
474            instance._session = session
475        instance._is_new = eid is None
476        return instance
477
478    @classmethod
479    def from_edge_result(
480        cls,
481        data: dict[str, Any],
482        *,
483        session: UniSession | None = None,
484    ) -> UniEdge:
485        """Create an instance from a uni-db edge result dict.
486
487        Convenience method that handles _id, _type, _src, _dst keys.
488        """
489        return cls.from_properties(data, session=session)
490
491    def __repr__(self) -> str:
492        eid_str = f"eid={self._eid}" if self._eid else "unsaved"
493        return f"{self.__class__.__name__}({eid_str}, {super().__repr__()})"

Base class for graph edge models with properties.

Subclass this to define edge types with typed properties. Edges represent relationships between nodes.

Attributes: __edge_type__: The edge type name. __from__: The source node type(s). __to__: The target node type(s).

Private Attributes: _eid: The edge ID assigned by the database. _src_vid: The source vertex ID. _dst_vid: The destination vertex ID. _session: Reference to the owning session.

Example:

class FriendshipEdge(UniEdge): ... __edge_type__ = "FRIEND_OF" ... __from__ = Person ... __to__ = Person ... ... since: date ... strength: float = 1.0

eid: int | None
375    @property
376    def eid(self) -> int | None:
377        """The edge ID assigned by the database."""
378        return self._eid

The edge ID assigned by the database.

src_vid: int | None
380    @property
381    def src_vid(self) -> int | None:
382        """The source vertex ID."""
383        return self._src_vid

The source vertex ID.

dst_vid: int | None
385    @property
386    def dst_vid(self) -> int | None:
387        """The destination vertex ID."""
388        return self._dst_vid

The destination vertex ID.

is_persisted: bool
390    @property
391    def is_persisted(self) -> bool:
392        """Whether this edge has been saved to the database."""
393        return self._eid is not None

Whether this edge has been saved to the database.

@classmethod
def get_from_labels(cls) -> list[str]:
409    @classmethod
410    def get_from_labels(cls) -> list[str]:
411        """Get the source label names."""
412        if cls.__from__ is None:
413            return []
414        if isinstance(cls.__from__, tuple):
415            return [n.__label__ for n in cls.__from__]
416        return [cls.__from__.__label__]

Get the source label names.

@classmethod
def get_to_labels(cls) -> list[str]:
418    @classmethod
419    def get_to_labels(cls) -> list[str]:
420        """Get the target label names."""
421        if cls.__to__ is None:
422            return []
423        if isinstance(cls.__to__, tuple):
424            return [n.__label__ for n in cls.__to__]
425        return [cls.__to__.__label__]

Get the target label names.

@classmethod
def get_property_fields(cls) -> dict[str, pydantic.fields.FieldInfo]:
427    @classmethod
428    def get_property_fields(cls) -> dict[str, FieldInfo]:
429        """Get all property fields."""
430        return dict(cls.model_fields)

Get all property fields.

def to_properties(self) -> dict[str, typing.Any]:
432    def to_properties(self) -> dict[str, Any]:
433        """Convert to a property dictionary for database storage."""
434        return _model_to_properties(self, self.get_property_fields())

Convert to a property dictionary for database storage.

@classmethod
def from_properties( cls, props: dict[str, typing.Any], *, eid: int | None = None, src_vid: int | None = None, dst_vid: int | None = None, session: UniSession | None = None) -> UniEdge:
436    @classmethod
437    def from_properties(
438        cls,
439        props: dict[str, Any],
440        *,
441        eid: int | None = None,
442        src_vid: int | None = None,
443        dst_vid: int | None = None,
444        session: UniSession | None = None,
445    ) -> UniEdge:
446        """Create an instance from a property dictionary.
447
448        Accepts _id, _type, _src, _dst from uni-db edge dicts.
449        Does not mutate the input dict.
450        """
451        data = dict(props)
452
453        raw_id = data.pop("_id", None)
454        if raw_id is not None and eid is None:
455            eid = int(raw_id) if not isinstance(raw_id, int) else raw_id
456        data.pop("_type", None)
457        raw_src = data.pop("_src", None)
458        if raw_src is not None and src_vid is None:
459            src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src
460        raw_dst = data.pop("_dst", None)
461        if raw_dst is not None and dst_vid is None:
462            dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst
463
464        converted = _convert_db_values(data, cls)
465
466        instance = cls.model_validate(converted)
467        if eid is not None:
468            instance._eid = eid
469        if src_vid is not None:
470            instance._src_vid = src_vid
471        if dst_vid is not None:
472            instance._dst_vid = dst_vid
473        if session is not None:
474            instance._session = session
475        instance._is_new = eid is None
476        return instance

Create an instance from a property dictionary.

Accepts _id, _type, _src, _dst from uni-db edge dicts. Does not mutate the input dict.

@classmethod
def from_edge_result( cls, data: dict[str, typing.Any], *, session: UniSession | None = None) -> UniEdge:
478    @classmethod
479    def from_edge_result(
480        cls,
481        data: dict[str, Any],
482        *,
483        session: UniSession | None = None,
484    ) -> UniEdge:
485        """Create an instance from a uni-db edge result dict.
486
487        Convenience method that handles _id, _type, _src, _dst keys.
488        """
489        return cls.from_properties(data, session=session)

Create an instance from a uni-db edge result dict.

Convenience method that handles _id, _type, _src, _dst keys.

class UniSession:
160class UniSession:
161    """
162    Session for interacting with the graph database using Pydantic models.
163
164    The session manages model registration, schema synchronization,
165    and provides CRUD operations and query building.
166
167    Example:
168        >>> from uni_db import Uni
169        >>> from uni_pydantic import UniSession
170        >>>
171        >>> db = Uni("./my_graph")
172        >>> session = UniSession(db)
173        >>> session.register(Person, Company)
174        >>> session.sync_schema()
175        >>>
176        >>> alice = Person(name="Alice", age=30)
177        >>> session.add(alice)
178        >>> session.commit()
179    """
180
181    def __init__(self, db: uni_db.Uni) -> None:
182        self._db = db
183        self._db_session = db.session()
184        self._schema_gen = SchemaGenerator()
185        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
186            WeakValueDictionary()
187        )
188        self._pending_new: list[UniNode] = []
189        self._pending_delete: list[UniNode] = []
190
191    def __enter__(self) -> UniSession:
192        return self
193
194    def __exit__(
195        self,
196        exc_type: type[BaseException] | None,
197        exc_val: BaseException | None,
198        exc_tb: TracebackType | None,
199    ) -> None:
200        self.close()
201
202    def close(self) -> None:
203        """Close the session and clear all pending state."""
204        self._pending_new.clear()
205        self._pending_delete.clear()
206
207    @property
208    def db(self) -> uni_db.Uni:
209        """Access the underlying uni_db.Uni for low-level operations."""
210        return self._db
211
212    def locy(
213        self, program: str, params: dict[str, Any] | None = None
214    ) -> uni_db.LocyResult:
215        """
216        Evaluate a Locy program and return derived facts, stats, and warnings.
217
218        Delegates to the underlying ``uni_db.Session.locy()``.
219        """
220        return self._db_session.locy(program, params)
221
222    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
223        """
224        Register model classes with the session.
225
226        Registered models can be used for schema generation and queries.
227
228        Args:
229            *models: UniNode or UniEdge subclasses to register.
230        """
231        self._schema_gen.register(*models)
232
233    def sync_schema(self) -> None:
234        """
235        Synchronize database schema with registered models.
236
237        Creates labels, edge types, properties, and indexes as needed.
238        This is additive-only; it won't remove existing schema elements.
239        """
240        self._schema_gen.apply_to_database(self._db)
241
242    def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]:
243        """
244        Create a query builder for the given model.
245
246        Args:
247            model: The UniNode subclass to query.
248
249        Returns:
250            A QueryBuilder for constructing queries.
251        """
252        return QueryBuilder(self, model)
253
254    def add(self, entity: UniNode) -> None:
255        """
256        Add a new entity to be persisted.
257
258        The entity will be inserted on the next commit().
259        """
260        if entity.is_persisted:
261            raise SessionError(f"Entity {entity!r} is already persisted")
262        entity._session = self
263        self._pending_new.append(entity)
264
265    def add_all(self, entities: Sequence[UniNode]) -> None:
266        """Add multiple entities to be persisted."""
267        for entity in entities:
268            self.add(entity)
269
270    def delete(self, entity: UniNode) -> None:
271        """Mark an entity for deletion."""
272        if not entity.is_persisted:
273            raise NotPersisted(entity)
274        self._pending_delete.append(entity)
275
276    def get(
277        self,
278        model: type[NodeT],
279        vid: int | None = None,
280        uid: str | None = None,
281        **kwargs: Any,
282    ) -> NodeT | None:
283        """
284        Get an entity by ID or unique properties.
285
286        Args:
287            model: The model type to retrieve.
288            vid: Vertex ID to look up.
289            uid: Unique ID to look up.
290            **kwargs: Property equality filters.
291
292        Returns:
293            The model instance or None if not found.
294        """
295        # Check identity map first
296        if vid is not None:
297            cached = self._identity_map.get((model.__label__, vid))
298            if cached is not None:
299                return cached  # type: ignore[return-value]
300
301        # Build query
302        label = model.__label__
303        params: dict[str, Any] = {}
304
305        if vid is not None:
306            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
307            params["vid"] = vid
308        elif uid is not None:
309            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
310            params["uid"] = uid
311        elif kwargs:
312            # Validate property names
313            for k in kwargs:
314                _validate_property(k, model)
315            conditions = [f"n.{k} = ${k}" for k in kwargs]
316            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
317            params.update(kwargs)
318        else:
319            raise ValueError("Must provide vid, uid, or property filters")
320
321        results = self._db_session.query(cypher, params)
322        if not results:
323            return None
324
325        node_data = _row_to_node_dict(results[0].to_dict())
326        if node_data is None:
327            return None
328        return self._result_to_model(node_data, model)
329
330    def refresh(self, entity: UniNode) -> None:
331        """Refresh an entity's properties from the database."""
332        if not entity.is_persisted:
333            raise NotPersisted(entity)
334
335        label = entity.__class__.__label__
336        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
337        results = self._db_session.query(cypher, {"vid": entity._vid})
338
339        if not results:
340            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
341
342        # Update properties
343        props = _row_to_node_dict(results[0].to_dict())
344        if props is None:
345            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
346        try:
347            hints = get_type_hints(type(entity))
348        except Exception:
349            hints = {}
350
351        for field_name in entity.get_property_fields():
352            if field_name in props:
353                value = props[field_name]
354                if field_name in hints:
355                    value = db_to_python_value(value, hints[field_name])
356                setattr(entity, field_name, value)
357
358        entity._mark_clean()
359
360    def commit(self) -> None:
361        """
362        Commit all pending changes to the database.
363
364        This persists new entities, updates dirty entities,
365        and deletes marked entities.
366        """
367        # Insert new entities
368        for entity in self._pending_new:
369            self._create_node(entity)
370
371        # Update dirty entities in identity map
372        for (label, vid), entity in list(self._identity_map.items()):
373            if entity.is_dirty and entity.is_persisted:
374                self._update_node(entity)
375
376        # Delete marked entities
377        for entity in self._pending_delete:
378            self._delete_node(entity)
379
380        # Flush to storage
381        self._db.flush()
382
383        # Clear pending lists
384        self._pending_new.clear()
385        self._pending_delete.clear()
386
387    def rollback(self) -> None:
388        """Discard all pending changes."""
389        # Clear pending new — detach entities
390        for entity in self._pending_new:
391            entity._session = None
392        self._pending_new.clear()
393
394        # Clear pending deletes
395        self._pending_delete.clear()
396
397        # Invalidate dirty identity map entries
398        for entity in list(self._identity_map.values()):
399            if entity.is_dirty:
400                self.refresh(entity)
401
402    @contextmanager
403    def transaction(self) -> Iterator[UniTransaction]:
404        """Create a transaction context."""
405        tx = UniTransaction(self)
406        with tx:
407            yield tx
408
409    def begin(self) -> UniTransaction:
410        """Begin a new transaction."""
411        tx = UniTransaction(self)
412        tx._tx = self._db_session.tx()
413        return tx
414
415    def cypher(
416        self,
417        query: str,
418        params: dict[str, Any] | None = None,
419        result_type: type[NodeT] | None = None,
420    ) -> list[NodeT] | list[dict[str, Any]]:
421        """
422        Execute a raw Cypher query.
423
424        Args:
425            query: Cypher query string.
426            params: Query parameters.
427            result_type: Optional model type for result mapping.
428
429        Returns:
430            List of results (model instances if result_type provided).
431        """
432        results = self._db_session.query(query, params)
433
434        if result_type is None:
435            return [r.to_dict() for r in results]
436
437        # Map results to model instances
438        mapped = []
439        for raw_row in results:
440            row = raw_row.to_dict()
441            # Try to find node data in the row
442            for key, value in row.items():
443                if isinstance(value, dict):
444                    # Check for _id/_label keys (uni-db node dict)
445                    if "_id" in value and "_label" in value:
446                        instance = self._result_to_model(value, result_type)
447                        if instance:
448                            mapped.append(instance)
449                            break
450                    # Also check if _label matches registered model
451                    elif "_label" in value:
452                        label = value["_label"]
453                        if label in self._schema_gen._node_models:
454                            model = self._schema_gen._node_models[label]
455                            instance = self._result_to_model(value, model)
456                            if instance:
457                                mapped.append(instance)
458                                break
459            else:
460                # Try the first column
461                first_value = next(iter(row.values()), None)
462                if isinstance(first_value, dict):
463                    instance = self._result_to_model(first_value, result_type)
464                    if instance:
465                        mapped.append(instance)
466
467        return mapped
468
469    @staticmethod
470    def _validate_edge_endpoints(
471        source: UniNode, target: UniNode
472    ) -> tuple[int, int, str, str]:
473        """Validate that both endpoints are persisted and return (src_vid, dst_vid, src_label, dst_label)."""
474        if not source.is_persisted:
475            raise NotPersisted(source)
476        if not target.is_persisted:
477            raise NotPersisted(target)
478        return (
479            source._vid,
480            target._vid,
481            source.__class__.__label__,
482            target.__class__.__label__,
483        )
484
485    @staticmethod
486    def _normalize_edge_properties(
487        properties: dict[str, Any] | UniEdge | None,
488    ) -> dict[str, Any]:
489        """Normalize edge properties from dict, UniEdge, or None."""
490        if isinstance(properties, UniEdge):
491            return properties.to_properties()
492        if properties:
493            return properties
494        return {}
495
496    def create_edge(
497        self,
498        source: UniNode,
499        edge_type: str,
500        target: UniNode,
501        properties: dict[str, Any] | UniEdge | None = None,
502    ) -> None:
503        """Create an edge between two nodes."""
504        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
505            source, target
506        )
507        props = self._normalize_edge_properties(properties)
508
509        # Build CREATE edge query with labels (required by Cypher implementation)
510        props_str = ", ".join(f"{k}: ${k}" for k in props)
511        if props_str:
512            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
513        else:
514            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
515
516        params = {"src": src_vid, "dst": dst_vid, **props}
517        with self._db_session.tx() as tx:
518            tx.execute(cypher, params)
519            tx.commit()
520
521    def delete_edge(
522        self,
523        source: UniNode,
524        edge_type: str,
525        target: UniNode,
526    ) -> int:
527        """Delete edges between two nodes. Returns the number of deleted edges."""
528        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
529            source, target
530        )
531        cypher = (
532            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
533            f"WHERE a._vid = $src AND b._vid = $dst "
534            f"DELETE r RETURN count(r) as count"
535        )
536        with self._db_session.tx() as tx:
537            results = tx.query(cypher, {"src": src_vid, "dst": dst_vid})
538            tx.commit()
539        return cast(int, results[0]["count"]) if results else 0
540
541    def update_edge(
542        self,
543        source: UniNode,
544        edge_type: str,
545        target: UniNode,
546        properties: dict[str, Any],
547    ) -> int:
548        """Update properties on edges between two nodes. Returns the number of updated edges."""
549        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
550            source, target
551        )
552        set_parts = [f"r.{k} = ${k}" for k in properties]
553        params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties}
554        cypher = (
555            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
556            f"WHERE a._vid = $src AND b._vid = $dst "
557            f"SET {', '.join(set_parts)} "
558            f"RETURN count(r) as count"
559        )
560        with self._db_session.tx() as tx:
561            results = tx.query(cypher, params)
562            tx.commit()
563        return cast(int, results[0]["count"]) if results else 0
564
565    def get_edge(
566        self,
567        source: UniNode,
568        edge_type: str,
569        target: UniNode,
570        edge_model: type[EdgeT] | None = None,
571    ) -> list[dict[str, Any]] | list[EdgeT]:
572        """Get edges between two nodes. Returns dicts or edge model instances."""
573        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
574            source, target
575        )
576        cypher = (
577            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
578            f"WHERE a._vid = $src AND b._vid = $dst "
579            f"RETURN properties(r) AS _props, id(r) AS _eid"
580        )
581        results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid})
582        rows = [r.to_dict() for r in results]
583
584        if edge_model is None:
585            edge_dicts: list[dict[str, Any]] = []
586            for row in rows:
587                props = row.get("_props", {})
588                if isinstance(props, dict):
589                    edge_dict = dict(props)
590                    edge_dict["_eid"] = row.get("_eid")
591                    edge_dicts.append(edge_dict)
592            return edge_dicts
593
594        edges = []
595        for row in rows:
596            r_data = row.get("_props", {})
597            if isinstance(r_data, dict):
598                edge = edge_model.from_properties(
599                    r_data,
600                    src_vid=src_vid,
601                    dst_vid=dst_vid,
602                    session=self,
603                )
604                edges.append(edge)
605        return edges
606
607    def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
608        """
609        Bulk-add entities using bulk_writer for performance.
610
611        Groups entities by label and uses db.bulk_writer().
612        Returns VIDs and attaches sessions.
613
614        Args:
615            entities: Sequence of UniNode instances to bulk-insert.
616
617        Returns:
618            List of assigned vertex IDs.
619
620        Raises:
621            BulkLoadError: If bulk insertion fails.
622        """
623        if not entities:
624            return []
625
626        # Group by label
627        by_label: dict[str, list[UniNode]] = {}
628        for entity in entities:
629            label = entity.__class__.__label__
630            if label not in by_label:
631                by_label[label] = []
632            by_label[label].append(entity)
633
634        all_vids: list[int] = []
635        try:
636            for label, group in by_label.items():
637                # Run before_create hooks
638                for entity in group:
639                    run_hooks(entity, _BEFORE_CREATE)
640
641                # Convert to property dicts
642                prop_dicts = [e.to_properties() for e in group]
643
644                # Bulk insert via transaction
645                tx = self._db_session.tx()
646                with tx.bulk_writer().build() as bw:
647                    vids = bw.insert_vertices(label, prop_dicts)
648                    bw.commit()
649                tx.commit()
650
651                # Attach sessions and record VIDs
652                for entity, vid in zip(group, vids):
653                    entity._attach_session(self, vid)
654                    self._identity_map[(label, vid)] = entity
655                    run_hooks(entity, _AFTER_CREATE)
656                    entity._mark_clean()
657
658                all_vids.extend(vids)
659        except Exception as e:
660            raise BulkLoadError(f"Bulk insert failed: {e}") from e
661
662        return all_vids
663
664    def explain(self, cypher: str) -> uni_db.ExplainOutput:
665        """Get the query execution plan without running it."""
666        return self._db_session.explain(cypher)
667
668    def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]:
669        """Run the query with profiling and return results + stats."""
670        return self._db_session.profile(cypher)
671
672    def save_schema(self, path: str) -> None:
673        """Save the database schema to a file."""
674        self._db.save_schema(path)
675
676    def load_schema(self, path: str) -> None:
677        """Load a database schema from a file."""
678        self._db.load_schema(path)
679
680    # -------------------------------------------------------------------------
681    # Internal Methods
682    # -------------------------------------------------------------------------
683
684    def _create_node(self, entity: UniNode) -> None:
685        """Create a node in the database."""
686        # Run before_create hooks
687        run_hooks(entity, _BEFORE_CREATE)
688
689        label = entity.__class__.__label__
690        props = entity.to_properties()
691
692        # Build CREATE query
693        props_str = ", ".join(f"{k}: ${k}" for k in props)
694        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
695
696        with self._db_session.tx() as tx:
697            results = tx.query(cypher, props)
698            tx.commit()
699        if results:
700            vid = results[0]["vid"]
701            entity._attach_session(self, vid)
702
703            # Add to identity map
704            self._identity_map[(label, vid)] = entity
705
706        # Run after_create hooks
707        run_hooks(entity, _AFTER_CREATE)
708        entity._mark_clean()
709
710    def _create_node_in_tx(self, entity: UniNode, tx: uni_db.Transaction) -> None:
711        """Create a node within a transaction."""
712        run_hooks(entity, _BEFORE_CREATE)
713
714        label = entity.__class__.__label__
715        props = entity.to_properties()
716
717        props_str = ", ".join(f"{k}: ${k}" for k in props)
718        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
719
720        results = tx.query(cypher, props)
721        if results:
722            vid = results[0]["vid"]
723            entity._attach_session(self, vid)
724            self._identity_map[(label, vid)] = entity
725
726        run_hooks(entity, _AFTER_CREATE)
727
728    def _create_edge_in_tx(
729        self,
730        source: UniNode,
731        edge_type: str,
732        target: UniNode,
733        properties: UniEdge | None,
734        tx: uni_db.Transaction,
735    ) -> None:
736        """Create an edge within a transaction."""
737        props = properties.to_properties() if properties else {}
738        src_label = source.__class__.__label__
739        dst_label = target.__class__.__label__
740
741        props_str = ", ".join(f"{k}: ${k}" for k in props)
742        if props_str:
743            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)"
744        else:
745            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)"
746
747        params = {"src": source._vid, "dst": target._vid, **props}
748        tx.query(cypher, params)
749
750    def _update_node(self, entity: UniNode) -> None:
751        """Update a node in the database."""
752        run_hooks(entity, _BEFORE_UPDATE)
753
754        label = entity.__class__.__label__
755
756        # Convert dirty prop values via python_to_db_value
757        try:
758            hints = get_type_hints(type(entity))
759        except Exception:
760            hints = {}
761
762        dirty_props = {}
763        for name in entity._dirty:
764            value = getattr(entity, name)
765            if name in hints:
766                value = python_to_db_value(value, hints[name])
767            dirty_props[name] = value
768
769        if not dirty_props:
770            return
771
772        set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props)
773        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}"
774        params = {"vid": entity._vid, **dirty_props}
775
776        with self._db_session.tx() as tx:
777            tx.execute(cypher, params)
778            tx.commit()
779
780        run_hooks(entity, _AFTER_UPDATE)
781        entity._mark_clean()
782
783    def _delete_node(self, entity: UniNode) -> None:
784        """Delete a node from the database."""
785        run_hooks(entity, _BEFORE_DELETE)
786
787        label = entity.__class__.__label__
788        vid = entity._vid
789
790        # DETACH DELETE to also remove connected edges
791        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n"
792        with self._db_session.tx() as tx:
793            tx.execute(cypher, {"vid": vid})
794            tx.commit()
795
796        # Remove from identity map
797        if vid is not None and (label, vid) in self._identity_map:
798            del self._identity_map[(label, vid)]
799
800        # Clear entity IDs
801        entity._vid = None
802        entity._uid = None
803        entity._session = None
804
805        run_hooks(entity, _AFTER_DELETE)
806
807    def _result_to_model(
808        self,
809        data: dict[str, Any],
810        model: type[NodeT],
811    ) -> NodeT | None:
812        """Convert a query result row to a model instance.
813
814        Does not mutate the input dict.
815        """
816        if not data:
817            return None
818
819        # Work on a copy
820        data = dict(data)
821
822        # Run before_load hooks
823        data = run_class_hooks(model, _BEFORE_LOAD, data) or data
824
825        # Extract _id → vid (uni-db returns _id as string or int)
826        vid = data.pop("_id", None)
827        if vid is None:
828            vid = data.pop("_vid", None)
829        if vid is None:
830            vid = data.pop("vid", None)
831        if vid is not None and not isinstance(vid, int):
832            vid = int(vid)
833
834        # Remove _label (informational)
835        data.pop("_label", None)
836
837        try:
838            instance = cast(
839                NodeT,
840                model.from_properties(
841                    data,
842                    vid=vid,
843                    session=self,
844                ),
845            )
846        except Exception:
847            # If validation fails, return None
848            return None
849
850        # Add to identity map if we have a vid
851        if vid is not None:
852            existing = self._identity_map.get((model.__label__, vid))
853            if existing is not None:
854                return cast(NodeT, existing)
855            self._identity_map[(model.__label__, vid)] = instance
856
857        # Run after_load hooks
858        run_hooks(instance, _AFTER_LOAD)
859
860        return instance
861
862    def _load_relationship(
863        self,
864        entity: UniNode,
865        descriptor: RelationshipDescriptor[Any],
866    ) -> list[UniNode] | UniNode | None:
867        """Load a relationship for an entity."""
868        if not entity.is_persisted:
869            raise NotPersisted(entity)
870
871        config = descriptor.config
872        label = entity.__class__.__label__
873        pattern = _edge_pattern(config.edge_type, config.direction)
874
875        cypher = (
876            f"MATCH (a:{label}){pattern}(b) WHERE id(a) = $vid "
877            f"RETURN properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
878        )
879        results = self._db_session.query(cypher, {"vid": entity._vid})
880
881        nodes = []
882        for raw_row in results:
883            row = raw_row.to_dict()
884            node_data = _row_to_node_dict(row)
885            if node_data is None:
886                continue
887            # Try to find the model for this node
888            node_label = node_data.get("_label")
889            if node_label and node_label in self._schema_gen._node_models:
890                model = self._schema_gen._node_models[node_label]
891                instance = self._result_to_model(node_data, model)
892                if instance:
893                    nodes.append(instance)
894
895        if not descriptor.is_list:
896            return nodes[0] if nodes else None
897        return nodes
898
899    def _eager_load_relationships(
900        self,
901        entities: list[NodeT],
902        relationships: list[str],
903    ) -> None:
904        """Eager load relationships for a list of entities."""
905        if not entities:
906            return
907
908        model = type(entities[0])
909        rel_configs = model.get_relationship_fields()
910
911        for rel_name in relationships:
912            if rel_name not in rel_configs:
913                continue
914
915            config = rel_configs[rel_name]
916            label = model.__label__
917            vids = [e._vid for e in entities if e._vid is not None]
918
919            if not vids:
920                continue
921
922            pattern = _edge_pattern(config.edge_type, config.direction)
923            cypher = (
924                f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids "
925                f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
926            )
927            results = self._db_session.query(cypher, {"vids": vids})
928
929            # Group results by source vid
930            by_source: dict[int, list[Any]] = {}
931            for raw_row in results:
932                row = raw_row.to_dict()
933                src_vid = row["src_vid"]
934                node_data = _row_to_node_dict(row)
935                if node_data is None:
936                    continue
937                if src_vid not in by_source:
938                    by_source[src_vid] = []
939                by_source[src_vid].append(node_data)
940
941            # Set cached values on entities
942            for entity in entities:
943                if entity._vid in by_source:
944                    related = by_source[entity._vid]
945                    cache_attr = f"_rel_cache_{rel_name}"
946                    setattr(entity, cache_attr, related)

Session for interacting with the graph database using Pydantic models.

The session manages model registration, schema synchronization, and provides CRUD operations and query building.

Example:

from uni_db import Uni from uni_pydantic import UniSession

db = Uni("./my_graph") session = UniSession(db) session.register(Person, Company) session.sync_schema()

alice = Person(name="Alice", age=30) session.add(alice) session.commit()

UniSession(db: Uni)
181    def __init__(self, db: uni_db.Uni) -> None:
182        self._db = db
183        self._db_session = db.session()
184        self._schema_gen = SchemaGenerator()
185        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
186            WeakValueDictionary()
187        )
188        self._pending_new: list[UniNode] = []
189        self._pending_delete: list[UniNode] = []
def close(self) -> None:
202    def close(self) -> None:
203        """Close the session and clear all pending state."""
204        self._pending_new.clear()
205        self._pending_delete.clear()

Close the session and clear all pending state.

db: Uni
207    @property
208    def db(self) -> uni_db.Uni:
209        """Access the underlying uni_db.Uni for low-level operations."""
210        return self._db

Access the underlying uni_db.Uni for low-level operations.

def locy( self, program: str, params: dict[str, typing.Any] | None = None) -> LocyResult:
212    def locy(
213        self, program: str, params: dict[str, Any] | None = None
214    ) -> uni_db.LocyResult:
215        """
216        Evaluate a Locy program and return derived facts, stats, and warnings.
217
218        Delegates to the underlying ``uni_db.Session.locy()``.
219        """
220        return self._db_session.locy(program, params)

Evaluate a Locy program and return derived facts, stats, and warnings.

Delegates to the underlying uni_db.Session.locy().

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
222    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
223        """
224        Register model classes with the session.
225
226        Registered models can be used for schema generation and queries.
227
228        Args:
229            *models: UniNode or UniEdge subclasses to register.
230        """
231        self._schema_gen.register(*models)

Register model classes with the session.

Registered models can be used for schema generation and queries.

Args: *models: UniNode or UniEdge subclasses to register.

def sync_schema(self) -> None:
233    def sync_schema(self) -> None:
234        """
235        Synchronize database schema with registered models.
236
237        Creates labels, edge types, properties, and indexes as needed.
238        This is additive-only; it won't remove existing schema elements.
239        """
240        self._schema_gen.apply_to_database(self._db)

Synchronize database schema with registered models.

Creates labels, edge types, properties, and indexes as needed. This is additive-only; it won't remove existing schema elements.

def query(self, model: type[~NodeT]) -> QueryBuilder[~NodeT]:
242    def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]:
243        """
244        Create a query builder for the given model.
245
246        Args:
247            model: The UniNode subclass to query.
248
249        Returns:
250            A QueryBuilder for constructing queries.
251        """
252        return QueryBuilder(self, model)

Create a query builder for the given model.

Args: model: The UniNode subclass to query.

Returns: A QueryBuilder for constructing queries.

def add(self, entity: UniNode) -> None:
254    def add(self, entity: UniNode) -> None:
255        """
256        Add a new entity to be persisted.
257
258        The entity will be inserted on the next commit().
259        """
260        if entity.is_persisted:
261            raise SessionError(f"Entity {entity!r} is already persisted")
262        entity._session = self
263        self._pending_new.append(entity)

Add a new entity to be persisted.

The entity will be inserted on the next commit().

def add_all(self, entities: Sequence[UniNode]) -> None:
265    def add_all(self, entities: Sequence[UniNode]) -> None:
266        """Add multiple entities to be persisted."""
267        for entity in entities:
268            self.add(entity)

Add multiple entities to be persisted.

def delete(self, entity: UniNode) -> None:
270    def delete(self, entity: UniNode) -> None:
271        """Mark an entity for deletion."""
272        if not entity.is_persisted:
273            raise NotPersisted(entity)
274        self._pending_delete.append(entity)

Mark an entity for deletion.

def get( self, model: type[~NodeT], vid: int | None = None, uid: str | None = None, **kwargs: Any) -> Optional[~NodeT]:
276    def get(
277        self,
278        model: type[NodeT],
279        vid: int | None = None,
280        uid: str | None = None,
281        **kwargs: Any,
282    ) -> NodeT | None:
283        """
284        Get an entity by ID or unique properties.
285
286        Args:
287            model: The model type to retrieve.
288            vid: Vertex ID to look up.
289            uid: Unique ID to look up.
290            **kwargs: Property equality filters.
291
292        Returns:
293            The model instance or None if not found.
294        """
295        # Check identity map first
296        if vid is not None:
297            cached = self._identity_map.get((model.__label__, vid))
298            if cached is not None:
299                return cached  # type: ignore[return-value]
300
301        # Build query
302        label = model.__label__
303        params: dict[str, Any] = {}
304
305        if vid is not None:
306            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
307            params["vid"] = vid
308        elif uid is not None:
309            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
310            params["uid"] = uid
311        elif kwargs:
312            # Validate property names
313            for k in kwargs:
314                _validate_property(k, model)
315            conditions = [f"n.{k} = ${k}" for k in kwargs]
316            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
317            params.update(kwargs)
318        else:
319            raise ValueError("Must provide vid, uid, or property filters")
320
321        results = self._db_session.query(cypher, params)
322        if not results:
323            return None
324
325        node_data = _row_to_node_dict(results[0].to_dict())
326        if node_data is None:
327            return None
328        return self._result_to_model(node_data, model)

Get an entity by ID or unique properties.

Args: model: The model type to retrieve. vid: Vertex ID to look up. uid: Unique ID to look up. **kwargs: Property equality filters.

Returns: The model instance or None if not found.

def refresh(self, entity: UniNode) -> None:
330    def refresh(self, entity: UniNode) -> None:
331        """Refresh an entity's properties from the database."""
332        if not entity.is_persisted:
333            raise NotPersisted(entity)
334
335        label = entity.__class__.__label__
336        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
337        results = self._db_session.query(cypher, {"vid": entity._vid})
338
339        if not results:
340            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
341
342        # Update properties
343        props = _row_to_node_dict(results[0].to_dict())
344        if props is None:
345            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
346        try:
347            hints = get_type_hints(type(entity))
348        except Exception:
349            hints = {}
350
351        for field_name in entity.get_property_fields():
352            if field_name in props:
353                value = props[field_name]
354                if field_name in hints:
355                    value = db_to_python_value(value, hints[field_name])
356                setattr(entity, field_name, value)
357
358        entity._mark_clean()

Refresh an entity's properties from the database.

def commit(self) -> None:
360    def commit(self) -> None:
361        """
362        Commit all pending changes to the database.
363
364        This persists new entities, updates dirty entities,
365        and deletes marked entities.
366        """
367        # Insert new entities
368        for entity in self._pending_new:
369            self._create_node(entity)
370
371        # Update dirty entities in identity map
372        for (label, vid), entity in list(self._identity_map.items()):
373            if entity.is_dirty and entity.is_persisted:
374                self._update_node(entity)
375
376        # Delete marked entities
377        for entity in self._pending_delete:
378            self._delete_node(entity)
379
380        # Flush to storage
381        self._db.flush()
382
383        # Clear pending lists
384        self._pending_new.clear()
385        self._pending_delete.clear()

Commit all pending changes to the database.

This persists new entities, updates dirty entities, and deletes marked entities.

def rollback(self) -> None:
387    def rollback(self) -> None:
388        """Discard all pending changes."""
389        # Clear pending new — detach entities
390        for entity in self._pending_new:
391            entity._session = None
392        self._pending_new.clear()
393
394        # Clear pending deletes
395        self._pending_delete.clear()
396
397        # Invalidate dirty identity map entries
398        for entity in list(self._identity_map.values()):
399            if entity.is_dirty:
400                self.refresh(entity)

Discard all pending changes.

@contextmanager
def transaction(self) -> Iterator[UniTransaction]:
402    @contextmanager
403    def transaction(self) -> Iterator[UniTransaction]:
404        """Create a transaction context."""
405        tx = UniTransaction(self)
406        with tx:
407            yield tx

Create a transaction context.

def begin(self) -> UniTransaction:
409    def begin(self) -> UniTransaction:
410        """Begin a new transaction."""
411        tx = UniTransaction(self)
412        tx._tx = self._db_session.tx()
413        return tx

Begin a new transaction.

def cypher( self, query: str, params: dict[str, typing.Any] | None = None, result_type: type[~NodeT] | None = None) -> list[~NodeT] | list[dict[str, typing.Any]]:
415    def cypher(
416        self,
417        query: str,
418        params: dict[str, Any] | None = None,
419        result_type: type[NodeT] | None = None,
420    ) -> list[NodeT] | list[dict[str, Any]]:
421        """
422        Execute a raw Cypher query.
423
424        Args:
425            query: Cypher query string.
426            params: Query parameters.
427            result_type: Optional model type for result mapping.
428
429        Returns:
430            List of results (model instances if result_type provided).
431        """
432        results = self._db_session.query(query, params)
433
434        if result_type is None:
435            return [r.to_dict() for r in results]
436
437        # Map results to model instances
438        mapped = []
439        for raw_row in results:
440            row = raw_row.to_dict()
441            # Try to find node data in the row
442            for key, value in row.items():
443                if isinstance(value, dict):
444                    # Check for _id/_label keys (uni-db node dict)
445                    if "_id" in value and "_label" in value:
446                        instance = self._result_to_model(value, result_type)
447                        if instance:
448                            mapped.append(instance)
449                            break
450                    # Also check if _label matches registered model
451                    elif "_label" in value:
452                        label = value["_label"]
453                        if label in self._schema_gen._node_models:
454                            model = self._schema_gen._node_models[label]
455                            instance = self._result_to_model(value, model)
456                            if instance:
457                                mapped.append(instance)
458                                break
459            else:
460                # Try the first column
461                first_value = next(iter(row.values()), None)
462                if isinstance(first_value, dict):
463                    instance = self._result_to_model(first_value, result_type)
464                    if instance:
465                        mapped.append(instance)
466
467        return mapped

Execute a raw Cypher query.

Args: query: Cypher query string. params: Query parameters. result_type: Optional model type for result mapping.

Returns: List of results (model instances if result_type provided).

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any] | UniEdge | None = None) -> None:
496    def create_edge(
497        self,
498        source: UniNode,
499        edge_type: str,
500        target: UniNode,
501        properties: dict[str, Any] | UniEdge | None = None,
502    ) -> None:
503        """Create an edge between two nodes."""
504        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
505            source, target
506        )
507        props = self._normalize_edge_properties(properties)
508
509        # Build CREATE edge query with labels (required by Cypher implementation)
510        props_str = ", ".join(f"{k}: ${k}" for k in props)
511        if props_str:
512            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
513        else:
514            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
515
516        params = {"src": src_vid, "dst": dst_vid, **props}
517        with self._db_session.tx() as tx:
518            tx.execute(cypher, params)
519            tx.commit()

Create an edge between two nodes.

def delete_edge( self, source: UniNode, edge_type: str, target: UniNode) -> int:
521    def delete_edge(
522        self,
523        source: UniNode,
524        edge_type: str,
525        target: UniNode,
526    ) -> int:
527        """Delete edges between two nodes. Returns the number of deleted edges."""
528        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
529            source, target
530        )
531        cypher = (
532            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
533            f"WHERE a._vid = $src AND b._vid = $dst "
534            f"DELETE r RETURN count(r) as count"
535        )
536        with self._db_session.tx() as tx:
537            results = tx.query(cypher, {"src": src_vid, "dst": dst_vid})
538            tx.commit()
539        return cast(int, results[0]["count"]) if results else 0

Delete edges between two nodes. Returns the number of deleted edges.

def update_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any]) -> int:
541    def update_edge(
542        self,
543        source: UniNode,
544        edge_type: str,
545        target: UniNode,
546        properties: dict[str, Any],
547    ) -> int:
548        """Update properties on edges between two nodes. Returns the number of updated edges."""
549        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
550            source, target
551        )
552        set_parts = [f"r.{k} = ${k}" for k in properties]
553        params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties}
554        cypher = (
555            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
556            f"WHERE a._vid = $src AND b._vid = $dst "
557            f"SET {', '.join(set_parts)} "
558            f"RETURN count(r) as count"
559        )
560        with self._db_session.tx() as tx:
561            results = tx.query(cypher, params)
562            tx.commit()
563        return cast(int, results[0]["count"]) if results else 0

Update properties on edges between two nodes. Returns the number of updated edges.

def get_edge( self, source: UniNode, edge_type: str, target: UniNode, edge_model: type[~EdgeT] | None = None) -> list[dict[str, typing.Any]] | list[~EdgeT]:
565    def get_edge(
566        self,
567        source: UniNode,
568        edge_type: str,
569        target: UniNode,
570        edge_model: type[EdgeT] | None = None,
571    ) -> list[dict[str, Any]] | list[EdgeT]:
572        """Get edges between two nodes. Returns dicts or edge model instances."""
573        src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints(
574            source, target
575        )
576        cypher = (
577            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
578            f"WHERE a._vid = $src AND b._vid = $dst "
579            f"RETURN properties(r) AS _props, id(r) AS _eid"
580        )
581        results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid})
582        rows = [r.to_dict() for r in results]
583
584        if edge_model is None:
585            edge_dicts: list[dict[str, Any]] = []
586            for row in rows:
587                props = row.get("_props", {})
588                if isinstance(props, dict):
589                    edge_dict = dict(props)
590                    edge_dict["_eid"] = row.get("_eid")
591                    edge_dicts.append(edge_dict)
592            return edge_dicts
593
594        edges = []
595        for row in rows:
596            r_data = row.get("_props", {})
597            if isinstance(r_data, dict):
598                edge = edge_model.from_properties(
599                    r_data,
600                    src_vid=src_vid,
601                    dst_vid=dst_vid,
602                    session=self,
603                )
604                edges.append(edge)
605        return edges

Get edges between two nodes. Returns dicts or edge model instances.

def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
607    def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
608        """
609        Bulk-add entities using bulk_writer for performance.
610
611        Groups entities by label and uses db.bulk_writer().
612        Returns VIDs and attaches sessions.
613
614        Args:
615            entities: Sequence of UniNode instances to bulk-insert.
616
617        Returns:
618            List of assigned vertex IDs.
619
620        Raises:
621            BulkLoadError: If bulk insertion fails.
622        """
623        if not entities:
624            return []
625
626        # Group by label
627        by_label: dict[str, list[UniNode]] = {}
628        for entity in entities:
629            label = entity.__class__.__label__
630            if label not in by_label:
631                by_label[label] = []
632            by_label[label].append(entity)
633
634        all_vids: list[int] = []
635        try:
636            for label, group in by_label.items():
637                # Run before_create hooks
638                for entity in group:
639                    run_hooks(entity, _BEFORE_CREATE)
640
641                # Convert to property dicts
642                prop_dicts = [e.to_properties() for e in group]
643
644                # Bulk insert via transaction
645                tx = self._db_session.tx()
646                with tx.bulk_writer().build() as bw:
647                    vids = bw.insert_vertices(label, prop_dicts)
648                    bw.commit()
649                tx.commit()
650
651                # Attach sessions and record VIDs
652                for entity, vid in zip(group, vids):
653                    entity._attach_session(self, vid)
654                    self._identity_map[(label, vid)] = entity
655                    run_hooks(entity, _AFTER_CREATE)
656                    entity._mark_clean()
657
658                all_vids.extend(vids)
659        except Exception as e:
660            raise BulkLoadError(f"Bulk insert failed: {e}") from e
661
662        return all_vids

Bulk-add entities using bulk_writer for performance.

Groups entities by label and uses db.bulk_writer(). Returns VIDs and attaches sessions.

Args: entities: Sequence of UniNode instances to bulk-insert.

Returns: List of assigned vertex IDs.

Raises: BulkLoadError: If bulk insertion fails.

def explain(self, cypher: str) -> ExplainOutput:
664    def explain(self, cypher: str) -> uni_db.ExplainOutput:
665        """Get the query execution plan without running it."""
666        return self._db_session.explain(cypher)

Get the query execution plan without running it.

def profile(self, cypher: str) -> tuple[QueryResult, ProfileOutput]:
668    def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]:
669        """Run the query with profiling and return results + stats."""
670        return self._db_session.profile(cypher)

Run the query with profiling and return results + stats.

def save_schema(self, path: str) -> None:
672    def save_schema(self, path: str) -> None:
673        """Save the database schema to a file."""
674        self._db.save_schema(path)

Save the database schema to a file.

def load_schema(self, path: str) -> None:
676    def load_schema(self, path: str) -> None:
677        """Load a database schema from a file."""
678        self._db.load_schema(path)

Load a database schema from a file.

class UniTransaction:
 61class UniTransaction:
 62    """
 63    Transaction context for atomic operations.
 64
 65    Provides commit/rollback semantics for a group of operations.
 66
 67    Example:
 68        >>> with session.transaction() as tx:
 69        ...     alice = Person(name="Alice")
 70        ...     tx.add(alice)
 71        ...     # Auto-commits on success, rolls back on exception
 72    """
 73
 74    def __init__(self, session: UniSession) -> None:
 75        self._session = session
 76        self._tx: uni_db.Transaction | None = None
 77        self._pending_nodes: list[UniNode] = []
 78        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
 79        self._committed = False
 80        self._rolled_back = False
 81
 82    def __enter__(self) -> UniTransaction:
 83        self._tx = self._session._db_session.tx()
 84        return self
 85
 86    def __exit__(
 87        self,
 88        exc_type: type[BaseException] | None,
 89        exc_val: BaseException | None,
 90        exc_tb: TracebackType | None,
 91    ) -> None:
 92        if exc_type is not None:
 93            self.rollback()
 94            return
 95        if not self._committed and not self._rolled_back:
 96            self.commit()
 97
 98    def add(self, entity: UniNode) -> None:
 99        """Add a node to be created in this transaction."""
100        self._pending_nodes.append(entity)
101
102    def create_edge(
103        self,
104        source: UniNode,
105        edge_type: str,
106        target: UniNode,
107        properties: UniEdge | None = None,
108        **kwargs: Any,
109    ) -> None:
110        """Create an edge between two nodes in this transaction."""
111        if not source.is_persisted:
112            raise NotPersisted(source)
113        if not target.is_persisted:
114            raise NotPersisted(target)
115        self._pending_edges.append((source, edge_type, target, properties))
116
117    def commit(self) -> None:
118        """Commit the transaction."""
119        if self._committed:
120            raise TransactionError("Transaction already committed")
121        if self._rolled_back:
122            raise TransactionError("Transaction already rolled back")
123
124        if self._tx is None:
125            raise TransactionError("Transaction not started")
126
127        try:
128            # Create pending nodes
129            for node in self._pending_nodes:
130                self._session._create_node_in_tx(node, self._tx)
131
132            # Create pending edges
133            for source, edge_type, target, props in self._pending_edges:
134                self._session._create_edge_in_tx(
135                    source, edge_type, target, props, self._tx
136                )
137
138            self._tx.commit()
139            self._committed = True
140
141            # Mark nodes as clean
142            for node in self._pending_nodes:
143                node._mark_clean()
144
145        except Exception as e:
146            self.rollback()
147            raise TransactionError(f"Commit failed: {e}") from e
148
149    def rollback(self) -> None:
150        """Rollback the transaction."""
151        if self._rolled_back:
152            return
153        if self._tx is not None:
154            self._tx.rollback()
155        self._rolled_back = True
156        self._pending_nodes.clear()
157        self._pending_edges.clear()

Transaction context for atomic operations.

Provides commit/rollback semantics for a group of operations.

Example:

with session.transaction() as tx: ... alice = Person(name="Alice") ... tx.add(alice) ... # Auto-commits on success, rolls back on exception

UniTransaction(session: UniSession)
74    def __init__(self, session: UniSession) -> None:
75        self._session = session
76        self._tx: uni_db.Transaction | None = None
77        self._pending_nodes: list[UniNode] = []
78        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
79        self._committed = False
80        self._rolled_back = False
def add(self, entity: UniNode) -> None:
 98    def add(self, entity: UniNode) -> None:
 99        """Add a node to be created in this transaction."""
100        self._pending_nodes.append(entity)

Add a node to be created in this transaction.

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: UniEdge | None = None, **kwargs: Any) -> None:
102    def create_edge(
103        self,
104        source: UniNode,
105        edge_type: str,
106        target: UniNode,
107        properties: UniEdge | None = None,
108        **kwargs: Any,
109    ) -> None:
110        """Create an edge between two nodes in this transaction."""
111        if not source.is_persisted:
112            raise NotPersisted(source)
113        if not target.is_persisted:
114            raise NotPersisted(target)
115        self._pending_edges.append((source, edge_type, target, properties))

Create an edge between two nodes in this transaction.

def commit(self) -> None:
117    def commit(self) -> None:
118        """Commit the transaction."""
119        if self._committed:
120            raise TransactionError("Transaction already committed")
121        if self._rolled_back:
122            raise TransactionError("Transaction already rolled back")
123
124        if self._tx is None:
125            raise TransactionError("Transaction not started")
126
127        try:
128            # Create pending nodes
129            for node in self._pending_nodes:
130                self._session._create_node_in_tx(node, self._tx)
131
132            # Create pending edges
133            for source, edge_type, target, props in self._pending_edges:
134                self._session._create_edge_in_tx(
135                    source, edge_type, target, props, self._tx
136                )
137
138            self._tx.commit()
139            self._committed = True
140
141            # Mark nodes as clean
142            for node in self._pending_nodes:
143                node._mark_clean()
144
145        except Exception as e:
146            self.rollback()
147            raise TransactionError(f"Commit failed: {e}") from e

Commit the transaction.

def rollback(self) -> None:
149    def rollback(self) -> None:
150        """Rollback the transaction."""
151        if self._rolled_back:
152            return
153        if self._tx is not None:
154            self._tx.rollback()
155        self._rolled_back = True
156        self._pending_nodes.clear()
157        self._pending_edges.clear()

Rollback the transaction.

class AsyncUniSession:
134class AsyncUniSession:
135    """
136    Async session for interacting with the graph database.
137
138    Mirrors UniSession with async methods. Uses AsyncUni.
139
140    Example:
141        >>> from uni_db import AsyncUni
142        >>> from uni_pydantic import AsyncUniSession
143        >>>
144        >>> db = await AsyncUni.open("./my_graph")
145        >>> async with AsyncUniSession(db) as session:
146        ...     session.register(Person)
147        ...     await session.sync_schema()
148        ...     alice = Person(name="Alice", age=30)
149        ...     session.add(alice)
150        ...     await session.commit()
151    """
152
153    def __init__(self, db: uni_db.AsyncUni) -> None:
154        self._db = db
155        self._db_session = db.session()
156        self._schema_gen = SchemaGenerator()
157        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
158            WeakValueDictionary()
159        )
160        self._pending_new: list[UniNode] = []
161        self._pending_delete: list[UniNode] = []
162
163    async def __aenter__(self) -> AsyncUniSession:
164        return self
165
166    async def __aexit__(
167        self,
168        exc_type: type[BaseException] | None,
169        exc_val: BaseException | None,
170        exc_tb: TracebackType | None,
171    ) -> None:
172        self.close()
173
174    def close(self) -> None:
175        """Close the session and clear pending state."""
176        self._pending_new.clear()
177        self._pending_delete.clear()
178
179    @property
180    def db(self) -> uni_db.AsyncUni:
181        """Access the underlying uni_db.AsyncUni for low-level operations."""
182        return self._db
183
184    async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any:
185        """
186        Evaluate a Locy program and return derived facts, stats, and warnings.
187
188        Delegates to the underlying ``uni_db.AsyncSession.locy()``.
189        """
190        return await self._db_session.locy(program, params)
191
192    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
193        """Register model classes with the session (sync)."""
194        self._schema_gen.register(*models)
195
196    async def sync_schema(self) -> None:
197        """Synchronize database schema with registered models."""
198        await self._schema_gen.async_apply_to_database(self._db)
199
200    def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]:
201        """Create an async query builder for the given model."""
202        return AsyncQueryBuilder(self, model)
203
204    def add(self, entity: UniNode) -> None:
205        """Add a new entity to be persisted (sync — just collects)."""
206        if entity.is_persisted:
207            raise SessionError(f"Entity {entity!r} is already persisted")
208        entity._session = self
209        self._pending_new.append(entity)
210
211    def add_all(self, entities: Sequence[UniNode]) -> None:
212        """Add multiple entities (sync — just collects)."""
213        for entity in entities:
214            self.add(entity)
215
216    def delete(self, entity: UniNode) -> None:
217        """Mark an entity for deletion (sync — just collects)."""
218        if not entity.is_persisted:
219            raise NotPersisted(entity)
220        self._pending_delete.append(entity)
221
222    async def get(
223        self,
224        model: type[NodeT],
225        vid: int | None = None,
226        uid: str | None = None,
227        **kwargs: Any,
228    ) -> NodeT | None:
229        """Get an entity by ID or unique properties."""
230        if vid is not None:
231            cached = self._identity_map.get((model.__label__, vid))
232            if cached is not None:
233                return cached  # type: ignore[return-value]
234
235        label = model.__label__
236        params: dict[str, Any] = {}
237
238        if vid is not None:
239            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
240            params["vid"] = vid
241        elif uid is not None:
242            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
243            params["uid"] = uid
244        elif kwargs:
245            for k in kwargs:
246                _validate_property(k, model)
247            conditions = [f"n.{k} = ${k}" for k in kwargs]
248            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
249            params.update(kwargs)
250        else:
251            raise ValueError("Must provide vid, uid, or property filters")
252
253        results = await self._db_session.query(cypher, params)
254        if not results:
255            return None
256
257        node_data = _row_to_node_dict(results[0].to_dict())
258        if node_data is None:
259            return None
260        return self._result_to_model(node_data, model)
261
262    async def refresh(self, entity: UniNode) -> None:
263        """Refresh an entity's properties from the database."""
264        if not entity.is_persisted:
265            raise NotPersisted(entity)
266
267        label = entity.__class__.__label__
268        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
269        results = await self._db_session.query(cypher, {"vid": entity._vid})
270
271        if not results:
272            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
273
274        props = _row_to_node_dict(results[0].to_dict())
275        if props is None:
276            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
277        try:
278            hints = get_type_hints(type(entity))
279        except Exception:
280            hints = {}
281
282        for field_name in entity.get_property_fields():
283            if field_name in props:
284                value = props[field_name]
285                if field_name in hints:
286                    value = db_to_python_value(value, hints[field_name])
287                setattr(entity, field_name, value)
288
289        entity._mark_clean()
290
291    async def commit(self) -> None:
292        """Commit all pending changes."""
293        for entity in self._pending_new:
294            await self._create_node(entity)
295
296        for (label, vid), entity in list(self._identity_map.items()):
297            if entity.is_dirty and entity.is_persisted:
298                await self._update_node(entity)
299
300        for entity in self._pending_delete:
301            await self._delete_node(entity)
302
303        await self._db.flush()
304        self._pending_new.clear()
305        self._pending_delete.clear()
306
307    async def rollback(self) -> None:
308        """Discard all pending changes."""
309        for entity in self._pending_new:
310            entity._session = None
311        self._pending_new.clear()
312        self._pending_delete.clear()
313        for entity in list(self._identity_map.values()):
314            if entity.is_dirty:
315                await self.refresh(entity)
316
317    async def transaction(self) -> AsyncUniTransaction:
318        """Create an async transaction. Use as `async with session.transaction() as tx:`."""
319        return AsyncUniTransaction(self)
320
321    async def cypher(
322        self,
323        query: str,
324        params: dict[str, Any] | None = None,
325        result_type: type[NodeT] | None = None,
326    ) -> list[NodeT] | list[dict[str, Any]]:
327        """Execute a raw Cypher query."""
328        results = await self._db_session.query(query, params)
329
330        if result_type is None:
331            return [r.to_dict() for r in results]
332
333        mapped = []
334        for raw_row in results:
335            row = raw_row.to_dict()
336            for key, value in row.items():
337                if isinstance(value, dict):
338                    if "_id" in value and "_label" in value:
339                        instance = self._result_to_model(value, result_type)
340                        if instance:
341                            mapped.append(instance)
342                            break
343                    elif "_label" in value:
344                        label = value["_label"]
345                        if label in self._schema_gen._node_models:
346                            model = self._schema_gen._node_models[label]
347                            instance = self._result_to_model(value, model)
348                            if instance:
349                                mapped.append(instance)
350                                break
351            else:
352                first_value = next(iter(row.values()), None)
353                if isinstance(first_value, dict):
354                    instance = self._result_to_model(first_value, result_type)
355                    if instance:
356                        mapped.append(instance)
357
358        return mapped
359
360    async def create_edge(
361        self,
362        source: UniNode,
363        edge_type: str,
364        target: UniNode,
365        properties: dict[str, Any] | UniEdge | None = None,
366    ) -> None:
367        """Create an edge between two nodes."""
368        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
369            source, target
370        )
371        props = UniSession._normalize_edge_properties(properties)
372
373        props_str = ", ".join(f"{k}: ${k}" for k in props)
374        if props_str:
375            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
376        else:
377            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
378
379        async with await self._db_session.tx() as tx:
380            await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props})
381            await tx.commit()
382
383    async def delete_edge(
384        self, source: UniNode, edge_type: str, target: UniNode
385    ) -> int:
386        """Delete edges between two nodes. Returns the number of deleted edges."""
387        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
388            source, target
389        )
390        cypher = (
391            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
392            f"WHERE a._vid = $src AND b._vid = $dst "
393            f"DELETE r RETURN count(r) as count"
394        )
395        async with await self._db_session.tx() as tx:
396            results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid})
397            await tx.commit()
398        return cast(int, results[0]["count"]) if results else 0
399
400    async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
401        """Bulk-add entities using bulk_writer."""
402        if not entities:
403            return []
404
405        by_label: dict[str, list[UniNode]] = {}
406        for entity in entities:
407            label = entity.__class__.__label__
408            if label not in by_label:
409                by_label[label] = []
410            by_label[label].append(entity)
411
412        all_vids: list[int] = []
413        try:
414            for label, group in by_label.items():
415                for entity in group:
416                    run_hooks(entity, _BEFORE_CREATE)
417                prop_dicts = [e.to_properties() for e in group]
418                tx = await self._db_session.tx()
419                async with await tx.bulk_writer().build() as bw:
420                    vids = await bw.insert_vertices(label, prop_dicts)
421                    await bw.commit()
422                await tx.commit()
423                for entity, vid in zip(group, vids):
424                    entity._attach_session(self, vid)
425                    self._identity_map[(label, vid)] = entity
426                    run_hooks(entity, _AFTER_CREATE)
427                    entity._mark_clean()
428                all_vids.extend(vids)
429        except Exception as e:
430            raise BulkLoadError(f"Bulk insert failed: {e}") from e
431
432        return all_vids
433
434    async def explain(self, cypher: str) -> Any:
435        """Get the query execution plan."""
436        return await self._db_session.explain(cypher)
437
438    async def profile(self, cypher: str) -> Any:
439        """Run the query with profiling and return results + stats."""
440        return await self._db_session.profile(cypher)
441
442    async def save_schema(self, path: str) -> None:
443        """Save the database schema to a file."""
444        await self._db.save_schema(path)
445
446    async def load_schema(self, path: str) -> None:
447        """Load a database schema from a file."""
448        await self._db.load_schema(path)
449
450    # ---- Internal methods ----
451
452    async def _create_node(self, entity: UniNode) -> None:
453        run_hooks(entity, _BEFORE_CREATE)
454        label = entity.__class__.__label__
455        props = entity.to_properties()
456        props_str = ", ".join(f"{k}: ${k}" for k in props)
457        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
458        async with await self._db_session.tx() as tx:
459            results = await tx.query(cypher, props)
460            await tx.commit()
461        if results:
462            vid = results[0]["vid"]
463            entity._attach_session(self, vid)
464            self._identity_map[(label, vid)] = entity
465        run_hooks(entity, _AFTER_CREATE)
466        entity._mark_clean()
467
468    async def _create_node_in_tx(
469        self, entity: UniNode, tx: uni_db.AsyncTransaction
470    ) -> None:
471        run_hooks(entity, _BEFORE_CREATE)
472        label = entity.__class__.__label__
473        props = entity.to_properties()
474        props_str = ", ".join(f"{k}: ${k}" for k in props)
475        cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid"
476        results = await tx.query(cypher, props)
477        if results:
478            vid = results[0]["vid"]
479            entity._attach_session(self, vid)
480            self._identity_map[(label, vid)] = entity
481        run_hooks(entity, _AFTER_CREATE)
482
483    async def _create_edge_in_tx(
484        self,
485        source: UniNode,
486        edge_type: str,
487        target: UniNode,
488        properties: UniEdge | None,
489        tx: uni_db.AsyncTransaction,
490    ) -> None:
491        props = properties.to_properties() if properties else {}
492        src_label = source.__class__.__label__
493        dst_label = target.__class__.__label__
494        props_str = ", ".join(f"{k}: ${k}" for k in props)
495        if props_str:
496            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)"
497        else:
498            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)"
499        params = {"src": source._vid, "dst": target._vid, **props}
500        await tx.query(cypher, params)
501
502    async def _update_node(self, entity: UniNode) -> None:
503        run_hooks(entity, _BEFORE_UPDATE)
504        label = entity.__class__.__label__
505        try:
506            hints = get_type_hints(type(entity))
507        except Exception:
508            hints = {}
509        dirty_props = {}
510        for name in entity._dirty:
511            value = getattr(entity, name)
512            if name in hints:
513                value = python_to_db_value(value, hints[name])
514            dirty_props[name] = value
515        if not dirty_props:
516            return
517        set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props)
518        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}"
519        params = {"vid": entity._vid, **dirty_props}
520        async with await self._db_session.tx() as tx:
521            await tx.execute(cypher, params)
522            await tx.commit()
523        run_hooks(entity, _AFTER_UPDATE)
524        entity._mark_clean()
525
526    async def _delete_node(self, entity: UniNode) -> None:
527        run_hooks(entity, _BEFORE_DELETE)
528        label = entity.__class__.__label__
529        vid = entity._vid
530        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n"
531        async with await self._db_session.tx() as tx:
532            await tx.execute(cypher, {"vid": vid})
533            await tx.commit()
534        if vid is not None and (label, vid) in self._identity_map:
535            del self._identity_map[(label, vid)]
536        entity._vid = None
537        entity._uid = None
538        entity._session = None
539        run_hooks(entity, _AFTER_DELETE)
540
541    def _result_to_model(
542        self,
543        data: dict[str, Any],
544        model: type[NodeT],
545    ) -> NodeT | None:
546        """Convert a query result row to a model instance (sync — pure dict processing)."""
547        if not data:
548            return None
549
550        data = dict(data)
551        data = run_class_hooks(model, _BEFORE_LOAD, data) or data
552
553        vid = data.pop("_id", None)
554        if vid is None:
555            vid = data.pop("_vid", None)
556        if vid is None:
557            vid = data.pop("vid", None)
558        if vid is not None and not isinstance(vid, int):
559            vid = int(vid)
560        data.pop("_label", None)
561
562        try:
563            instance = cast(
564                NodeT,
565                model.from_properties(data, vid=vid, session=self),
566            )
567        except Exception:
568            return None
569
570        if vid is not None:
571            existing = self._identity_map.get((model.__label__, vid))
572            if existing is not None:
573                return cast(NodeT, existing)
574            self._identity_map[(model.__label__, vid)] = instance
575
576        run_hooks(instance, _AFTER_LOAD)
577        return instance
578
579    def _load_relationship(
580        self,
581        entity: UniNode,
582        descriptor: RelationshipDescriptor[Any],
583    ) -> list[UniNode] | UniNode | None:
584        """Sync relationship loading — raises error for async session.
585        Use _async_load_relationship instead."""
586        raise SessionError(
587            "Cannot synchronously load relationships in an async session. "
588            "Use eager_load() or access relationships via async queries."
589        )
590
591    async def _async_eager_load_relationships(
592        self,
593        entities: list[NodeT],
594        relationships: list[str],
595    ) -> None:
596        """Eager load relationships for a list of entities (async)."""
597        if not entities:
598            return
599
600        model = type(entities[0])
601        rel_configs = model.get_relationship_fields()
602
603        for rel_name in relationships:
604            if rel_name not in rel_configs:
605                continue
606
607            config = rel_configs[rel_name]
608            label = model.__label__
609            vids = [e._vid for e in entities if e._vid is not None]
610
611            if not vids:
612                continue
613
614            pattern = _edge_pattern(config.edge_type, config.direction)
615            cypher = (
616                f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids "
617                f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels"
618            )
619            results = await self._db_session.query(cypher, {"vids": vids})
620
621            by_source: dict[int, list[Any]] = {}
622            for raw_row in results:
623                row = raw_row.to_dict()
624                src_vid = row["src_vid"]
625                node_data = _row_to_node_dict(row)
626                if node_data is None:
627                    continue
628                if src_vid not in by_source:
629                    by_source[src_vid] = []
630                by_source[src_vid].append(node_data)
631
632            for entity in entities:
633                if entity._vid in by_source:
634                    related = by_source[entity._vid]
635                    cache_attr = f"_rel_cache_{rel_name}"
636                    setattr(entity, cache_attr, related)

Async session for interacting with the graph database.

Mirrors UniSession with async methods. Uses AsyncUni.

Example:

from uni_db import AsyncUni from uni_pydantic import AsyncUniSession

db = await AsyncUni.open("./my_graph") async with AsyncUniSession(db) as session: ... session.register(Person) ... await session.sync_schema() ... alice = Person(name="Alice", age=30) ... session.add(alice) ... await session.commit()

AsyncUniSession(db: AsyncUni)
153    def __init__(self, db: uni_db.AsyncUni) -> None:
154        self._db = db
155        self._db_session = db.session()
156        self._schema_gen = SchemaGenerator()
157        self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = (
158            WeakValueDictionary()
159        )
160        self._pending_new: list[UniNode] = []
161        self._pending_delete: list[UniNode] = []
def close(self) -> None:
174    def close(self) -> None:
175        """Close the session and clear pending state."""
176        self._pending_new.clear()
177        self._pending_delete.clear()

Close the session and clear pending state.

db: AsyncUni
179    @property
180    def db(self) -> uni_db.AsyncUni:
181        """Access the underlying uni_db.AsyncUni for low-level operations."""
182        return self._db

Access the underlying uni_db.AsyncUni for low-level operations.

async def locy(self, program: str, params: dict[str, typing.Any] | None = None) -> Any:
184    async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any:
185        """
186        Evaluate a Locy program and return derived facts, stats, and warnings.
187
188        Delegates to the underlying ``uni_db.AsyncSession.locy()``.
189        """
190        return await self._db_session.locy(program, params)

Evaluate a Locy program and return derived facts, stats, and warnings.

Delegates to the underlying uni_db.AsyncSession.locy().

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
192    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
193        """Register model classes with the session (sync)."""
194        self._schema_gen.register(*models)

Register model classes with the session (sync).

async def sync_schema(self) -> None:
196    async def sync_schema(self) -> None:
197        """Synchronize database schema with registered models."""
198        await self._schema_gen.async_apply_to_database(self._db)

Synchronize database schema with registered models.

def query( self, model: type[~NodeT]) -> AsyncQueryBuilder[~NodeT]:
200    def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]:
201        """Create an async query builder for the given model."""
202        return AsyncQueryBuilder(self, model)

Create an async query builder for the given model.

def add(self, entity: UniNode) -> None:
204    def add(self, entity: UniNode) -> None:
205        """Add a new entity to be persisted (sync — just collects)."""
206        if entity.is_persisted:
207            raise SessionError(f"Entity {entity!r} is already persisted")
208        entity._session = self
209        self._pending_new.append(entity)

Add a new entity to be persisted (sync — just collects).

def add_all(self, entities: Sequence[UniNode]) -> None:
211    def add_all(self, entities: Sequence[UniNode]) -> None:
212        """Add multiple entities (sync — just collects)."""
213        for entity in entities:
214            self.add(entity)

Add multiple entities (sync — just collects).

def delete(self, entity: UniNode) -> None:
216    def delete(self, entity: UniNode) -> None:
217        """Mark an entity for deletion (sync — just collects)."""
218        if not entity.is_persisted:
219            raise NotPersisted(entity)
220        self._pending_delete.append(entity)

Mark an entity for deletion (sync — just collects).

async def get( self, model: type[~NodeT], vid: int | None = None, uid: str | None = None, **kwargs: Any) -> Optional[~NodeT]:
222    async def get(
223        self,
224        model: type[NodeT],
225        vid: int | None = None,
226        uid: str | None = None,
227        **kwargs: Any,
228    ) -> NodeT | None:
229        """Get an entity by ID or unique properties."""
230        if vid is not None:
231            cached = self._identity_map.get((model.__label__, vid))
232            if cached is not None:
233                return cached  # type: ignore[return-value]
234
235        label = model.__label__
236        params: dict[str, Any] = {}
237
238        if vid is not None:
239            cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
240            params["vid"] = vid
241        elif uid is not None:
242            cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}"
243            params["uid"] = uid
244        elif kwargs:
245            for k in kwargs:
246                _validate_property(k, model)
247            conditions = [f"n.{k} = ${k}" for k in kwargs]
248            cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1"
249            params.update(kwargs)
250        else:
251            raise ValueError("Must provide vid, uid, or property filters")
252
253        results = await self._db_session.query(cypher, params)
254        if not results:
255            return None
256
257        node_data = _row_to_node_dict(results[0].to_dict())
258        if node_data is None:
259            return None
260        return self._result_to_model(node_data, model)

Get an entity by ID or unique properties.

async def refresh(self, entity: UniNode) -> None:
262    async def refresh(self, entity: UniNode) -> None:
263        """Refresh an entity's properties from the database."""
264        if not entity.is_persisted:
265            raise NotPersisted(entity)
266
267        label = entity.__class__.__label__
268        cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}"
269        results = await self._db_session.query(cypher, {"vid": entity._vid})
270
271        if not results:
272            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
273
274        props = _row_to_node_dict(results[0].to_dict())
275        if props is None:
276            raise SessionError(f"Entity with vid={entity._vid} no longer exists")
277        try:
278            hints = get_type_hints(type(entity))
279        except Exception:
280            hints = {}
281
282        for field_name in entity.get_property_fields():
283            if field_name in props:
284                value = props[field_name]
285                if field_name in hints:
286                    value = db_to_python_value(value, hints[field_name])
287                setattr(entity, field_name, value)
288
289        entity._mark_clean()

Refresh an entity's properties from the database.

async def commit(self) -> None:
291    async def commit(self) -> None:
292        """Commit all pending changes."""
293        for entity in self._pending_new:
294            await self._create_node(entity)
295
296        for (label, vid), entity in list(self._identity_map.items()):
297            if entity.is_dirty and entity.is_persisted:
298                await self._update_node(entity)
299
300        for entity in self._pending_delete:
301            await self._delete_node(entity)
302
303        await self._db.flush()
304        self._pending_new.clear()
305        self._pending_delete.clear()

Commit all pending changes.

async def rollback(self) -> None:
307    async def rollback(self) -> None:
308        """Discard all pending changes."""
309        for entity in self._pending_new:
310            entity._session = None
311        self._pending_new.clear()
312        self._pending_delete.clear()
313        for entity in list(self._identity_map.values()):
314            if entity.is_dirty:
315                await self.refresh(entity)

Discard all pending changes.

async def transaction(self) -> AsyncUniTransaction:
317    async def transaction(self) -> AsyncUniTransaction:
318        """Create an async transaction. Use as `async with session.transaction() as tx:`."""
319        return AsyncUniTransaction(self)

Create an async transaction. Use as async with session.transaction() as tx:.

async def cypher( self, query: str, params: dict[str, typing.Any] | None = None, result_type: type[~NodeT] | None = None) -> list[~NodeT] | list[dict[str, typing.Any]]:
321    async def cypher(
322        self,
323        query: str,
324        params: dict[str, Any] | None = None,
325        result_type: type[NodeT] | None = None,
326    ) -> list[NodeT] | list[dict[str, Any]]:
327        """Execute a raw Cypher query."""
328        results = await self._db_session.query(query, params)
329
330        if result_type is None:
331            return [r.to_dict() for r in results]
332
333        mapped = []
334        for raw_row in results:
335            row = raw_row.to_dict()
336            for key, value in row.items():
337                if isinstance(value, dict):
338                    if "_id" in value and "_label" in value:
339                        instance = self._result_to_model(value, result_type)
340                        if instance:
341                            mapped.append(instance)
342                            break
343                    elif "_label" in value:
344                        label = value["_label"]
345                        if label in self._schema_gen._node_models:
346                            model = self._schema_gen._node_models[label]
347                            instance = self._result_to_model(value, model)
348                            if instance:
349                                mapped.append(instance)
350                                break
351            else:
352                first_value = next(iter(row.values()), None)
353                if isinstance(first_value, dict):
354                    instance = self._result_to_model(first_value, result_type)
355                    if instance:
356                        mapped.append(instance)
357
358        return mapped

Execute a raw Cypher query.

async def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: dict[str, typing.Any] | UniEdge | None = None) -> None:
360    async def create_edge(
361        self,
362        source: UniNode,
363        edge_type: str,
364        target: UniNode,
365        properties: dict[str, Any] | UniEdge | None = None,
366    ) -> None:
367        """Create an edge between two nodes."""
368        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
369            source, target
370        )
371        props = UniSession._normalize_edge_properties(properties)
372
373        props_str = ", ".join(f"{k}: ${k}" for k in props)
374        if props_str:
375            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)"
376        else:
377            cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)"
378
379        async with await self._db_session.tx() as tx:
380            await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props})
381            await tx.commit()

Create an edge between two nodes.

async def delete_edge( self, source: UniNode, edge_type: str, target: UniNode) -> int:
383    async def delete_edge(
384        self, source: UniNode, edge_type: str, target: UniNode
385    ) -> int:
386        """Delete edges between two nodes. Returns the number of deleted edges."""
387        src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints(
388            source, target
389        )
390        cypher = (
391            f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) "
392            f"WHERE a._vid = $src AND b._vid = $dst "
393            f"DELETE r RETURN count(r) as count"
394        )
395        async with await self._db_session.tx() as tx:
396            results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid})
397            await tx.commit()
398        return cast(int, results[0]["count"]) if results else 0

Delete edges between two nodes. Returns the number of deleted edges.

async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
400    async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]:
401        """Bulk-add entities using bulk_writer."""
402        if not entities:
403            return []
404
405        by_label: dict[str, list[UniNode]] = {}
406        for entity in entities:
407            label = entity.__class__.__label__
408            if label not in by_label:
409                by_label[label] = []
410            by_label[label].append(entity)
411
412        all_vids: list[int] = []
413        try:
414            for label, group in by_label.items():
415                for entity in group:
416                    run_hooks(entity, _BEFORE_CREATE)
417                prop_dicts = [e.to_properties() for e in group]
418                tx = await self._db_session.tx()
419                async with await tx.bulk_writer().build() as bw:
420                    vids = await bw.insert_vertices(label, prop_dicts)
421                    await bw.commit()
422                await tx.commit()
423                for entity, vid in zip(group, vids):
424                    entity._attach_session(self, vid)
425                    self._identity_map[(label, vid)] = entity
426                    run_hooks(entity, _AFTER_CREATE)
427                    entity._mark_clean()
428                all_vids.extend(vids)
429        except Exception as e:
430            raise BulkLoadError(f"Bulk insert failed: {e}") from e
431
432        return all_vids

Bulk-add entities using bulk_writer.

async def explain(self, cypher: str) -> Any:
434    async def explain(self, cypher: str) -> Any:
435        """Get the query execution plan."""
436        return await self._db_session.explain(cypher)

Get the query execution plan.

async def profile(self, cypher: str) -> Any:
438    async def profile(self, cypher: str) -> Any:
439        """Run the query with profiling and return results + stats."""
440        return await self._db_session.profile(cypher)

Run the query with profiling and return results + stats.

async def save_schema(self, path: str) -> None:
442    async def save_schema(self, path: str) -> None:
443        """Save the database schema to a file."""
444        await self._db.save_schema(path)

Save the database schema to a file.

async def load_schema(self, path: str) -> None:
446    async def load_schema(self, path: str) -> None:
447        """Load a database schema from a file."""
448        await self._db.load_schema(path)

Load a database schema from a file.

class AsyncUniTransaction:
 54class AsyncUniTransaction:
 55    """Async transaction context for atomic operations."""
 56
 57    def __init__(self, session: AsyncUniSession) -> None:
 58        self._session = session
 59        self._tx: uni_db.AsyncTransaction | None = None
 60        self._pending_nodes: list[UniNode] = []
 61        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
 62        self._committed = False
 63        self._rolled_back = False
 64
 65    async def __aenter__(self) -> AsyncUniTransaction:
 66        self._tx = await self._session._db_session.tx()
 67        return self
 68
 69    async def __aexit__(
 70        self,
 71        exc_type: type[BaseException] | None,
 72        exc_val: BaseException | None,
 73        exc_tb: TracebackType | None,
 74    ) -> None:
 75        if exc_type is not None:
 76            await self.rollback()
 77            return
 78        if not self._committed and not self._rolled_back:
 79            await self.commit()
 80
 81    def add(self, entity: UniNode) -> None:
 82        """Add a node to be created in this transaction (sync — just collects)."""
 83        self._pending_nodes.append(entity)
 84
 85    def create_edge(
 86        self,
 87        source: UniNode,
 88        edge_type: str,
 89        target: UniNode,
 90        properties: UniEdge | None = None,
 91    ) -> None:
 92        """Create an edge between two nodes in this transaction (sync — just collects)."""
 93        if not source.is_persisted:
 94            raise NotPersisted(source)
 95        if not target.is_persisted:
 96            raise NotPersisted(target)
 97        self._pending_edges.append((source, edge_type, target, properties))
 98
 99    async def commit(self) -> None:
100        """Commit the transaction."""
101        if self._committed:
102            raise TransactionError("Transaction already committed")
103        if self._rolled_back:
104            raise TransactionError("Transaction already rolled back")
105        if self._tx is None:
106            raise TransactionError("Transaction not started")
107
108        try:
109            for node in self._pending_nodes:
110                await self._session._create_node_in_tx(node, self._tx)
111            for source, edge_type, target, props in self._pending_edges:
112                await self._session._create_edge_in_tx(
113                    source, edge_type, target, props, self._tx
114                )
115            await self._tx.commit()
116            self._committed = True
117            for node in self._pending_nodes:
118                node._mark_clean()
119        except Exception as e:
120            await self.rollback()
121            raise TransactionError(f"Commit failed: {e}") from e
122
123    async def rollback(self) -> None:
124        """Rollback the transaction."""
125        if self._rolled_back:
126            return
127        if self._tx is not None:
128            await self._tx.rollback()
129        self._rolled_back = True
130        self._pending_nodes.clear()
131        self._pending_edges.clear()

Async transaction context for atomic operations.

AsyncUniTransaction(session: AsyncUniSession)
57    def __init__(self, session: AsyncUniSession) -> None:
58        self._session = session
59        self._tx: uni_db.AsyncTransaction | None = None
60        self._pending_nodes: list[UniNode] = []
61        self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = []
62        self._committed = False
63        self._rolled_back = False
def add(self, entity: UniNode) -> None:
81    def add(self, entity: UniNode) -> None:
82        """Add a node to be created in this transaction (sync — just collects)."""
83        self._pending_nodes.append(entity)

Add a node to be created in this transaction (sync — just collects).

def create_edge( self, source: UniNode, edge_type: str, target: UniNode, properties: UniEdge | None = None) -> None:
85    def create_edge(
86        self,
87        source: UniNode,
88        edge_type: str,
89        target: UniNode,
90        properties: UniEdge | None = None,
91    ) -> None:
92        """Create an edge between two nodes in this transaction (sync — just collects)."""
93        if not source.is_persisted:
94            raise NotPersisted(source)
95        if not target.is_persisted:
96            raise NotPersisted(target)
97        self._pending_edges.append((source, edge_type, target, properties))

Create an edge between two nodes in this transaction (sync — just collects).

async def commit(self) -> None:
 99    async def commit(self) -> None:
100        """Commit the transaction."""
101        if self._committed:
102            raise TransactionError("Transaction already committed")
103        if self._rolled_back:
104            raise TransactionError("Transaction already rolled back")
105        if self._tx is None:
106            raise TransactionError("Transaction not started")
107
108        try:
109            for node in self._pending_nodes:
110                await self._session._create_node_in_tx(node, self._tx)
111            for source, edge_type, target, props in self._pending_edges:
112                await self._session._create_edge_in_tx(
113                    source, edge_type, target, props, self._tx
114                )
115            await self._tx.commit()
116            self._committed = True
117            for node in self._pending_nodes:
118                node._mark_clean()
119        except Exception as e:
120            await self.rollback()
121            raise TransactionError(f"Commit failed: {e}") from e

Commit the transaction.

async def rollback(self) -> None:
123    async def rollback(self) -> None:
124        """Rollback the transaction."""
125        if self._rolled_back:
126            return
127        if self._tx is not None:
128            await self._tx.rollback()
129        self._rolled_back = True
130        self._pending_nodes.clear()
131        self._pending_edges.clear()

Rollback the transaction.

def Field( default: Any = Ellipsis, *, default_factory: Callable[[], typing.Any] | None = None, alias: str | None = None, title: str | None = None, description: str | None = None, examples: list[typing.Any] | None = None, exclude: bool = False, json_schema_extra: dict[str, typing.Any] | None = None, index: Optional[Literal['btree', 'hash', 'fulltext', 'vector', 'sparse']] = None, unique: bool = False, tokenizer: str | None = None, metric: Optional[Literal['l2', 'cosine', 'dot']] = None, generated: str | None = None) -> Any:
 69def Field(
 70    default: Any = ...,
 71    *,
 72    default_factory: Callable[[], Any] | None = None,
 73    alias: str | None = None,
 74    title: str | None = None,
 75    description: str | None = None,
 76    examples: list[Any] | None = None,
 77    exclude: bool = False,
 78    json_schema_extra: dict[str, Any] | None = None,
 79    # Uni-specific options
 80    index: IndexType | None = None,
 81    unique: bool = False,
 82    tokenizer: str | None = None,
 83    metric: VectorMetric | None = None,
 84    generated: str | None = None,
 85) -> Any:
 86    """
 87    Create a field with uni-pydantic configuration.
 88
 89    This extends Pydantic's Field with graph database options.
 90
 91    Args:
 92        default: Default value for the field.
 93        default_factory: Factory function for default value.
 94        alias: Field alias for serialization.
 95        title: Human-readable title.
 96        description: Field description.
 97        examples: Example values.
 98        exclude: Exclude from serialization.
 99        json_schema_extra: Extra JSON schema properties.
100        index: Index type ("btree", "hash", "fulltext", "vector").
101        unique: Whether to create a unique constraint.
102        tokenizer: Tokenizer for fulltext index (default: "standard").
103        metric: Distance metric for vector index ("l2", "cosine", "dot").
104        generated: Expression for generated/computed property.
105
106    Returns:
107        A Pydantic FieldInfo with uni-pydantic metadata attached.
108
109    Examples:
110        >>> class Person(UniNode):
111        ...     name: str = Field(index="btree")
112        ...     email: str = Field(unique=True)
113        ...     bio: str = Field(index="fulltext", tokenizer="standard")
114        ...     embedding: Vector[768] = Field(metric="cosine")
115    """
116    # Default tokenizer for fulltext indexes
117    if index == "fulltext" and tokenizer is None:
118        tokenizer = "standard"
119
120    # Store uni config in json_schema_extra
121    uni_config = FieldConfig(
122        index=index,
123        unique=unique,
124        tokenizer=tokenizer,
125        metric=metric,
126        generated=generated,
127        default=default,
128        default_factory=default_factory,
129        alias=alias,
130        title=title,
131        description=description,
132        examples=examples,
133        exclude=exclude,
134        json_schema_extra=json_schema_extra,
135    )
136
137    # Merge uni config into json_schema_extra
138    extra = json_schema_extra or {}
139    extra["uni_config"] = uni_config
140
141    # Create Pydantic FieldInfo
142    from pydantic.fields import FieldInfo as PydanticFieldInfo
143
144    if default_factory is not None:
145        return PydanticFieldInfo(
146            default_factory=default_factory,
147            alias=alias,
148            title=title,
149            description=description,
150            examples=examples,
151            exclude=exclude,
152            json_schema_extra=extra,
153        )
154    elif default is not ...:
155        return PydanticFieldInfo(
156            default=default,
157            alias=alias,
158            title=title,
159            description=description,
160            examples=examples,
161            exclude=exclude,
162            json_schema_extra=extra,
163        )
164    else:
165        return PydanticFieldInfo(
166            alias=alias,
167            title=title,
168            description=description,
169            examples=examples,
170            exclude=exclude,
171            json_schema_extra=extra,
172        )

Create a field with uni-pydantic configuration.

This extends Pydantic's Field with graph database options.

Args: default: Default value for the field. default_factory: Factory function for default value. alias: Field alias for serialization. title: Human-readable title. description: Field description. examples: Example values. exclude: Exclude from serialization. json_schema_extra: Extra JSON schema properties. index: Index type ("btree", "hash", "fulltext", "vector"). unique: Whether to create a unique constraint. tokenizer: Tokenizer for fulltext index (default: "standard"). metric: Distance metric for vector index ("l2", "cosine", "dot"). generated: Expression for generated/computed property.

Returns: A Pydantic FieldInfo with uni-pydantic metadata attached.

Examples:

class Person(UniNode): ... name: str = Field(index="btree") ... email: str = Field(unique=True) ... bio: str = Field(index="fulltext", tokenizer="standard") ... embedding: Vector[768] = Field(metric="cosine")

@dataclass
class FieldConfig:
41@dataclass
42class FieldConfig:
43    """Configuration for a uni-pydantic field."""
44
45    # Index configuration
46    index: IndexType | None = None
47    unique: bool = False
48
49    # Fulltext index options
50    tokenizer: str | None = None
51
52    # Vector index options
53    metric: VectorMetric | None = None
54
55    # Generated/computed property
56    generated: str | None = None
57
58    # Pydantic field options (passed through)
59    default: Any = dataclass_field(default_factory=lambda: ...)
60    default_factory: Callable[[], Any] | None = None
61    alias: str | None = None
62    title: str | None = None
63    description: str | None = None
64    examples: list[Any] | None = None
65    exclude: bool = False
66    json_schema_extra: dict[str, Any] | None = None

Configuration for a uni-pydantic field.

FieldConfig( index: Optional[Literal['btree', 'hash', 'fulltext', 'vector', 'sparse']] = None, unique: bool = False, tokenizer: str | None = None, metric: Optional[Literal['l2', 'cosine', 'dot']] = None, generated: str | None = None, default: Any = <factory>, default_factory: Callable[[], typing.Any] | None = None, alias: str | None = None, title: str | None = None, description: str | None = None, examples: list[typing.Any] | None = None, exclude: bool = False, json_schema_extra: dict[str, typing.Any] | None = None)
index: Optional[Literal['btree', 'hash', 'fulltext', 'vector', 'sparse']] = None
unique: bool = False
tokenizer: str | None = None
metric: Optional[Literal['l2', 'cosine', 'dot']] = None
generated: str | None = None
default: Any
default_factory: Callable[[], typing.Any] | None = None
alias: str | None = None
title: str | None = None
description: str | None = None
examples: list[typing.Any] | None = None
exclude: bool = False
json_schema_extra: dict[str, typing.Any] | None = None
def Relationship( edge_type: str, *, direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing', edge_model: type[UniEdge] | None = None, eager: bool = False, cascade_delete: bool = False) -> Any:
272def Relationship(
273    edge_type: str,
274    *,
275    direction: Direction = "outgoing",
276    edge_model: type[UniEdge] | None = None,
277    eager: bool = False,
278    cascade_delete: bool = False,
279) -> Any:
280    """
281    Declare a relationship to another node type.
282
283    Relationships are lazy-loaded by default. Use eager=True or
284    query.eager_load() to load them with the parent query.
285
286    Args:
287        edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT").
288        direction: Relationship direction:
289            - "outgoing": Follow edges from this node (default)
290            - "incoming": Follow edges to this node
291            - "both": Follow edges in both directions
292        edge_model: Optional UniEdge subclass for typed edge properties.
293        eager: Whether to eager-load this relationship by default.
294        cascade_delete: Whether to delete related edges when this node is deleted.
295
296    Returns:
297        A RelationshipDescriptor that will be processed during model creation.
298
299    Examples:
300        >>> class Person(UniNode):
301        ...     # Outgoing relationship (default)
302        ...     follows: list["Person"] = Relationship("FOLLOWS")
303        ...
304        ...     # Incoming relationship
305        ...     followers: list["Person"] = Relationship("FOLLOWS", direction="incoming")
306        ...
307        ...     # Single optional relationship
308        ...     manager: "Person | None" = Relationship("REPORTS_TO")
309        ...
310        ...     # Relationship with edge properties
311        ...     friendships: list[tuple["Person", FriendshipEdge]] = Relationship(
312        ...         "FRIEND_OF",
313        ...         edge_model=FriendshipEdge
314        ...     )
315    """
316    config = RelationshipConfig(
317        edge_type=edge_type,
318        direction=direction,
319        edge_model=edge_model,
320        eager=eager,
321        cascade_delete=cascade_delete,
322    )
323    # Return a marker that will be processed by the metaclass
324    return _RelationshipMarker(config)

Declare a relationship to another node type.

Relationships are lazy-loaded by default. Use eager=True or query.eager_load() to load them with the parent query.

Args: edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT"). direction: Relationship direction: - "outgoing": Follow edges from this node (default) - "incoming": Follow edges to this node - "both": Follow edges in both directions edge_model: Optional UniEdge subclass for typed edge properties. eager: Whether to eager-load this relationship by default. cascade_delete: Whether to delete related edges when this node is deleted.

Returns: A RelationshipDescriptor that will be processed during model creation.

Examples:

class Person(UniNode): ... # Outgoing relationship (default) ... follows: list["Person"] = Relationship("FOLLOWS") ... ... # Incoming relationship ... followers: list["Person"] = Relationship("FOLLOWS", direction="incoming") ... ... # Single optional relationship ... manager: "Person | None" = Relationship("REPORTS_TO") ... ... # Relationship with edge properties ... friendships: list[tuple["Person", FriendshipEdge]] = Relationship( ... "FRIEND_OF", ... edge_model=FriendshipEdge ... )

@dataclass
class RelationshipConfig:
185@dataclass
186class RelationshipConfig:
187    """Configuration for a relationship field."""
188
189    edge_type: str
190    direction: Direction = "outgoing"
191    edge_model: type[UniEdge] | None = None
192    eager: bool = False
193    cascade_delete: bool = False

Configuration for a relationship field.

RelationshipConfig( edge_type: str, direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing', edge_model: type[UniEdge] | None = None, eager: bool = False, cascade_delete: bool = False)
edge_type: str
direction: Literal['outgoing', 'incoming', 'both'] = 'outgoing'
edge_model: type[UniEdge] | None = None
eager: bool = False
cascade_delete: bool = False
class RelationshipDescriptor(typing.Generic[~NodeT]):
196class RelationshipDescriptor(Generic[NodeT]):
197    """
198    Descriptor for relationship fields that enables lazy loading.
199
200    When accessed on an instance, it returns the related nodes.
201    When accessed on the class, it returns the descriptor for query building.
202    """
203
204    def __init__(
205        self,
206        config: RelationshipConfig,
207        field_name: str,
208        target_type: type[NodeT] | str | None = None,
209        is_list: bool = True,
210    ) -> None:
211        self.config = config
212        self.field_name = field_name
213        self.target_type = target_type
214        self.is_list = is_list
215        self._cache_attr = f"_rel_cache_{field_name}"
216
217    def __set_name__(self, owner: type, name: str) -> None:
218        self.field_name = name
219        self._cache_attr = f"_rel_cache_{name}"
220
221    @overload
222    def __get__(
223        self, obj: None, objtype: type[NodeT]
224    ) -> RelationshipDescriptor[NodeT]: ...
225
226    @overload
227    def __get__(
228        self, obj: NodeT, objtype: type[NodeT] | None = None
229    ) -> list[NodeT] | NodeT | None: ...
230
231    def __get__(
232        self, obj: NodeT | None, objtype: type[NodeT] | None = None
233    ) -> RelationshipDescriptor[NodeT] | list[NodeT] | NodeT | None:
234        if obj is None:
235            # Class-level access returns the descriptor
236            return self
237
238        # Instance-level access - check cache first
239        if hasattr(obj, self._cache_attr):
240            cached = getattr(obj, self._cache_attr)
241            return cast("list[NodeT] | NodeT | None", cached)
242
243        # Check if we have a session for lazy loading
244        session = getattr(obj, "_session", None)
245        if session is None:
246            from .exceptions import LazyLoadError
247
248            raise LazyLoadError(
249                self.field_name,
250                "No session attached. Use session.get() or enable eager loading.",
251            )
252
253        # Lazy load the relationship
254        result = session._load_relationship(obj, self)
255
256        # Cache the result
257        setattr(obj, self._cache_attr, result)
258        return cast("list[NodeT] | NodeT | None", result)
259
260    def __set__(self, obj: NodeT, value: list[NodeT] | NodeT | None) -> None:
261        # Allow setting the cached value (e.g., during eager loading)
262        setattr(obj, self._cache_attr, value)
263
264    def __repr__(self) -> str:
265        return f"Relationship({self.config.edge_type!r}, direction={self.config.direction!r})"

Descriptor for relationship fields that enables lazy loading.

When accessed on an instance, it returns the related nodes. When accessed on the class, it returns the descriptor for query building.

RelationshipDescriptor( config: RelationshipConfig, field_name: str, target_type: type[~NodeT] | str | None = None, is_list: bool = True)
204    def __init__(
205        self,
206        config: RelationshipConfig,
207        field_name: str,
208        target_type: type[NodeT] | str | None = None,
209        is_list: bool = True,
210    ) -> None:
211        self.config = config
212        self.field_name = field_name
213        self.target_type = target_type
214        self.is_list = is_list
215        self._cache_attr = f"_rel_cache_{field_name}"
config
field_name
target_type
is_list
def get_field_config( field_info: pydantic.fields.FieldInfo) -> FieldConfig | None:
175def get_field_config(field_info: FieldInfo) -> FieldConfig | None:
176    """Extract uni-pydantic config from a Pydantic FieldInfo."""
177    extra = field_info.json_schema_extra
178    if isinstance(extra, dict):
179        config = extra.get("uni_config")
180        if isinstance(config, FieldConfig):
181            return config
182    return None

Extract uni-pydantic config from a Pydantic FieldInfo.

IndexType = typing.Literal['btree', 'hash', 'fulltext', 'vector', 'sparse']
Direction = typing.Literal['outgoing', 'incoming', 'both']
VectorMetric = typing.Literal['l2', 'cosine', 'dot']
class Btic:
281class Btic:
282    """A BTIC temporal interval value for Uni graph database.
283
284    Construct from an ISO 8601-inspired string literal::
285
286        Btic("1985")
287        Btic("1985-03/2024-06")
288        Btic("~1985")           # approximate certainty
289        Btic("2020-03/")        # ongoing (unbounded hi)
290
291    Use as a Pydantic model field type::
292
293        class Event(UniNode):
294            when: Btic
295    """
296
297    def __init__(self, value: str | object) -> None:
298        if _PyBtic is None:
299            raise ImportError("uni_db is required for Btic type")
300        if isinstance(value, str):
301            self._inner = _PyBtic(value)
302        elif _PyBtic is not None and isinstance(value, _PyBtic):
303            self._inner = value
304        elif isinstance(value, Btic):
305            self._inner = value._inner
306        else:
307            raise TypeError(f"Expected str or Btic, got {type(value)}")
308
309    @property
310    def lo(self) -> int:
311        """Lower bound in milliseconds since epoch."""
312        return self._inner.lo
313
314    @property
315    def hi(self) -> int:
316        """Upper bound in milliseconds since epoch."""
317        return self._inner.hi
318
319    @property
320    def meta(self) -> int:
321        """Raw 64-bit metadata word."""
322        return self._inner.meta
323
324    @property
325    def lo_granularity(self) -> str:
326        """Lower bound granularity name."""
327        return self._inner.lo_granularity
328
329    @property
330    def hi_granularity(self) -> str:
331        """Upper bound granularity name."""
332        return self._inner.hi_granularity
333
334    @property
335    def lo_certainty(self) -> str:
336        """Lower bound certainty name."""
337        return self._inner.lo_certainty
338
339    @property
340    def hi_certainty(self) -> str:
341        """Upper bound certainty name."""
342        return self._inner.hi_certainty
343
344    @property
345    def duration_ms(self) -> int | None:
346        """Duration in milliseconds, or None if unbounded."""
347        return self._inner.duration_ms
348
349    @property
350    def is_instant(self) -> bool:
351        """True if the interval is exactly 1 millisecond wide."""
352        return self._inner.is_instant
353
354    @property
355    def is_unbounded(self) -> bool:
356        """True if either bound is infinite."""
357        return self._inner.is_unbounded
358
359    @property
360    def is_finite(self) -> bool:
361        """True if both bounds are finite."""
362        return self._inner.is_finite
363
364    def __repr__(self) -> str:
365        return f'Btic("{self._inner}")'
366
367    def __str__(self) -> str:
368        return str(self._inner)
369
370    def __eq__(self, other: object) -> bool:
371        if isinstance(other, Btic):
372            return self._inner == other._inner
373        return False
374
375    def __hash__(self) -> int:
376        return hash(self._inner)
377
378    @classmethod
379    def __get_pydantic_core_schema__(
380        cls, source_type: Any, handler: GetCoreSchemaHandler
381    ) -> CoreSchema:
382        """Make Btic compatible with Pydantic v2."""
383
384        def validate_btic(v: Any) -> Btic:
385            if isinstance(v, Btic):
386                return v
387            if isinstance(v, str):
388                return Btic(v)
389            if _PyBtic is not None and isinstance(v, _PyBtic):
390                return Btic(v)
391            raise TypeError(f"Expected str or Btic, got {type(v)}")
392
393        return core_schema.no_info_plain_validator_function(
394            validate_btic,
395            serialization=core_schema.plain_serializer_function_ser_schema(
396                lambda v: str(v._inner) if isinstance(v, Btic) else str(v),
397                info_arg=False,
398            ),
399        )

A BTIC temporal interval value for Uni graph database.

Construct from an ISO 8601-inspired string literal::

Btic("1985")
Btic("1985-03/2024-06")
Btic("~1985")           # approximate certainty
Btic("2020-03/")        # ongoing (unbounded hi)

Use as a Pydantic model field type::

class Event(UniNode):
    when: Btic
Btic(value: str | object)
297    def __init__(self, value: str | object) -> None:
298        if _PyBtic is None:
299            raise ImportError("uni_db is required for Btic type")
300        if isinstance(value, str):
301            self._inner = _PyBtic(value)
302        elif _PyBtic is not None and isinstance(value, _PyBtic):
303            self._inner = value
304        elif isinstance(value, Btic):
305            self._inner = value._inner
306        else:
307            raise TypeError(f"Expected str or Btic, got {type(value)}")
lo: int
309    @property
310    def lo(self) -> int:
311        """Lower bound in milliseconds since epoch."""
312        return self._inner.lo

Lower bound in milliseconds since epoch.

hi: int
314    @property
315    def hi(self) -> int:
316        """Upper bound in milliseconds since epoch."""
317        return self._inner.hi

Upper bound in milliseconds since epoch.

meta: int
319    @property
320    def meta(self) -> int:
321        """Raw 64-bit metadata word."""
322        return self._inner.meta

Raw 64-bit metadata word.

lo_granularity: str
324    @property
325    def lo_granularity(self) -> str:
326        """Lower bound granularity name."""
327        return self._inner.lo_granularity

Lower bound granularity name.

hi_granularity: str
329    @property
330    def hi_granularity(self) -> str:
331        """Upper bound granularity name."""
332        return self._inner.hi_granularity

Upper bound granularity name.

lo_certainty: str
334    @property
335    def lo_certainty(self) -> str:
336        """Lower bound certainty name."""
337        return self._inner.lo_certainty

Lower bound certainty name.

hi_certainty: str
339    @property
340    def hi_certainty(self) -> str:
341        """Upper bound certainty name."""
342        return self._inner.hi_certainty

Upper bound certainty name.

duration_ms: int | None
344    @property
345    def duration_ms(self) -> int | None:
346        """Duration in milliseconds, or None if unbounded."""
347        return self._inner.duration_ms

Duration in milliseconds, or None if unbounded.

is_instant: bool
349    @property
350    def is_instant(self) -> bool:
351        """True if the interval is exactly 1 millisecond wide."""
352        return self._inner.is_instant

True if the interval is exactly 1 millisecond wide.

is_unbounded: bool
354    @property
355    def is_unbounded(self) -> bool:
356        """True if either bound is infinite."""
357        return self._inner.is_unbounded

True if either bound is infinite.

is_finite: bool
359    @property
360    def is_finite(self) -> bool:
361        """True if both bounds are finite."""
362        return self._inner.is_finite

True if both bounds are finite.

class Vector(typing.Generic[~N]):
 67class Vector(Generic[N], metaclass=VectorMeta):
 68    """
 69    A vector type with fixed dimensions for embeddings.
 70
 71    Usage:
 72        embedding: Vector[1536]  # 1536-dimensional vector
 73
 74    At runtime, vectors are stored as list[float].
 75    """
 76
 77    __dimensions__: int = 0
 78    __origin__: type | None = None
 79
 80    def __init__(self, values: list[float]) -> None:
 81        expected = self.__class__.__dimensions__
 82        if expected > 0 and len(values) != expected:
 83            raise ValueError(f"Vector expects {expected} dimensions, got {len(values)}")
 84        self._values = values
 85
 86    @property
 87    def values(self) -> list[float]:
 88        return self._values
 89
 90    def __repr__(self) -> str:
 91        dims = self.__class__.__dimensions__
 92        return (
 93            f"Vector[{dims}]({self._values[:3]}...)"
 94            if len(self._values) > 3
 95            else f"Vector[{dims}]({self._values})"
 96        )
 97
 98    def __eq__(self, other: object) -> bool:
 99        if isinstance(other, Vector):
100            return self._values == other._values
101        if isinstance(other, list):
102            return self._values == other
103        return False
104
105    def __len__(self) -> int:
106        return len(self._values)
107
108    def __iter__(self):  # type: ignore[no-untyped-def]
109        return iter(self._values)
110
111    @classmethod
112    def __get_pydantic_core_schema__(
113        cls, source_type: Any, handler: GetCoreSchemaHandler
114    ) -> CoreSchema:
115        """Make Vector compatible with Pydantic v2."""
116        dimensions = getattr(source_type, "__dimensions__", 0)
117        vec_cls = source_type if dimensions > 0 else cls
118
119        def validate_vector(v: Any) -> Vector:  # type: ignore[type-arg]
120            if isinstance(v, Vector):
121                if dimensions > 0 and len(v) != dimensions:
122                    raise ValueError(
123                        f"Vector expects {dimensions} dimensions, got {len(v)}"
124                    )
125                return v
126            if isinstance(v, list):
127                if dimensions > 0 and len(v) != dimensions:
128                    raise ValueError(
129                        f"Vector expects {dimensions} dimensions, got {len(v)}"
130                    )
131                return vec_cls([float(x) for x in v])
132            raise TypeError(f"Expected list or Vector, got {type(v)}")
133
134        return core_schema.no_info_plain_validator_function(
135            validate_vector,
136            serialization=core_schema.plain_serializer_function_ser_schema(
137                lambda v: v.values if isinstance(v, Vector) else list(v),
138                info_arg=False,
139            ),
140        )

A vector type with fixed dimensions for embeddings.

Usage: embedding: Vector[1536] # 1536-dimensional vector

At runtime, vectors are stored as list[float].

Vector(values: list[float])
80    def __init__(self, values: list[float]) -> None:
81        expected = self.__class__.__dimensions__
82        if expected > 0 and len(values) != expected:
83            raise ValueError(f"Vector expects {expected} dimensions, got {len(values)}")
84        self._values = values
values: list[float]
86    @property
87    def values(self) -> list[float]:
88        return self._values
class SparseVector(typing.Generic[~N]):
182class SparseVector(Generic[N], metaclass=SparseVectorMeta):
183    """
184    A learned-sparse (SPLADE / BGE-M3) vector over a fixed-size vocabulary.
185
186    Usage:
187        terms: SparseVector[30522]  # SPLADE head over a 30522-term BERT vocab
188
189    At runtime, holds parallel ``indices`` (term ids) and ``values`` (weights).
190    Accepts a ``dict[int, float]`` of term id -> weight, a ``uni_db.SparseVector``,
191    or an existing instance; ingestion serializes to the typed Rust binding when
192    available, otherwise to an ``{"indices": [...], "values": [...]}`` mapping.
193    """
194
195    __sparse_dimensions__: int = 0
196    __origin__: type | None = None
197
198    def __init__(self, indices: list[int], values: list[float]) -> None:
199        if len(indices) != len(values):
200            raise ValueError(
201                f"SparseVector indices/values length mismatch: "
202                f"{len(indices)} vs {len(values)}"
203            )
204        self._indices = [int(i) for i in indices]
205        self._values = [float(v) for v in values]
206
207    @property
208    def indices(self) -> list[int]:
209        return self._indices
210
211    @property
212    def values(self) -> list[float]:
213        return self._values
214
215    @classmethod
216    def from_dict(cls, mapping: dict[int, float]) -> SparseVector[Any]:
217        """Build from a ``{term_id: weight}`` mapping (sorted by term id)."""
218        items = sorted(mapping.items())
219        return cls([k for k, _ in items], [v for _, v in items])
220
221    def __repr__(self) -> str:
222        dims = self.__class__.__sparse_dimensions__
223        return f"SparseVector[{dims}](indices={self._indices}, values={self._values})"
224
225    def __eq__(self, other: object) -> bool:
226        if isinstance(other, SparseVector):
227            return self._indices == other._indices and self._values == other._values
228        return False
229
230    def __len__(self) -> int:
231        return len(self._indices)
232
233    @classmethod
234    def __get_pydantic_core_schema__(
235        cls, source_type: Any, handler: GetCoreSchemaHandler
236    ) -> CoreSchema:
237        """Make SparseVector compatible with Pydantic v2."""
238        dimensions = getattr(source_type, "__sparse_dimensions__", 0)
239        sv_cls = source_type if dimensions > 0 else cls
240
241        def validate_sparse(v: Any) -> SparseVector:  # type: ignore[type-arg]
242            if isinstance(v, SparseVector):
243                return v
244            if _PySparseVector is not None and isinstance(v, _PySparseVector):
245                return sv_cls(list(v.indices), list(v.values))
246            if isinstance(v, dict):
247                return sv_cls.from_dict(v)
248            if isinstance(v, (tuple, list)) and len(v) == 2:
249                return sv_cls(list(v[0]), list(v[1]))
250            raise TypeError(
251                f"Expected SparseVector, dict, or (indices, values), got {type(v)}"
252            )
253
254        def serialize_sparse(v: SparseVector) -> Any:  # type: ignore[type-arg]
255            if _PySparseVector is not None:
256                return _PySparseVector(v.indices, v.values)
257            return {"indices": v.indices, "values": v.values}
258
259        return core_schema.no_info_plain_validator_function(
260            validate_sparse,
261            serialization=core_schema.plain_serializer_function_ser_schema(
262                serialize_sparse,
263                info_arg=False,
264            ),
265        )

A learned-sparse (SPLADE / BGE-M3) vector over a fixed-size vocabulary.

Usage: terms: SparseVector[30522] # SPLADE head over a 30522-term BERT vocab

At runtime, holds parallel indices (term ids) and values (weights). Accepts a dict[int, float] of term id -> weight, a uni_db.SparseVector, or an existing instance; ingestion serializes to the typed Rust binding when available, otherwise to an {"indices": [...], "values": [...]} mapping.

SparseVector(indices: list[int], values: list[float])
198    def __init__(self, indices: list[int], values: list[float]) -> None:
199        if len(indices) != len(values):
200            raise ValueError(
201                f"SparseVector indices/values length mismatch: "
202                f"{len(indices)} vs {len(values)}"
203            )
204        self._indices = [int(i) for i in indices]
205        self._values = [float(v) for v in values]
indices: list[int]
207    @property
208    def indices(self) -> list[int]:
209        return self._indices
values: list[float]
211    @property
212    def values(self) -> list[float]:
213        return self._values
@classmethod
def from_dict(cls, mapping: dict[int, float]) -> 'SparseVector[Any]':
215    @classmethod
216    def from_dict(cls, mapping: dict[int, float]) -> SparseVector[Any]:
217        """Build from a ``{term_id: weight}`` mapping (sorted by term id)."""
218        items = sorted(mapping.items())
219        return cls([k for k, _ in items], [v for _, v in items])

Build from a {term_id: weight} mapping (sorted by term id).

def python_type_to_uni(type_hint: Any, *, nullable: bool = False) -> tuple[str, bool]:
558def python_type_to_uni(type_hint: Any, *, nullable: bool = False) -> tuple[str, bool]:
559    """
560    Convert a Python type hint to a Uni DataType string.
561
562    Args:
563        type_hint: The Python type hint to convert.
564        nullable: Whether the field is explicitly nullable.
565
566    Returns:
567        Tuple of (uni_data_type, is_nullable)
568
569    Raises:
570        TypeMappingError: If the type cannot be mapped.
571    """
572    # Unwrap Annotated if present
573    type_hint, _ = unwrap_annotated(type_hint)
574
575    # Check for optional (T | None)
576    is_opt, inner_type = is_optional(type_hint)
577    if is_opt:
578        uni_type, _ = python_type_to_uni(inner_type)
579        return uni_type, True
580
581    # Check for SparseVector types (before dense Vector: distinct marker attr).
582    sparse_dims = get_sparse_vector_dimensions(type_hint)
583    if sparse_dims is not None:
584        return f"sparse_vector:{sparse_dims}", nullable
585
586    # Check for Vector types
587    dims = get_vector_dimensions(type_hint)
588    if dims is not None:
589        return f"vector:{dims}", nullable
590
591    # Check for list types
592    is_lst, elem_type = is_list_type(type_hint)
593    if is_lst:
594        if elem_type in (str, int, float, bool):
595            # Simple list types
596            elem_uni = TYPE_MAP.get(elem_type, "string")
597            return f"list:{elem_uni}", nullable
598        # list[Vector[N]] (multi-vector / ColBERT) -> list:vector:N
599        elem_dims = get_vector_dimensions(elem_type)
600        if elem_dims is not None:
601            return f"list:vector:{elem_dims}", nullable
602        # Complex list types stored as JSON
603        return "json", nullable
604
605    # Direct type mapping
606    if type_hint in TYPE_MAP:
607        return TYPE_MAP[type_hint], nullable
608
609    # Handle generic dict types -> typed MAP<STRING, V> when the key is `str` and the
610    # value type maps to a concrete Uni type; otherwise schemaless JSON. Keys must be
611    # `str` because the storage value model (Value::Map) is string-keyed.
612    origin = get_origin(type_hint)
613    if origin is dict:
614        args = get_args(type_hint)
615        if len(args) == 2 and args[0] is str:
616            try:
617                val_uni, _ = python_type_to_uni(args[1])
618            except TypeMappingError:
619                val_uni = "json"
620            # Recurse handles nested values: dict[str, list[int]] -> map:string:list:int64,
621            # dict[str, Vector[N]] -> map:string:vector:N, dict[str, dict[str,int]] -> nested.
622            if val_uni != "json":
623                return f"map:string:{val_uni}", nullable
624        return "json", nullable
625
626    # Handle forward references (strings)
627    if isinstance(type_hint, str):
628        # This is a forward reference, can't resolve here
629        raise TypeMappingError(
630            type_hint,
631            f"Cannot resolve forward reference {type_hint!r}. "
632            "Ensure the referenced class is defined before schema sync.",
633        )
634
635    raise TypeMappingError(type_hint)

Convert a Python type hint to a Uni DataType string.

Args: type_hint: The Python type hint to convert. nullable: Whether the field is explicitly nullable.

Returns: Tuple of (uni_data_type, is_nullable)

Raises: TypeMappingError: If the type cannot be mapped.

def uni_to_python_type(uni_type: str) -> type:
638def uni_to_python_type(uni_type: str) -> type:
639    """
640    Convert a Uni DataType string to a Python type.
641
642    Args:
643        uni_type: The Uni data type string.
644
645    Returns:
646        The corresponding Python type.
647    """
648    # Reverse mapping — manually constructed to avoid bytes overwriting str for "string"
649    _REVERSE_MAP: dict[str, type] = {
650        "string": str,
651        "int64": int,
652        "float64": float,
653        "bool": bool,
654        "datetime": datetime,
655        "date": date,
656        "time": time,
657        "duration": timedelta,
658        "json": dict,
659        "btic": Btic,
660        "bytes": bytes,
661    }
662
663    # Handle vector types
664    if uni_type.startswith("vector:"):
665        return list  # Vectors are stored as list[float]
666
667    # Handle list types
668    if uni_type.startswith("list:"):
669        return list
670
671    return _REVERSE_MAP.get(uni_type.lower(), str)

Convert a Uni DataType string to a Python type.

Args: uni_type: The Uni data type string.

Returns: The corresponding Python type.

def get_vector_dimensions(type_hint: Any) -> int | None:
143def get_vector_dimensions(type_hint: Any) -> int | None:
144    """Extract vector dimensions from a Vector[N] type hint."""
145    if hasattr(type_hint, "__dimensions__"):
146        dims: int = type_hint.__dimensions__
147        return dims
148    origin = get_origin(type_hint)
149    if origin is Vector:
150        args = get_args(type_hint)
151        if args and isinstance(args[0], int):
152            return args[0]
153    return None

Extract vector dimensions from a Vector[N] type hint.

def get_sparse_vector_dimensions(type_hint: Any) -> int | None:
268def get_sparse_vector_dimensions(type_hint: Any) -> int | None:
269    """Extract the vocabulary size from a SparseVector[N] type hint."""
270    if hasattr(type_hint, "__sparse_dimensions__"):
271        dims: int = type_hint.__sparse_dimensions__
272        return dims
273    origin = get_origin(type_hint)
274    if origin is SparseVector:
275        args = get_args(type_hint)
276        if args and isinstance(args[0], int):
277            return args[0]
278    return None

Extract the vocabulary size from a SparseVector[N] type hint.

def is_optional(type_hint: Any) -> tuple[bool, typing.Any]:
402def is_optional(type_hint: Any) -> tuple[bool, Any]:
403    """
404    Check if a type hint is Optional (T | None).
405
406    Returns:
407        Tuple of (is_optional, inner_type)
408    """
409    origin = get_origin(type_hint)
410
411    # Handle Union types (including T | None which is Union[T, None])
412    if origin is Union:
413        args = get_args(type_hint)
414        non_none_args = [arg for arg in args if arg is not type(None)]
415        if len(non_none_args) == 1 and type(None) in args:
416            return True, non_none_args[0]
417
418    # Python 3.10+ uses types.UnionType for X | Y syntax
419    if isinstance(type_hint, types.UnionType):
420        args = get_args(type_hint)
421        non_none_args = [arg for arg in args if arg is not type(None)]
422        if len(non_none_args) == 1 and type(None) in args:
423            return True, non_none_args[0]
424
425    return False, type_hint

Check if a type hint is Optional (T | None).

Returns: Tuple of (is_optional, inner_type)

def is_list_type(type_hint: Any) -> tuple[bool, typing.Any | None]:
428def is_list_type(type_hint: Any) -> tuple[bool, Any | None]:
429    """
430    Check if a type hint is a list type.
431
432    Returns:
433        Tuple of (is_list, element_type)
434    """
435    origin = get_origin(type_hint)
436    if origin is list:
437        args = get_args(type_hint)
438        return True, args[0] if args else Any
439    return False, None

Check if a type hint is a list type.

Returns: Tuple of (is_list, element_type)

def unwrap_annotated(type_hint: Any) -> tuple[typing.Any, tuple[typing.Any, ...]]:
442def unwrap_annotated(type_hint: Any) -> tuple[Any, tuple[Any, ...]]:
443    """
444    Unwrap an Annotated type.
445
446    Returns:
447        Tuple of (base_type, metadata_tuple)
448    """
449    origin = get_origin(type_hint)
450    if origin is Annotated:
451        args = get_args(type_hint)
452        return args[0], args[1:]
453    return type_hint, ()

Unwrap an Annotated type.

Returns: Tuple of (base_type, metadata_tuple)

def python_to_db_value(value: Any, type_hint: Any) -> Any:
460def python_to_db_value(value: Any, type_hint: Any) -> Any:
461    """Convert a Python value to a database-compatible value.
462
463    Passes datetime/date/time/timedelta through to the Rust layer which
464    converts them to proper Value::Temporal types. Converts Vector to
465    list[float] and passes through everything else.
466    """
467    if value is None:
468        return None
469
470    # list[Vector[N]] → list[list[float]] (multi-vector / ColBERT)
471    if isinstance(value, list) and value and isinstance(value[0], Vector):
472        return [v.values if isinstance(v, Vector) else v for v in value]
473
474    # Vector → list[float]
475    if isinstance(value, Vector):
476        return value.values
477
478    # Btic → unwrap to the Rust PyBtic for py_object_to_value
479    if isinstance(value, Btic):
480        return value._inner
481
482    # datetime/date/time/timedelta pass through — the Rust py_object_to_value
483    # handles conversion to Value::Temporal with proper type information.
484    return value

Convert a Python value to a database-compatible value.

Passes datetime/date/time/timedelta through to the Rust layer which converts them to proper Value::Temporal types. Converts Vector to list[float] and passes through everything else.

def db_to_python_value(value: Any, type_hint: Any) -> Any:
487def db_to_python_value(value: Any, type_hint: Any) -> Any:
488    """Convert a database value back to a Python value.
489
490    The Rust layer now returns proper Python datetime/date/time objects
491    via Value::Temporal, so in most cases values pass through directly.
492    """
493    if value is None:
494        return None
495
496    # Unwrap Optional
497    _, inner = is_optional(type_hint)
498    if inner is not type_hint:
499        type_hint = inner
500
501    # Unwrap Annotated
502    type_hint, _ = unwrap_annotated(type_hint)
503
504    # If value is already the right Python type, pass through
505    if type_hint is datetime and isinstance(value, datetime):
506        return value
507    if type_hint is date and isinstance(value, date):
508        return value
509    if type_hint is time and isinstance(value, time):
510        return value
511    if type_hint is timedelta and isinstance(value, timedelta):
512        return value
513
514    # Btic — wrap Rust PyBtic in the pydantic Btic wrapper
515    if type_hint is Btic and _PyBtic is not None and isinstance(value, _PyBtic):
516        return Btic(value)
517
518    # Handle struct dict from Arrow deserialization (e.g. datetime struct)
519    if type_hint is datetime and isinstance(value, dict):
520        nanos = value.get("nanos_since_epoch")
521        if nanos is not None:
522            return datetime.fromtimestamp(nanos / 1_000_000_000)
523        return None
524
525    # list[Vector[N]] fields: list[list[float]] → list[Vector[N]] (multi-vector)
526    is_lst, elem_type = is_list_type(type_hint)
527    if is_lst:
528        elem_dims = get_vector_dimensions(elem_type)
529        if elem_dims is not None and isinstance(value, list):
530            vec_cls = Vector[elem_dims]
531            return [vec_cls(v) if isinstance(v, list) else v for v in value]
532
533    # Vector fields: list[float] → Vector
534    dims = get_vector_dimensions(type_hint)
535    if dims is not None and isinstance(value, list):
536        vec_cls = Vector[dims]
537        return vec_cls(value)
538
539    return value

Convert a database value back to a Python value.

The Rust layer now returns proper Python datetime/date/time objects via Value::Temporal, so in most cases values pass through directly.

DATETIME_TYPES = {<class 'datetime.date'>, <class 'datetime.datetime'>, <class 'datetime.timedelta'>, <class 'datetime.time'>, <class 'Btic'>}
class QueryBuilder(uni_pydantic.query._QueryBuilderBase[~NodeT]):
695class QueryBuilder(_QueryBuilderBase[NodeT]):
696    """
697    Immutable, type-safe query builder for graph queries.
698
699    Each method returns a **new** QueryBuilder instance. The original is
700    never mutated. Provides a fluent API for building Cypher queries
701    with type checking and IDE autocomplete support.
702
703    Example:
704        >>> adults = (
705        ...     session.query(Person)
706        ...     .filter(Person.age >= 18)
707        ...     .order_by(Person.name)
708        ...     .limit(10)
709        ...     .all()
710        ... )
711    """
712
713    def __init__(self, session: UniSession, model: type[NodeT]) -> None:
714        self._init_state(session, model)
715
716    def _execute_query(
717        self, cypher: str, params: dict[str, Any]
718    ) -> list[dict[str, Any]]:
719        """Execute a query, using query_with if timeout/max_memory is set."""
720        if self._timeout is not None or self._max_memory is not None:
721            builder = self._session._db_session.query_with(cypher)
722            if params:
723                builder = builder.params(params)
724            if self._timeout is not None:
725                builder = builder.timeout(self._timeout)
726            if self._max_memory is not None:
727                builder = builder.max_memory(self._max_memory)
728            result = builder.fetch_all()
729        else:
730            result = self._session._db_session.query(cypher, params)
731        return [row.to_dict() for row in result]
732
733    def all(self) -> list[NodeT]:
734        """Execute the query and return all results."""
735        cypher, params = self._build_cypher()
736        results = self._execute_query(cypher, params)
737        instances = self._rows_to_instances(results)
738        if self._eager_load and instances:
739            self._session._eager_load_relationships(instances, self._eager_load)
740        return instances
741
742    def first(self) -> NodeT | None:
743        """Execute the query and return the first result."""
744        clone = self._clone()
745        clone._limit = 1
746        results = clone.all()
747        return results[0] if results else None
748
749    def one(self) -> NodeT:
750        """Execute the query and return exactly one result.
751
752        Raises QueryError if no results or more than one result.
753        """
754        clone = self._clone()
755        clone._limit = 2
756        results = clone.all()
757        if not results:
758            raise QueryError("Query returned no results")
759        if len(results) > 1:
760            raise QueryError("Query returned more than one result")
761        return results[0]
762
763    def count(self) -> int:
764        """Execute the query and return the count of results."""
765        cypher, params = self._build_count_cypher()
766        results = self._execute_query(cypher, params)
767        return cast(int, results[0]["count"]) if results else 0
768
769    def exists(self) -> bool:
770        """Check if any matching records exist."""
771        cypher, params = self._build_exists_cypher()
772        results = self._execute_query(cypher, params)
773        return len(results) > 0
774
775    def delete(self) -> int:
776        """Delete all matching records (DETACH DELETE)."""
777        cypher, params = self._build_delete_cypher()
778        with self._session._db_session.tx() as tx:
779            results = tx.query(cypher, params)
780            tx.commit()
781        return results[0].to_dict()["count"] if results else 0
782
783    def update(self, **kwargs: Any) -> int:
784        """Update all matching records."""
785        cypher, params = self._build_update_cypher(**kwargs)
786        with self._session._db_session.tx() as tx:
787            results = tx.query(cypher, params)
788            tx.commit()
789        return results[0].to_dict()["count"] if results else 0

Immutable, type-safe query builder for graph queries.

Each method returns a new QueryBuilder instance. The original is never mutated. Provides a fluent API for building Cypher queries with type checking and IDE autocomplete support.

Example:

adults = ( ... session.query(Person) ... .filter(Person.age >= 18) ... .order_by(Person.name) ... .limit(10) ... .all() ... )

QueryBuilder(session: UniSession, model: type[~NodeT])
713    def __init__(self, session: UniSession, model: type[NodeT]) -> None:
714        self._init_state(session, model)
def all(self) -> list[~NodeT]:
733    def all(self) -> list[NodeT]:
734        """Execute the query and return all results."""
735        cypher, params = self._build_cypher()
736        results = self._execute_query(cypher, params)
737        instances = self._rows_to_instances(results)
738        if self._eager_load and instances:
739            self._session._eager_load_relationships(instances, self._eager_load)
740        return instances

Execute the query and return all results.

def first(self) -> Optional[~NodeT]:
742    def first(self) -> NodeT | None:
743        """Execute the query and return the first result."""
744        clone = self._clone()
745        clone._limit = 1
746        results = clone.all()
747        return results[0] if results else None

Execute the query and return the first result.

def one(self) -> ~NodeT:
749    def one(self) -> NodeT:
750        """Execute the query and return exactly one result.
751
752        Raises QueryError if no results or more than one result.
753        """
754        clone = self._clone()
755        clone._limit = 2
756        results = clone.all()
757        if not results:
758            raise QueryError("Query returned no results")
759        if len(results) > 1:
760            raise QueryError("Query returned more than one result")
761        return results[0]

Execute the query and return exactly one result.

Raises QueryError if no results or more than one result.

def count(self) -> int:
763    def count(self) -> int:
764        """Execute the query and return the count of results."""
765        cypher, params = self._build_count_cypher()
766        results = self._execute_query(cypher, params)
767        return cast(int, results[0]["count"]) if results else 0

Execute the query and return the count of results.

def exists(self) -> bool:
769    def exists(self) -> bool:
770        """Check if any matching records exist."""
771        cypher, params = self._build_exists_cypher()
772        results = self._execute_query(cypher, params)
773        return len(results) > 0

Check if any matching records exist.

def delete(self) -> int:
775    def delete(self) -> int:
776        """Delete all matching records (DETACH DELETE)."""
777        cypher, params = self._build_delete_cypher()
778        with self._session._db_session.tx() as tx:
779            results = tx.query(cypher, params)
780            tx.commit()
781        return results[0].to_dict()["count"] if results else 0

Delete all matching records (DETACH DELETE).

def update(self, **kwargs: Any) -> int:
783    def update(self, **kwargs: Any) -> int:
784        """Update all matching records."""
785        cypher, params = self._build_update_cypher(**kwargs)
786        with self._session._db_session.tx() as tx:
787            results = tx.query(cypher, params)
788            tx.commit()
789        return results[0].to_dict()["count"] if results else 0

Update all matching records.

class AsyncQueryBuilder(uni_pydantic.query._QueryBuilderBase[~NodeT]):
 26class AsyncQueryBuilder(_QueryBuilderBase[NodeT]):
 27    """
 28    Immutable, async query builder for graph queries.
 29
 30    Inherits all Cypher-building and immutable builder methods from
 31    ``_QueryBuilderBase``. Only the execution methods are async.
 32    """
 33
 34    def __init__(self, session: AsyncUniSession, model: type[NodeT]) -> None:
 35        self._init_state(session, model)
 36
 37    async def _execute_query(
 38        self, cypher: str, params: dict[str, Any]
 39    ) -> list[dict[str, Any]]:
 40        """Execute a query, using query_with if timeout/max_memory is set."""
 41        if self._timeout is not None or self._max_memory is not None:
 42            builder = self._session._db_session.query_with(cypher)
 43            if params:
 44                builder = builder.params(params)
 45            if self._timeout is not None:
 46                builder = builder.timeout(self._timeout)
 47            if self._max_memory is not None:
 48                builder = builder.max_memory(self._max_memory)
 49            result = await builder.fetch_all()
 50        else:
 51            result = await self._session._db_session.query(cypher, params)
 52        return [row.to_dict() for row in result]
 53
 54    async def all(self) -> list[NodeT]:
 55        """Execute the query and return all results."""
 56        cypher, params = self._build_cypher()
 57        results = await self._execute_query(cypher, params)
 58        instances = self._rows_to_instances(results)
 59        if self._eager_load and instances:
 60            await self._session._async_eager_load_relationships(
 61                instances, self._eager_load
 62            )
 63        return instances
 64
 65    async def first(self) -> NodeT | None:
 66        """Execute the query and return the first result."""
 67        clone = self._clone()
 68        clone._limit = 1
 69        results = await clone.all()
 70        return results[0] if results else None
 71
 72    async def one(self) -> NodeT:
 73        """Execute the query and return exactly one result.
 74
 75        Raises QueryError if no results or more than one result.
 76        """
 77        clone = self._clone()
 78        clone._limit = 2
 79        results = await clone.all()
 80        if not results:
 81            raise QueryError("Query returned no results")
 82        if len(results) > 1:
 83            raise QueryError("Query returned more than one result")
 84        return results[0]
 85
 86    async def count(self) -> int:
 87        """Execute the query and return the count of results."""
 88        cypher, params = self._build_count_cypher()
 89        results = await self._execute_query(cypher, params)
 90        return cast(int, results[0]["count"]) if results else 0
 91
 92    async def exists(self) -> bool:
 93        """Check if any matching records exist."""
 94        cypher, params = self._build_exists_cypher()
 95        results = await self._execute_query(cypher, params)
 96        return len(results) > 0
 97
 98    async def delete(self) -> int:
 99        """Delete all matching records (DETACH DELETE)."""
100        cypher, params = self._build_delete_cypher()
101        async with await self._session._db_session.tx() as tx:
102            results = await tx.query(cypher, params)
103            await tx.commit()
104        return results[0].to_dict()["count"] if results else 0
105
106    async def update(self, **kwargs: Any) -> int:
107        """Update all matching records."""
108        cypher, params = self._build_update_cypher(**kwargs)
109        async with await self._session._db_session.tx() as tx:
110            results = await tx.query(cypher, params)
111            await tx.commit()
112        return results[0].to_dict()["count"] if results else 0

Immutable, async query builder for graph queries.

Inherits all Cypher-building and immutable builder methods from _QueryBuilderBase. Only the execution methods are async.

AsyncQueryBuilder( session: AsyncUniSession, model: type[~NodeT])
34    def __init__(self, session: AsyncUniSession, model: type[NodeT]) -> None:
35        self._init_state(session, model)
async def all(self) -> list[~NodeT]:
54    async def all(self) -> list[NodeT]:
55        """Execute the query and return all results."""
56        cypher, params = self._build_cypher()
57        results = await self._execute_query(cypher, params)
58        instances = self._rows_to_instances(results)
59        if self._eager_load and instances:
60            await self._session._async_eager_load_relationships(
61                instances, self._eager_load
62            )
63        return instances

Execute the query and return all results.

async def first(self) -> Optional[~NodeT]:
65    async def first(self) -> NodeT | None:
66        """Execute the query and return the first result."""
67        clone = self._clone()
68        clone._limit = 1
69        results = await clone.all()
70        return results[0] if results else None

Execute the query and return the first result.

async def one(self) -> ~NodeT:
72    async def one(self) -> NodeT:
73        """Execute the query and return exactly one result.
74
75        Raises QueryError if no results or more than one result.
76        """
77        clone = self._clone()
78        clone._limit = 2
79        results = await clone.all()
80        if not results:
81            raise QueryError("Query returned no results")
82        if len(results) > 1:
83            raise QueryError("Query returned more than one result")
84        return results[0]

Execute the query and return exactly one result.

Raises QueryError if no results or more than one result.

async def count(self) -> int:
86    async def count(self) -> int:
87        """Execute the query and return the count of results."""
88        cypher, params = self._build_count_cypher()
89        results = await self._execute_query(cypher, params)
90        return cast(int, results[0]["count"]) if results else 0

Execute the query and return the count of results.

async def exists(self) -> bool:
92    async def exists(self) -> bool:
93        """Check if any matching records exist."""
94        cypher, params = self._build_exists_cypher()
95        results = await self._execute_query(cypher, params)
96        return len(results) > 0

Check if any matching records exist.

async def delete(self) -> int:
 98    async def delete(self) -> int:
 99        """Delete all matching records (DETACH DELETE)."""
100        cypher, params = self._build_delete_cypher()
101        async with await self._session._db_session.tx() as tx:
102            results = await tx.query(cypher, params)
103            await tx.commit()
104        return results[0].to_dict()["count"] if results else 0

Delete all matching records (DETACH DELETE).

async def update(self, **kwargs: Any) -> int:
106    async def update(self, **kwargs: Any) -> int:
107        """Update all matching records."""
108        cypher, params = self._build_update_cypher(**kwargs)
109        async with await self._session._db_session.tx() as tx:
110            results = await tx.query(cypher, params)
111            await tx.commit()
112        return results[0].to_dict()["count"] if results else 0

Update all matching records.

@dataclass
class FilterExpr:
116@dataclass
117class FilterExpr:
118    """A filter expression for a query."""
119
120    property_name: str
121    op: FilterOp
122    value: Any = None
123
124    def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]:
125        """Convert to Cypher WHERE clause fragment."""
126        prop = f"{node_var}.{self.property_name}"
127
128        if self.op == FilterOp.IS_NULL:
129            return f"{prop} IS NULL", {}
130        elif self.op == FilterOp.IS_NOT_NULL:
131            return f"{prop} IS NOT NULL", {}
132        elif self.op == FilterOp.IN:
133            return f"{prop} IN ${param_name}", {param_name: self.value}
134        elif self.op == FilterOp.NOT_IN:
135            return f"NOT {prop} IN ${param_name}", {param_name: self.value}
136        elif self.op == FilterOp.LIKE:
137            return f"{prop} =~ ${param_name}", {param_name: self.value}
138        elif self.op == FilterOp.STARTS_WITH:
139            return f"{prop} STARTS WITH ${param_name}", {param_name: self.value}
140        elif self.op == FilterOp.ENDS_WITH:
141            return f"{prop} ENDS WITH ${param_name}", {param_name: self.value}
142        elif self.op == FilterOp.CONTAINS:
143            return f"{prop} CONTAINS ${param_name}", {param_name: self.value}
144        else:
145            return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}

A filter expression for a query.

FilterExpr( property_name: str, op: FilterOp, value: Any = None)
property_name: str
op: FilterOp
value: Any = None
def to_cypher( self, node_var: str, param_name: str) -> tuple[str, dict[str, typing.Any]]:
124    def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]:
125        """Convert to Cypher WHERE clause fragment."""
126        prop = f"{node_var}.{self.property_name}"
127
128        if self.op == FilterOp.IS_NULL:
129            return f"{prop} IS NULL", {}
130        elif self.op == FilterOp.IS_NOT_NULL:
131            return f"{prop} IS NOT NULL", {}
132        elif self.op == FilterOp.IN:
133            return f"{prop} IN ${param_name}", {param_name: self.value}
134        elif self.op == FilterOp.NOT_IN:
135            return f"NOT {prop} IN ${param_name}", {param_name: self.value}
136        elif self.op == FilterOp.LIKE:
137            return f"{prop} =~ ${param_name}", {param_name: self.value}
138        elif self.op == FilterOp.STARTS_WITH:
139            return f"{prop} STARTS WITH ${param_name}", {param_name: self.value}
140        elif self.op == FilterOp.ENDS_WITH:
141            return f"{prop} ENDS WITH ${param_name}", {param_name: self.value}
142        elif self.op == FilterOp.CONTAINS:
143            return f"{prop} CONTAINS ${param_name}", {param_name: self.value}
144        else:
145            return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}

Convert to Cypher WHERE clause fragment.

class FilterOp(enum.Enum):
 97class FilterOp(Enum):
 98    """Filter operation types."""
 99
100    EQ = "="
101    NE = "<>"
102    LT = "<"
103    LE = "<="
104    GT = ">"
105    GE = ">="
106    IN = "IN"
107    NOT_IN = "NOT IN"
108    LIKE = "=~"
109    IS_NULL = "IS NULL"
110    IS_NOT_NULL = "IS NOT NULL"
111    STARTS_WITH = "STARTS WITH"
112    ENDS_WITH = "ENDS WITH"
113    CONTAINS = "CONTAINS"

Filter operation types.

EQ = <FilterOp.EQ: '='>
NE = <FilterOp.NE: '<>'>
LT = <FilterOp.LT: '<'>
LE = <FilterOp.LE: '<='>
GT = <FilterOp.GT: '>'>
GE = <FilterOp.GE: '>='>
IN = <FilterOp.IN: 'IN'>
NOT_IN = <FilterOp.NOT_IN: 'NOT IN'>
LIKE = <FilterOp.LIKE: '=~'>
IS_NULL = <FilterOp.IS_NULL: 'IS NULL'>
IS_NOT_NULL = <FilterOp.IS_NOT_NULL: 'IS NOT NULL'>
STARTS_WITH = <FilterOp.STARTS_WITH: 'STARTS WITH'>
ENDS_WITH = <FilterOp.ENDS_WITH: 'ENDS WITH'>
CONTAINS = <FilterOp.CONTAINS: 'CONTAINS'>
class PropertyProxy(typing.Generic[~T]):
148class PropertyProxy(Generic[T]):
149    """
150    Proxy for model properties that enables filter expressions.
151
152    Used in query builder to create type-safe filter conditions.
153
154    Example:
155        >>> query.filter(Person.age >= 18)
156        >>> query.filter(Person.name.starts_with("A"))
157    """
158
159    def __init__(self, property_name: str, model: type[UniNode]) -> None:
160        self._property_name = property_name
161        self._model = model
162
163    def __eq__(self, other: Any) -> FilterExpr:  # type: ignore[override]
164        return FilterExpr(self._property_name, FilterOp.EQ, other)
165
166    def __ne__(self, other: Any) -> FilterExpr:  # type: ignore[override]
167        return FilterExpr(self._property_name, FilterOp.NE, other)
168
169    def __lt__(self, other: Any) -> FilterExpr:
170        return FilterExpr(self._property_name, FilterOp.LT, other)
171
172    def __le__(self, other: Any) -> FilterExpr:
173        return FilterExpr(self._property_name, FilterOp.LE, other)
174
175    def __gt__(self, other: Any) -> FilterExpr:
176        return FilterExpr(self._property_name, FilterOp.GT, other)
177
178    def __ge__(self, other: Any) -> FilterExpr:
179        return FilterExpr(self._property_name, FilterOp.GE, other)
180
181    def in_(self, values: Sequence[T]) -> FilterExpr:
182        """Check if value is in a list."""
183        return FilterExpr(self._property_name, FilterOp.IN, list(values))
184
185    def not_in(self, values: Sequence[T]) -> FilterExpr:
186        """Check if value is not in a list."""
187        return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values))
188
189    def like(self, pattern: str) -> FilterExpr:
190        """Match a regex pattern."""
191        return FilterExpr(self._property_name, FilterOp.LIKE, pattern)
192
193    def is_null(self) -> FilterExpr:
194        """Check if value is null."""
195        return FilterExpr(self._property_name, FilterOp.IS_NULL)
196
197    def is_not_null(self) -> FilterExpr:
198        """Check if value is not null."""
199        return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL)
200
201    def starts_with(self, prefix: str) -> FilterExpr:
202        """Check if string starts with prefix."""
203        return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix)
204
205    def ends_with(self, suffix: str) -> FilterExpr:
206        """Check if string ends with suffix."""
207        return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix)
208
209    def contains(self, substring: str) -> FilterExpr:
210        """Check if string contains substring."""
211        return FilterExpr(self._property_name, FilterOp.CONTAINS, substring)

Proxy for model properties that enables filter expressions.

Used in query builder to create type-safe filter conditions.

Example:

query.filter(Person.age >= 18) query.filter(Person.name.starts_with("A"))

PropertyProxy(property_name: str, model: type[UniNode])
159    def __init__(self, property_name: str, model: type[UniNode]) -> None:
160        self._property_name = property_name
161        self._model = model
def in_(self, values: Sequence[~T]) -> FilterExpr:
181    def in_(self, values: Sequence[T]) -> FilterExpr:
182        """Check if value is in a list."""
183        return FilterExpr(self._property_name, FilterOp.IN, list(values))

Check if value is in a list.

def not_in(self, values: Sequence[~T]) -> FilterExpr:
185    def not_in(self, values: Sequence[T]) -> FilterExpr:
186        """Check if value is not in a list."""
187        return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values))

Check if value is not in a list.

def like(self, pattern: str) -> FilterExpr:
189    def like(self, pattern: str) -> FilterExpr:
190        """Match a regex pattern."""
191        return FilterExpr(self._property_name, FilterOp.LIKE, pattern)

Match a regex pattern.

def is_null(self) -> FilterExpr:
193    def is_null(self) -> FilterExpr:
194        """Check if value is null."""
195        return FilterExpr(self._property_name, FilterOp.IS_NULL)

Check if value is null.

def is_not_null(self) -> FilterExpr:
197    def is_not_null(self) -> FilterExpr:
198        """Check if value is not null."""
199        return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL)

Check if value is not null.

def starts_with(self, prefix: str) -> FilterExpr:
201    def starts_with(self, prefix: str) -> FilterExpr:
202        """Check if string starts with prefix."""
203        return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix)

Check if string starts with prefix.

def ends_with(self, suffix: str) -> FilterExpr:
205    def ends_with(self, suffix: str) -> FilterExpr:
206        """Check if string ends with suffix."""
207        return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix)

Check if string ends with suffix.

def contains(self, substring: str) -> FilterExpr:
209    def contains(self, substring: str) -> FilterExpr:
210        """Check if string contains substring."""
211        return FilterExpr(self._property_name, FilterOp.CONTAINS, substring)

Check if string contains substring.

class ModelProxy(typing.Generic[~NodeT]):
214class ModelProxy(Generic[NodeT]):
215    """
216    Proxy for model classes that provides property proxies.
217
218    Enables type-safe property access in query filters.
219
220    Example:
221        >>> Person.name  # Returns PropertyProxy for 'name'
222        >>> query.filter(Person.age >= 18)
223    """
224
225    def __init__(self, model: type[NodeT]) -> None:
226        self._model = model
227
228    def __getattr__(self, name: str) -> PropertyProxy[Any]:
229        if name.startswith("_"):
230            raise AttributeError(name)
231        return PropertyProxy(name, self._model)

Proxy for model classes that provides property proxies.

Enables type-safe property access in query filters.

Example:

Person.name # Returns PropertyProxy for 'name' query.filter(Person.age >= 18)

ModelProxy(model: type[~NodeT])
225    def __init__(self, model: type[NodeT]) -> None:
226        self._model = model
@dataclass
class OrderByClause:
234@dataclass
235class OrderByClause:
236    """An ORDER BY clause."""
237
238    property_name: str
239    descending: bool = False

An ORDER BY clause.

OrderByClause(property_name: str, descending: bool = False)
property_name: str
descending: bool = False
@dataclass
class TraversalStep:
242@dataclass
243class TraversalStep:
244    """A relationship traversal step."""
245
246    edge_type: str
247    direction: Literal["outgoing", "incoming", "both"]
248    target_label: str | None = None

A relationship traversal step.

TraversalStep( edge_type: str, direction: Literal['outgoing', 'incoming', 'both'], target_label: str | None = None)
edge_type: str
direction: Literal['outgoing', 'incoming', 'both']
target_label: str | None = None
@dataclass
class VectorSearchConfig:
251@dataclass
252class VectorSearchConfig:
253    """Configuration for vector similarity search."""
254
255    property_name: str
256    query_vector: list[float]
257    k: int
258    threshold: float | None = None
259    pre_filter: str | None = None

Configuration for vector similarity search.

VectorSearchConfig( property_name: str, query_vector: list[float], k: int, threshold: float | None = None, pre_filter: str | None = None)
property_name: str
query_vector: list[float]
k: int
threshold: float | None = None
pre_filter: str | None = None
@dataclass
class SparseSearchConfig:
262@dataclass
263class SparseSearchConfig:
264    """Configuration for learned-sparse (SPLADE) similarity search."""
265
266    property_name: str
267    query_indices: list[int]
268    query_values: list[float]
269    k: int
270    threshold: float | None = None
271    pre_filter: str | None = None

Configuration for learned-sparse (SPLADE) similarity search.

SparseSearchConfig( property_name: str, query_indices: list[int], query_values: list[float], k: int, threshold: float | None = None, pre_filter: str | None = None)
property_name: str
query_indices: list[int]
query_values: list[float]
k: int
threshold: float | None = None
pre_filter: str | None = None
class SchemaGenerator:
 65class SchemaGenerator:
 66    """Generates Uni database schema from registered models."""
 67
 68    def __init__(self) -> None:
 69        self._node_models: dict[str, type[UniNode]] = {}
 70        self._edge_models: dict[str, type[UniEdge]] = {}
 71        self._schema: DatabaseSchema | None = None
 72
 73    def register_node(self, model: type[UniNode]) -> None:
 74        """Register a node model for schema generation."""
 75        label = model.__label__
 76        if not label:
 77            raise SchemaError(f"Model {model.__name__} has no __label__", model)
 78        self._node_models[label] = model
 79        self._schema = None  # Invalidate cached schema
 80
 81    def register_edge(self, model: type[UniEdge]) -> None:
 82        """Register an edge model for schema generation."""
 83        edge_type = model.__edge_type__
 84        if not edge_type:
 85            raise SchemaError(f"Model {model.__name__} has no __edge_type__", model)
 86        self._edge_models[edge_type] = model
 87        self._schema = None
 88
 89    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
 90        """Register multiple models."""
 91        for model in models:
 92            if issubclass(model, UniEdge):
 93                self.register_edge(model)
 94            elif issubclass(model, UniNode):
 95                self.register_node(model)
 96            else:
 97                raise SchemaError(
 98                    f"Model {model.__name__} must be a subclass of UniNode or UniEdge"
 99                )
100
101    def _generate_property_schema(
102        self,
103        model: type[UniNode] | type[UniEdge],
104        field_name: str,
105    ) -> PropertySchema:
106        """Generate schema for a single property field."""
107        field_info = model.model_fields[field_name]
108
109        # Get type hints with forward refs resolved
110        try:
111            hints = get_type_hints(model)
112            type_hint = hints.get(field_name, field_info.annotation)
113        except Exception:
114            type_hint = field_info.annotation
115
116        # Check for nullability
117        is_nullable, inner_type = is_optional(type_hint)
118
119        # Get Uni data type
120        data_type, nullable = python_type_to_uni(type_hint, nullable=is_nullable)
121
122        # Check for vector dimensions
123        vec_dims = get_vector_dimensions(inner_type if is_nullable else type_hint)
124        if vec_dims:
125            data_type = f"vector:{vec_dims}"
126
127        # Check for sparse-vector dimensions (vocabulary size)
128        sparse_dims = get_sparse_vector_dimensions(
129            inner_type if is_nullable else type_hint
130        )
131        if sparse_dims:
132            data_type = f"sparse_vector:{sparse_dims}"
133
134        # Get field config for index settings
135        config = get_field_config(field_info)
136        index_type = config.index if config else None
137        unique = config.unique if config else False
138        tokenizer = config.tokenizer if config else None
139        metric = config.metric if config else None
140
141        # Auto-create vector index for Vector fields (regardless of Field config)
142        if vec_dims and not index_type:
143            index_type = "vector"
144
145        # Auto-create sparse index for SparseVector fields
146        if sparse_dims and not index_type:
147            index_type = "sparse"
148
149        return PropertySchema(
150            name=field_name,
151            data_type=data_type,
152            nullable=nullable,
153            index_type=index_type,
154            unique=unique,
155            tokenizer=tokenizer,
156            metric=metric,
157        )
158
159    def _generate_label_schema(self, model: type[UniNode]) -> LabelSchema:
160        """Generate schema for a node model."""
161        label = model.__label__
162
163        properties = {}
164        for field_name in model.get_property_fields():
165            prop_schema = self._generate_property_schema(model, field_name)
166            properties[field_name] = prop_schema
167
168        return LabelSchema(
169            name=label,
170            properties=properties,
171        )
172
173    def _generate_edge_type_schema(self, model: type[UniEdge]) -> EdgeTypeSchema:
174        """Generate schema for an edge model."""
175        edge_type = model.__edge_type__
176        from_labels = model.get_from_labels()
177        to_labels = model.get_to_labels()
178
179        # If from/to not specified, allow any labels
180        if not from_labels:
181            from_labels = list(self._node_models.keys())
182        if not to_labels:
183            to_labels = list(self._node_models.keys())
184
185        properties = {}
186        for field_name in model.get_property_fields():
187            prop_schema = self._generate_property_schema(model, field_name)
188            properties[field_name] = prop_schema
189
190        return EdgeTypeSchema(
191            name=edge_type,
192            from_labels=from_labels,
193            to_labels=to_labels,
194            properties=properties,
195        )
196
197    def generate(self) -> DatabaseSchema:
198        """Generate the complete database schema."""
199        if self._schema is not None:
200            return self._schema
201
202        schema = DatabaseSchema()
203
204        # Generate label schemas
205        for label, model in self._node_models.items():
206            schema.labels[label] = self._generate_label_schema(model)
207
208        # Generate edge type schemas
209        for edge_type_name, edge_model in self._edge_models.items():
210            schema.edge_types[edge_type_name] = self._generate_edge_type_schema(
211                edge_model
212            )
213
214        # Also generate labels from relationships in node models
215        for model in self._node_models.values():
216            for rel_name, rel_config in model.get_relationship_fields().items():
217                edge_type = rel_config.edge_type
218                if edge_type not in schema.edge_types:
219                    # Create a minimal edge type schema
220                    schema.edge_types[edge_type] = EdgeTypeSchema(
221                        name=edge_type,
222                        from_labels=list(self._node_models.keys()),
223                        to_labels=list(self._node_models.keys()),
224                    )
225
226        self._schema = schema
227        return schema
228
229    def apply_to_database(self, db: uni_db.Uni) -> None:
230        """Apply the generated schema to a database using SchemaBuilder.
231
232        Uses db.schema() for atomic schema application with additive-only
233        semantics. Creates labels, edge types, properties, and indexes.
234        """
235        schema = self.generate()
236
237        # Build the full schema using SchemaBuilder, skipping existing labels/edge types
238        builder = db.schema()
239        has_changes = False
240
241        for label, label_schema in schema.labels.items():
242            if db.label_exists(label):
243                continue  # Additive-only: skip existing labels
244            lb = builder.label(label)
245            for prop in label_schema.properties.values():
246                # Check for vector type
247                if prop.data_type.startswith("vector:"):
248                    dims = int(prop.data_type.split(":")[1])
249                    lb = lb.vector(prop.name, dims)
250                elif prop.nullable:
251                    lb = lb.property_nullable(prop.name, prop.data_type)
252                else:
253                    lb = lb.property(prop.name, prop.data_type)
254
255                # Add indexes (not vector — vector is handled by .vector())
256                if prop.index_type and prop.index_type in ("btree", "hash"):
257                    lb = lb.index(prop.name, prop.index_type)
258            builder = lb.done()
259            has_changes = True
260
261        for edge_type, edge_schema in schema.edge_types.items():
262            if db.edge_type_exists(edge_type):
263                continue  # Skip existing edge types
264            eb = builder.edge_type(
265                edge_type, edge_schema.from_labels, edge_schema.to_labels
266            )
267            for prop in edge_schema.properties.values():
268                if prop.nullable:
269                    eb = eb.property_nullable(prop.name, prop.data_type)
270                else:
271                    eb = eb.property(prop.name, prop.data_type)
272            builder = eb.done()
273            has_changes = True
274
275        if has_changes:
276            builder.apply()
277
278        # Create vector and fulltext indexes via schema builder
279        for label, label_schema in schema.labels.items():
280            for prop in label_schema.properties.values():
281                if prop.index_type == "vector":
282                    metric = prop.metric or "l2"
283                    try:
284                        db.schema().label(label).index(
285                            prop.name, {"type": "vector", "metric": metric}
286                        ).apply()
287                    except Exception:
288                        pass  # Index may already exist
289                elif prop.index_type == "sparse":
290                    try:
291                        cfg = {"type": "sparse"}
292                        if prop.data_type.startswith("sparse_vector:"):
293                            cfg["dimensions"] = int(prop.data_type.split(":")[1])
294                        db.schema().label(label).index(prop.name, cfg).apply()
295                    except Exception:
296                        pass  # Index may already exist
297                elif prop.index_type == "fulltext":
298                    try:
299                        db.schema().label(label).index(prop.name, "fulltext").apply()
300                    except Exception:
301                        pass  # Index may already exist
302
303    async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None:
304        """Apply the generated schema to an async database.
305
306        Async variant of apply_to_database using AsyncSchemaBuilder.
307        """
308        schema = self.generate()
309
310        # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types
311        builder = db.schema()
312        has_changes = False
313
314        for label, label_schema in schema.labels.items():
315            if await db.label_exists(label):
316                continue
317            lb = builder.label(label)
318            for prop in label_schema.properties.values():
319                if prop.data_type.startswith("vector:"):
320                    dims = int(prop.data_type.split(":")[1])
321                    lb = lb.vector(prop.name, dims)
322                elif prop.nullable:
323                    lb = lb.property_nullable(prop.name, prop.data_type)
324                else:
325                    lb = lb.property(prop.name, prop.data_type)
326
327                if prop.index_type and prop.index_type in ("btree", "hash"):
328                    lb = lb.index(prop.name, prop.index_type)
329            builder = lb.done()
330            has_changes = True
331
332        for edge_type, edge_schema in schema.edge_types.items():
333            if await db.edge_type_exists(edge_type):
334                continue
335            eb = builder.edge_type(
336                edge_type, edge_schema.from_labels, edge_schema.to_labels
337            )
338            for prop in edge_schema.properties.values():
339                if prop.nullable:
340                    eb = eb.property_nullable(prop.name, prop.data_type)
341                else:
342                    eb = eb.property(prop.name, prop.data_type)
343            builder = eb.done()
344            has_changes = True
345
346        if has_changes:
347            await builder.apply()
348
349        # Create vector and fulltext indexes via schema builder
350        for label, label_schema in schema.labels.items():
351            for prop in label_schema.properties.values():
352                if prop.index_type == "vector":
353                    metric = prop.metric or "l2"
354                    try:
355                        await (
356                            db.schema()
357                            .label(label)
358                            .index(prop.name, {"type": "vector", "metric": metric})
359                            .apply()
360                        )
361                    except Exception:
362                        pass  # Index may already exist
363                elif prop.index_type == "sparse":
364                    try:
365                        cfg = {"type": "sparse"}
366                        if prop.data_type.startswith("sparse_vector:"):
367                            cfg["dimensions"] = int(prop.data_type.split(":")[1])
368                        await db.schema().label(label).index(prop.name, cfg).apply()
369                    except Exception:
370                        pass  # Index may already exist
371                elif prop.index_type == "fulltext":
372                    try:
373                        await (
374                            db.schema()
375                            .label(label)
376                            .index(prop.name, "fulltext")
377                            .apply()
378                        )
379                    except Exception:
380                        pass  # Index may already exist

Generates Uni database schema from registered models.

def register_node(self, model: type[UniNode]) -> None:
73    def register_node(self, model: type[UniNode]) -> None:
74        """Register a node model for schema generation."""
75        label = model.__label__
76        if not label:
77            raise SchemaError(f"Model {model.__name__} has no __label__", model)
78        self._node_models[label] = model
79        self._schema = None  # Invalidate cached schema

Register a node model for schema generation.

def register_edge(self, model: type[UniEdge]) -> None:
81    def register_edge(self, model: type[UniEdge]) -> None:
82        """Register an edge model for schema generation."""
83        edge_type = model.__edge_type__
84        if not edge_type:
85            raise SchemaError(f"Model {model.__name__} has no __edge_type__", model)
86        self._edge_models[edge_type] = model
87        self._schema = None

Register an edge model for schema generation.

def register( self, *models: type[UniNode] | type[UniEdge]) -> None:
89    def register(self, *models: type[UniNode] | type[UniEdge]) -> None:
90        """Register multiple models."""
91        for model in models:
92            if issubclass(model, UniEdge):
93                self.register_edge(model)
94            elif issubclass(model, UniNode):
95                self.register_node(model)
96            else:
97                raise SchemaError(
98                    f"Model {model.__name__} must be a subclass of UniNode or UniEdge"
99                )

Register multiple models.

def generate(self) -> DatabaseSchema:
197    def generate(self) -> DatabaseSchema:
198        """Generate the complete database schema."""
199        if self._schema is not None:
200            return self._schema
201
202        schema = DatabaseSchema()
203
204        # Generate label schemas
205        for label, model in self._node_models.items():
206            schema.labels[label] = self._generate_label_schema(model)
207
208        # Generate edge type schemas
209        for edge_type_name, edge_model in self._edge_models.items():
210            schema.edge_types[edge_type_name] = self._generate_edge_type_schema(
211                edge_model
212            )
213
214        # Also generate labels from relationships in node models
215        for model in self._node_models.values():
216            for rel_name, rel_config in model.get_relationship_fields().items():
217                edge_type = rel_config.edge_type
218                if edge_type not in schema.edge_types:
219                    # Create a minimal edge type schema
220                    schema.edge_types[edge_type] = EdgeTypeSchema(
221                        name=edge_type,
222                        from_labels=list(self._node_models.keys()),
223                        to_labels=list(self._node_models.keys()),
224                    )
225
226        self._schema = schema
227        return schema

Generate the complete database schema.

def apply_to_database(self, db: Uni) -> None:
229    def apply_to_database(self, db: uni_db.Uni) -> None:
230        """Apply the generated schema to a database using SchemaBuilder.
231
232        Uses db.schema() for atomic schema application with additive-only
233        semantics. Creates labels, edge types, properties, and indexes.
234        """
235        schema = self.generate()
236
237        # Build the full schema using SchemaBuilder, skipping existing labels/edge types
238        builder = db.schema()
239        has_changes = False
240
241        for label, label_schema in schema.labels.items():
242            if db.label_exists(label):
243                continue  # Additive-only: skip existing labels
244            lb = builder.label(label)
245            for prop in label_schema.properties.values():
246                # Check for vector type
247                if prop.data_type.startswith("vector:"):
248                    dims = int(prop.data_type.split(":")[1])
249                    lb = lb.vector(prop.name, dims)
250                elif prop.nullable:
251                    lb = lb.property_nullable(prop.name, prop.data_type)
252                else:
253                    lb = lb.property(prop.name, prop.data_type)
254
255                # Add indexes (not vector — vector is handled by .vector())
256                if prop.index_type and prop.index_type in ("btree", "hash"):
257                    lb = lb.index(prop.name, prop.index_type)
258            builder = lb.done()
259            has_changes = True
260
261        for edge_type, edge_schema in schema.edge_types.items():
262            if db.edge_type_exists(edge_type):
263                continue  # Skip existing edge types
264            eb = builder.edge_type(
265                edge_type, edge_schema.from_labels, edge_schema.to_labels
266            )
267            for prop in edge_schema.properties.values():
268                if prop.nullable:
269                    eb = eb.property_nullable(prop.name, prop.data_type)
270                else:
271                    eb = eb.property(prop.name, prop.data_type)
272            builder = eb.done()
273            has_changes = True
274
275        if has_changes:
276            builder.apply()
277
278        # Create vector and fulltext indexes via schema builder
279        for label, label_schema in schema.labels.items():
280            for prop in label_schema.properties.values():
281                if prop.index_type == "vector":
282                    metric = prop.metric or "l2"
283                    try:
284                        db.schema().label(label).index(
285                            prop.name, {"type": "vector", "metric": metric}
286                        ).apply()
287                    except Exception:
288                        pass  # Index may already exist
289                elif prop.index_type == "sparse":
290                    try:
291                        cfg = {"type": "sparse"}
292                        if prop.data_type.startswith("sparse_vector:"):
293                            cfg["dimensions"] = int(prop.data_type.split(":")[1])
294                        db.schema().label(label).index(prop.name, cfg).apply()
295                    except Exception:
296                        pass  # Index may already exist
297                elif prop.index_type == "fulltext":
298                    try:
299                        db.schema().label(label).index(prop.name, "fulltext").apply()
300                    except Exception:
301                        pass  # Index may already exist

Apply the generated schema to a database using SchemaBuilder.

Uses db.schema() for atomic schema application with additive-only semantics. Creates labels, edge types, properties, and indexes.

async def async_apply_to_database(self, db: AsyncUni) -> None:
303    async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None:
304        """Apply the generated schema to an async database.
305
306        Async variant of apply_to_database using AsyncSchemaBuilder.
307        """
308        schema = self.generate()
309
310        # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types
311        builder = db.schema()
312        has_changes = False
313
314        for label, label_schema in schema.labels.items():
315            if await db.label_exists(label):
316                continue
317            lb = builder.label(label)
318            for prop in label_schema.properties.values():
319                if prop.data_type.startswith("vector:"):
320                    dims = int(prop.data_type.split(":")[1])
321                    lb = lb.vector(prop.name, dims)
322                elif prop.nullable:
323                    lb = lb.property_nullable(prop.name, prop.data_type)
324                else:
325                    lb = lb.property(prop.name, prop.data_type)
326
327                if prop.index_type and prop.index_type in ("btree", "hash"):
328                    lb = lb.index(prop.name, prop.index_type)
329            builder = lb.done()
330            has_changes = True
331
332        for edge_type, edge_schema in schema.edge_types.items():
333            if await db.edge_type_exists(edge_type):
334                continue
335            eb = builder.edge_type(
336                edge_type, edge_schema.from_labels, edge_schema.to_labels
337            )
338            for prop in edge_schema.properties.values():
339                if prop.nullable:
340                    eb = eb.property_nullable(prop.name, prop.data_type)
341                else:
342                    eb = eb.property(prop.name, prop.data_type)
343            builder = eb.done()
344            has_changes = True
345
346        if has_changes:
347            await builder.apply()
348
349        # Create vector and fulltext indexes via schema builder
350        for label, label_schema in schema.labels.items():
351            for prop in label_schema.properties.values():
352                if prop.index_type == "vector":
353                    metric = prop.metric or "l2"
354                    try:
355                        await (
356                            db.schema()
357                            .label(label)
358                            .index(prop.name, {"type": "vector", "metric": metric})
359                            .apply()
360                        )
361                    except Exception:
362                        pass  # Index may already exist
363                elif prop.index_type == "sparse":
364                    try:
365                        cfg = {"type": "sparse"}
366                        if prop.data_type.startswith("sparse_vector:"):
367                            cfg["dimensions"] = int(prop.data_type.split(":")[1])
368                        await db.schema().label(label).index(prop.name, cfg).apply()
369                    except Exception:
370                        pass  # Index may already exist
371                elif prop.index_type == "fulltext":
372                    try:
373                        await (
374                            db.schema()
375                            .label(label)
376                            .index(prop.name, "fulltext")
377                            .apply()
378                        )
379                    except Exception:
380                        pass  # Index may already exist

Apply the generated schema to an async database.

Async variant of apply_to_database using AsyncSchemaBuilder.

@dataclass
class DatabaseSchema:
57@dataclass
58class DatabaseSchema:
59    """Complete database schema generated from models."""
60
61    labels: dict[str, LabelSchema] = field(default_factory=dict)
62    edge_types: dict[str, EdgeTypeSchema] = field(default_factory=dict)

Complete database schema generated from models.

DatabaseSchema( labels: dict[str, LabelSchema] = <factory>, edge_types: dict[str, EdgeTypeSchema] = <factory>)
labels: dict[str, LabelSchema]
edge_types: dict[str, EdgeTypeSchema]
@dataclass
class LabelSchema:
39@dataclass
40class LabelSchema:
41    """Schema for a vertex label."""
42
43    name: str
44    properties: dict[str, PropertySchema] = field(default_factory=dict)

Schema for a vertex label.

LabelSchema( name: str, properties: dict[str, PropertySchema] = <factory>)
name: str
properties: dict[str, PropertySchema]
@dataclass
class EdgeTypeSchema:
47@dataclass
48class EdgeTypeSchema:
49    """Schema for an edge type."""
50
51    name: str
52    from_labels: list[str] = field(default_factory=list)
53    to_labels: list[str] = field(default_factory=list)
54    properties: dict[str, PropertySchema] = field(default_factory=dict)

Schema for an edge type.

EdgeTypeSchema( name: str, from_labels: list[str] = <factory>, to_labels: list[str] = <factory>, properties: dict[str, PropertySchema] = <factory>)
name: str
from_labels: list[str]
to_labels: list[str]
properties: dict[str, PropertySchema]
@dataclass
class PropertySchema:
26@dataclass
27class PropertySchema:
28    """Schema for a single property."""
29
30    name: str
31    data_type: str
32    nullable: bool = False
33    index_type: str | None = None
34    unique: bool = False
35    tokenizer: str | None = None
36    metric: str | None = None

Schema for a single property.

PropertySchema( name: str, data_type: str, nullable: bool = False, index_type: str | None = None, unique: bool = False, tokenizer: str | None = None, metric: str | None = None)
name: str
data_type: str
nullable: bool = False
index_type: str | None = None
unique: bool = False
tokenizer: str | None = None
metric: str | None = None
def generate_schema( *models: type[UniNode] | type[UniEdge]) -> DatabaseSchema:
383def generate_schema(*models: type[UniNode] | type[UniEdge]) -> DatabaseSchema:
384    """Generate a database schema from the given models."""
385    generator = SchemaGenerator()
386    generator.register(*models)
387    return generator.generate()

Generate a database schema from the given models.

class UniDatabase:
15class UniDatabase:
16    """
17    Thin wrapper around uni-db UniBuilder for ergonomic database creation.
18
19    Example:
20        >>> db = UniDatabase.open("./path").cache_size(1024*1024).build()
21        >>> db = UniDatabase.temporary().build()
22        >>> db = UniDatabase.in_memory().build()
23    """
24
25    def __init__(self, builder: uni_db.UniBuilder) -> None:
26        self._builder = builder
27
28    @classmethod
29    def open(cls, path: str) -> UniDatabase:
30        """Open or create a database at the given path."""
31        import uni_db
32
33        return cls(uni_db.UniBuilder.open(path))
34
35    @classmethod
36    def create(cls, path: str) -> UniDatabase:
37        """Create a new database at the given path."""
38        import uni_db
39
40        return cls(uni_db.UniBuilder.create(path))
41
42    @classmethod
43    def open_existing(cls, path: str) -> UniDatabase:
44        """Open an existing database (must already exist)."""
45        import uni_db
46
47        return cls(uni_db.UniBuilder.open_existing(path))
48
49    @classmethod
50    def temporary(cls) -> UniDatabase:
51        """Create an ephemeral in-memory database."""
52        import uni_db
53
54        return cls(uni_db.UniBuilder.temporary())
55
56    @classmethod
57    def in_memory(cls) -> UniDatabase:
58        """Create a persistent in-memory database."""
59        import uni_db
60
61        return cls(uni_db.UniBuilder.in_memory())
62
63    def cache_size(self, bytes_: int) -> UniDatabase:
64        """Set the cache size in bytes."""
65        self._builder = self._builder.cache_size(bytes_)
66        return self
67
68    def parallelism(self, n: int) -> UniDatabase:
69        """Set the parallelism level."""
70        self._builder = self._builder.parallelism(n)
71        return self
72
73    def build(self) -> uni_db.Uni:
74        """Build and return the database instance."""
75        return self._builder.build()

Thin wrapper around uni-db UniBuilder for ergonomic database creation.

Example:

db = UniDatabase.open("./path").cache_size(1024*1024).build() db = UniDatabase.temporary().build() db = UniDatabase.in_memory().build()

UniDatabase(builder: UniBuilder)
25    def __init__(self, builder: uni_db.UniBuilder) -> None:
26        self._builder = builder
@classmethod
def open(cls, path: str) -> UniDatabase:
28    @classmethod
29    def open(cls, path: str) -> UniDatabase:
30        """Open or create a database at the given path."""
31        import uni_db
32
33        return cls(uni_db.UniBuilder.open(path))

Open or create a database at the given path.

@classmethod
def create(cls, path: str) -> UniDatabase:
35    @classmethod
36    def create(cls, path: str) -> UniDatabase:
37        """Create a new database at the given path."""
38        import uni_db
39
40        return cls(uni_db.UniBuilder.create(path))

Create a new database at the given path.

@classmethod
def open_existing(cls, path: str) -> UniDatabase:
42    @classmethod
43    def open_existing(cls, path: str) -> UniDatabase:
44        """Open an existing database (must already exist)."""
45        import uni_db
46
47        return cls(uni_db.UniBuilder.open_existing(path))

Open an existing database (must already exist).

@classmethod
def temporary(cls) -> UniDatabase:
49    @classmethod
50    def temporary(cls) -> UniDatabase:
51        """Create an ephemeral in-memory database."""
52        import uni_db
53
54        return cls(uni_db.UniBuilder.temporary())

Create an ephemeral in-memory database.

@classmethod
def in_memory(cls) -> UniDatabase:
56    @classmethod
57    def in_memory(cls) -> UniDatabase:
58        """Create a persistent in-memory database."""
59        import uni_db
60
61        return cls(uni_db.UniBuilder.in_memory())

Create a persistent in-memory database.

def cache_size(self, bytes_: int) -> UniDatabase:
63    def cache_size(self, bytes_: int) -> UniDatabase:
64        """Set the cache size in bytes."""
65        self._builder = self._builder.cache_size(bytes_)
66        return self

Set the cache size in bytes.

def parallelism(self, n: int) -> UniDatabase:
68    def parallelism(self, n: int) -> UniDatabase:
69        """Set the parallelism level."""
70        self._builder = self._builder.parallelism(n)
71        return self

Set the parallelism level.

def build(self) -> Uni:
73    def build(self) -> uni_db.Uni:
74        """Build and return the database instance."""
75        return self._builder.build()

Build and return the database instance.

class AsyncUniDatabase:
 78class AsyncUniDatabase:
 79    """
 80    Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation.
 81
 82    Example:
 83        >>> db = await AsyncUniDatabase.open("./path").build()
 84        >>> db = await AsyncUniDatabase.temporary().build()
 85    """
 86
 87    def __init__(self, builder: uni_db.AsyncUniBuilder) -> None:
 88        self._builder = builder
 89
 90    @classmethod
 91    def open(cls, path: str) -> AsyncUniDatabase:
 92        """Open or create a database at the given path."""
 93        import uni_db
 94
 95        return cls(uni_db.AsyncUniBuilder.open(path))
 96
 97    @classmethod
 98    def temporary(cls) -> AsyncUniDatabase:
 99        """Create an ephemeral in-memory database."""
100        import uni_db
101
102        return cls(uni_db.AsyncUniBuilder.temporary())
103
104    @classmethod
105    def in_memory(cls) -> AsyncUniDatabase:
106        """Create a persistent in-memory database."""
107        import uni_db
108
109        return cls(uni_db.AsyncUniBuilder.in_memory())
110
111    def cache_size(self, bytes_: int) -> AsyncUniDatabase:
112        """Set the cache size in bytes."""
113        self._builder = self._builder.cache_size(bytes_)
114        return self
115
116    def parallelism(self, n: int) -> AsyncUniDatabase:
117        """Set the parallelism level."""
118        self._builder = self._builder.parallelism(n)
119        return self
120
121    async def build(self) -> uni_db.AsyncUni:
122        """Build and return the async database instance."""
123        return await self._builder.build()

Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation.

Example:

db = await AsyncUniDatabase.open("./path").build() db = await AsyncUniDatabase.temporary().build()

AsyncUniDatabase(builder: AsyncUniBuilder)
87    def __init__(self, builder: uni_db.AsyncUniBuilder) -> None:
88        self._builder = builder
@classmethod
def open(cls, path: str) -> AsyncUniDatabase:
90    @classmethod
91    def open(cls, path: str) -> AsyncUniDatabase:
92        """Open or create a database at the given path."""
93        import uni_db
94
95        return cls(uni_db.AsyncUniBuilder.open(path))

Open or create a database at the given path.

@classmethod
def temporary(cls) -> AsyncUniDatabase:
 97    @classmethod
 98    def temporary(cls) -> AsyncUniDatabase:
 99        """Create an ephemeral in-memory database."""
100        import uni_db
101
102        return cls(uni_db.AsyncUniBuilder.temporary())

Create an ephemeral in-memory database.

@classmethod
def in_memory(cls) -> AsyncUniDatabase:
104    @classmethod
105    def in_memory(cls) -> AsyncUniDatabase:
106        """Create a persistent in-memory database."""
107        import uni_db
108
109        return cls(uni_db.AsyncUniBuilder.in_memory())

Create a persistent in-memory database.

def cache_size(self, bytes_: int) -> AsyncUniDatabase:
111    def cache_size(self, bytes_: int) -> AsyncUniDatabase:
112        """Set the cache size in bytes."""
113        self._builder = self._builder.cache_size(bytes_)
114        return self

Set the cache size in bytes.

def parallelism(self, n: int) -> AsyncUniDatabase:
116    def parallelism(self, n: int) -> AsyncUniDatabase:
117        """Set the parallelism level."""
118        self._builder = self._builder.parallelism(n)
119        return self

Set the parallelism level.

async def build(self) -> AsyncUni:
121    async def build(self) -> uni_db.AsyncUni:
122        """Build and return the async database instance."""
123        return await self._builder.build()

Build and return the async database instance.

def before_create(func: ~F) -> ~F:
40def before_create(func: F) -> F:
41    """
42    Mark a method to be called before the entity is created in the database.
43
44    The method is called after validation but before the INSERT operation.
45    Useful for setting timestamps, generating IDs, or final validation.
46
47    Example:
48        >>> class Person(UniNode):
49        ...     name: str
50        ...     created_at: datetime | None = None
51        ...
52        ...     @before_create
53        ...     def set_created_at(self):
54        ...         self.created_at = datetime.now()
55    """
56    return _mark_hook(_BEFORE_CREATE)(func)

Mark a method to be called before the entity is created in the database.

The method is called after validation but before the INSERT operation. Useful for setting timestamps, generating IDs, or final validation.

Example:

class Person(UniNode): ... name: str ... created_at: datetime | None = None ... ... @before_create ... def set_created_at(self): ... self.created_at = datetime.now()

def after_create(func: ~F) -> ~F:
59def after_create(func: F) -> F:
60    """
61    Mark a method to be called after the entity is created in the database.
62
63    The method is called after the INSERT operation completes successfully.
64    The entity will have its vid/eid assigned at this point.
65
66    Example:
67        >>> class Person(UniNode):
68        ...     name: str
69        ...
70        ...     @after_create
71        ...     def log_creation(self):
72        ...         logger.info(f"Created person {self.name} with vid={self.vid}")
73    """
74    return _mark_hook(_AFTER_CREATE)(func)

Mark a method to be called after the entity is created in the database.

The method is called after the INSERT operation completes successfully. The entity will have its vid/eid assigned at this point.

Example:

class Person(UniNode): ... name: str ... ... @after_create ... def log_creation(self): ... logger.info(f"Created person {self.name} with vid={self.vid}")

def before_update(func: ~F) -> ~F:
77def before_update(func: F) -> F:
78    """
79    Mark a method to be called before the entity is updated in the database.
80
81    The method is called before the UPDATE operation.
82    Useful for validation or updating timestamps.
83
84    Example:
85        >>> class Person(UniNode):
86        ...     name: str
87        ...     updated_at: datetime | None = None
88        ...
89        ...     @before_update
90        ...     def validate_and_timestamp(self):
91        ...         if not self.name:
92        ...             raise ValueError("Name cannot be empty")
93        ...         self.updated_at = datetime.now()
94    """
95    return _mark_hook(_BEFORE_UPDATE)(func)

Mark a method to be called before the entity is updated in the database.

The method is called before the UPDATE operation. Useful for validation or updating timestamps.

Example:

class Person(UniNode): ... name: str ... updated_at: datetime | None = None ... ... @before_update ... def validate_and_timestamp(self): ... if not self.name: ... raise ValueError("Name cannot be empty") ... self.updated_at = datetime.now()

def after_update(func: ~F) -> ~F:
 98def after_update(func: F) -> F:
 99    """
100    Mark a method to be called after the entity is updated in the database.
101
102    The method is called after the UPDATE operation completes successfully.
103
104    Example:
105        >>> class Person(UniNode):
106        ...     name: str
107        ...
108        ...     @after_update
109        ...     def notify_change(self):
110        ...         events.emit("person_updated", self.vid)
111    """
112    return _mark_hook(_AFTER_UPDATE)(func)

Mark a method to be called after the entity is updated in the database.

The method is called after the UPDATE operation completes successfully.

Example:

class Person(UniNode): ... name: str ... ... @after_update ... def notify_change(self): ... events.emit("person_updated", self.vid)

def before_delete(func: ~F) -> ~F:
115def before_delete(func: F) -> F:
116    """
117    Mark a method to be called before the entity is deleted from the database.
118
119    The method is called before the DELETE operation.
120    Useful for cleanup or validation.
121
122    Example:
123        >>> class Person(UniNode):
124        ...     name: str
125        ...
126        ...     @before_delete
127        ...     def cleanup(self):
128        ...         # Remove related data
129        ...         pass
130    """
131    return _mark_hook(_BEFORE_DELETE)(func)

Mark a method to be called before the entity is deleted from the database.

The method is called before the DELETE operation. Useful for cleanup or validation.

Example:

class Person(UniNode): ... name: str ... ... @before_delete ... def cleanup(self): ... # Remove related data ... pass

def after_delete(func: ~F) -> ~F:
134def after_delete(func: F) -> F:
135    """
136    Mark a method to be called after the entity is deleted from the database.
137
138    The method is called after the DELETE operation completes successfully.
139    The entity's vid/eid will be cleared at this point.
140
141    Example:
142        >>> class Person(UniNode):
143        ...     name: str
144        ...
145        ...     @after_delete
146        ...     def log_deletion(self):
147        ...         logger.info(f"Deleted person {self.name}")
148    """
149    return _mark_hook(_AFTER_DELETE)(func)

Mark a method to be called after the entity is deleted from the database.

The method is called after the DELETE operation completes successfully. The entity's vid/eid will be cleared at this point.

Example:

class Person(UniNode): ... name: str ... ... @after_delete ... def log_deletion(self): ... logger.info(f"Deleted person {self.name}")

def before_load(func: ~F) -> ~F:
152def before_load(func: F) -> F:
153    """
154    Mark a method to be called before the entity is loaded from the database.
155
156    This is a class method that receives the raw property dictionary.
157    Can be used to transform data before model instantiation.
158
159    Example:
160        >>> class Person(UniNode):
161        ...     name: str
162        ...
163        ...     @classmethod
164        ...     @before_load
165        ...     def transform_data(cls, props: dict) -> dict:
166        ...         # Normalize name
167        ...         if 'name' in props:
168        ...             props['name'] = props['name'].strip()
169        ...         return props
170    """
171    return _mark_hook(_BEFORE_LOAD)(func)

Mark a method to be called before the entity is loaded from the database.

This is a class method that receives the raw property dictionary. Can be used to transform data before model instantiation.

Example:

class Person(UniNode): ... name: str ... ... @classmethod ... @before_load ... def transform_data(cls, props: dict) -> dict: ... # Normalize name ... if 'name' in props: ... props['name'] = props['name'].strip() ... return props

def after_load(func: ~F) -> ~F:
174def after_load(func: F) -> F:
175    """
176    Mark a method to be called after the entity is loaded from the database.
177
178    The method is called after the entity is instantiated from database data.
179    Useful for computing derived values or initializing non-persisted state.
180
181    Example:
182        >>> class Person(UniNode):
183        ...     first_name: str
184        ...     last_name: str
185        ...     full_name: str | None = None
186        ...
187        ...     @after_load
188        ...     def compute_full_name(self):
189        ...         self.full_name = f"{self.first_name} {self.last_name}"
190    """
191    return _mark_hook(_AFTER_LOAD)(func)

Mark a method to be called after the entity is loaded from the database.

The method is called after the entity is instantiated from database data. Useful for computing derived values or initializing non-persisted state.

Example:

class Person(UniNode): ... first_name: str ... last_name: str ... full_name: str | None = None ... ... @after_load ... def compute_full_name(self): ... self.full_name = f"{self.first_name} {self.last_name}"

class UniPydanticError(builtins.Exception):
15class UniPydanticError(Exception):
16    """Base exception for all uni-pydantic errors."""

Base exception for all uni-pydantic errors.

class SchemaError(uni_pydantic.UniPydanticError):
19class SchemaError(UniPydanticError):
20    """Error related to schema definition or generation."""
21
22    def __init__(self, message: str, model: type | None = None) -> None:
23        self.model = model
24        super().__init__(message)

Error related to schema definition or generation.

SchemaError(message: str, model: type | None = None)
22    def __init__(self, message: str, model: type | None = None) -> None:
23        self.model = model
24        super().__init__(message)
model
class TypeMappingError(uni_pydantic.SchemaError):
27class TypeMappingError(SchemaError):
28    """Error mapping Python type to Uni DataType."""
29
30    def __init__(self, python_type: Any, message: str | None = None) -> None:
31        self.python_type = python_type
32        msg = message or f"Cannot map Python type {python_type!r} to Uni DataType"
33        super().__init__(msg)

Error mapping Python type to Uni DataType.

TypeMappingError(python_type: Any, message: str | None = None)
30    def __init__(self, python_type: Any, message: str | None = None) -> None:
31        self.python_type = python_type
32        msg = message or f"Cannot map Python type {python_type!r} to Uni DataType"
33        super().__init__(msg)
python_type
class ValidationError(uni_pydantic.UniPydanticError):
36class ValidationError(UniPydanticError):
37    """Validation error for model instances."""

Validation error for model instances.

class SessionError(uni_pydantic.UniPydanticError):
40class SessionError(UniPydanticError):
41    """Error related to session operations."""

Error related to session operations.

class NotRegisteredError(uni_pydantic.SessionError):
44class NotRegisteredError(SessionError):
45    """Model type not registered with session."""
46
47    def __init__(self, model: type[UniNode] | type[UniEdge]) -> None:
48        self.model = model
49        super().__init__(
50            f"Model {model.__name__!r} is not registered with this session. "
51            f"Call session.register({model.__name__}) first."
52        )

Model type not registered with session.

NotRegisteredError( model: type[UniNode] | type[UniEdge])
47    def __init__(self, model: type[UniNode] | type[UniEdge]) -> None:
48        self.model = model
49        super().__init__(
50            f"Model {model.__name__!r} is not registered with this session. "
51            f"Call session.register({model.__name__}) first."
52        )
model
class NotPersisted(uni_pydantic.SessionError):
55class NotPersisted(SessionError):
56    """Operation requires a persisted entity."""
57
58    def __init__(self, entity: UniNode | UniEdge) -> None:
59        self.entity = entity
60        super().__init__(
61            f"Entity {entity!r} is not persisted. Call session.add() and commit() first."
62        )

Operation requires a persisted entity.

NotPersisted(entity: UniNode | UniEdge)
58    def __init__(self, entity: UniNode | UniEdge) -> None:
59        self.entity = entity
60        super().__init__(
61            f"Entity {entity!r} is not persisted. Call session.add() and commit() first."
62        )
entity
class NotTrackedError(uni_pydantic.SessionError):
65class NotTrackedError(SessionError):
66    """Entity is not tracked by this session."""

Entity is not tracked by this session.

class TransactionError(uni_pydantic.SessionError):
69class TransactionError(SessionError):
70    """Error related to transaction operations."""

Error related to transaction operations.

class QueryError(uni_pydantic.UniPydanticError):
73class QueryError(UniPydanticError):
74    """Error executing a query."""

Error executing a query.

class RelationshipError(uni_pydantic.UniPydanticError):
77class RelationshipError(UniPydanticError):
78    """Error related to relationship operations."""

Error related to relationship operations.

class LazyLoadError(uni_pydantic.RelationshipError):
81class LazyLoadError(RelationshipError):
82    """Error lazy-loading a relationship."""
83
84    def __init__(self, field_name: str, reason: str) -> None:
85        self.field_name = field_name
86        super().__init__(f"Cannot lazy-load relationship '{field_name}': {reason}")

Error lazy-loading a relationship.

LazyLoadError(field_name: str, reason: str)
84    def __init__(self, field_name: str, reason: str) -> None:
85        self.field_name = field_name
86        super().__init__(f"Cannot lazy-load relationship '{field_name}': {reason}")
field_name
class BulkLoadError(uni_pydantic.UniPydanticError):
89class BulkLoadError(UniPydanticError):
90    """Error during bulk loading operations."""

Error during bulk loading operations.

class CypherInjectionError(uni_pydantic.QueryError):
93class CypherInjectionError(QueryError):
94    """Property name validation failure — potential Cypher injection."""
95
96    def __init__(self, name: str, reason: str | None = None) -> None:
97        self.name = name
98        msg = reason or f"Invalid property name {name!r}: possible Cypher injection"
99        super().__init__(msg)

Property name validation failure — potential Cypher injection.

CypherInjectionError(name: str, reason: str | None = None)
96    def __init__(self, name: str, reason: str | None = None) -> None:
97        self.name = name
98        msg = reason or f"Invalid property name {name!r}: possible Cypher injection"
99        super().__init__(msg)
name