uni_pydantic
uni-pydantic: Pydantic-based OGM for Uni Graph Database.
This package provides a type-safe Object-Graph Mapping layer on top of the Uni graph database, using Pydantic v2 for model definitions.
Example:
from uni_db import Uni from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector
class Person(UniNode): ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... embedding: Vector[1536] ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
db = Uni("./my_graph") session = UniSession(db) session.register(Person) session.sync_schema()
alice = Person(name="Alice", age=30, email="alice@example.com") session.add(alice) session.commit()
Query with type safety
adults = session.query(Person).filter(Person.age >= 18).all()
1# SPDX-License-Identifier: Apache-2.0 2# Copyright 2024-2026 Dragonscale Team 3 4""" 5uni-pydantic: Pydantic-based OGM for Uni Graph Database. 6 7This package provides a type-safe Object-Graph Mapping layer on top of 8the Uni graph database, using Pydantic v2 for model definitions. 9 10Example: 11 >>> from uni_db import Uni 12 >>> from uni_pydantic import UniNode, UniSession, Field, Relationship, Vector 13 >>> 14 >>> class Person(UniNode): 15 ... name: str 16 ... age: int | None = None 17 ... email: str = Field(unique=True) 18 ... embedding: Vector[1536] 19 ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both") 20 >>> 21 >>> db = Uni("./my_graph") 22 >>> session = UniSession(db) 23 >>> session.register(Person) 24 >>> session.sync_schema() 25 >>> 26 >>> alice = Person(name="Alice", age=30, email="alice@example.com") 27 >>> session.add(alice) 28 >>> session.commit() 29 >>> 30 >>> # Query with type safety 31 >>> adults = session.query(Person).filter(Person.age >= 18).all() 32""" 33 34__version__ = "1.1.0" 35 36# Base classes 37# Async support 38from .async_query import AsyncQueryBuilder 39from .async_session import AsyncUniSession, AsyncUniTransaction 40from .base import UniEdge, UniNode 41 42# Database wrappers 43from .database import AsyncUniDatabase, UniDatabase 44 45# Exceptions 46from .exceptions import ( 47 BulkLoadError, 48 CypherInjectionError, 49 LazyLoadError, 50 NotPersisted, 51 NotRegisteredError, 52 NotTrackedError, 53 QueryError, 54 RelationshipError, 55 SchemaError, 56 SessionError, 57 TransactionError, 58 TypeMappingError, 59 UniPydanticError, 60 ValidationError, 61) 62 63# Field configuration 64from .fields import ( 65 Direction, 66 Field, 67 FieldConfig, 68 IndexType, 69 Relationship, 70 RelationshipConfig, 71 RelationshipDescriptor, 72 VectorMetric, 73 get_field_config, 74) 75 76# Lifecycle hooks 77from .hooks import ( 78 after_create, 79 after_delete, 80 after_load, 81 after_update, 82 before_create, 83 before_delete, 84 before_load, 85 before_update, 86) 87 88# Query builder 89from .query import ( 90 FilterExpr, 91 FilterOp, 92 ModelProxy, 93 OrderByClause, 94 PropertyProxy, 95 QueryBuilder, 96 TraversalStep, 97 VectorSearchConfig, 98) 99 100# Schema generation 101from .schema import ( 102 DatabaseSchema, 103 EdgeTypeSchema, 104 LabelSchema, 105 PropertySchema, 106 SchemaGenerator, 107 generate_schema, 108) 109 110# Session management 111from .session import UniSession, UniTransaction 112 113# Type utilities 114from .types import ( 115 DATETIME_TYPES, 116 Btic, 117 Vector, 118 db_to_python_value, 119 get_vector_dimensions, 120 is_list_type, 121 is_optional, 122 python_to_db_value, 123 python_type_to_uni, 124 uni_to_python_type, 125 unwrap_annotated, 126) 127 128__all__ = [ 129 # Version 130 "__version__", 131 # Base classes 132 "UniNode", 133 "UniEdge", 134 # Session 135 "UniSession", 136 "UniTransaction", 137 # Async Session 138 "AsyncUniSession", 139 "AsyncUniTransaction", 140 # Fields 141 "Field", 142 "FieldConfig", 143 "Relationship", 144 "RelationshipConfig", 145 "RelationshipDescriptor", 146 "get_field_config", 147 "IndexType", 148 "Direction", 149 "VectorMetric", 150 # Types 151 "Btic", 152 "Vector", 153 "python_type_to_uni", 154 "uni_to_python_type", 155 "get_vector_dimensions", 156 "is_optional", 157 "is_list_type", 158 "unwrap_annotated", 159 "python_to_db_value", 160 "db_to_python_value", 161 "DATETIME_TYPES", 162 # Query 163 "QueryBuilder", 164 "AsyncQueryBuilder", 165 "FilterExpr", 166 "FilterOp", 167 "PropertyProxy", 168 "ModelProxy", 169 "OrderByClause", 170 "TraversalStep", 171 "VectorSearchConfig", 172 # Schema 173 "SchemaGenerator", 174 "DatabaseSchema", 175 "LabelSchema", 176 "EdgeTypeSchema", 177 "PropertySchema", 178 "generate_schema", 179 # Database 180 "UniDatabase", 181 "AsyncUniDatabase", 182 # Hooks 183 "before_create", 184 "after_create", 185 "before_update", 186 "after_update", 187 "before_delete", 188 "after_delete", 189 "before_load", 190 "after_load", 191 # Exceptions 192 "UniPydanticError", 193 "SchemaError", 194 "TypeMappingError", 195 "ValidationError", 196 "SessionError", 197 "NotRegisteredError", 198 "NotPersisted", 199 "NotTrackedError", 200 "TransactionError", 201 "QueryError", 202 "RelationshipError", 203 "LazyLoadError", 204 "BulkLoadError", 205 "CypherInjectionError", 206]
159class UniNode(BaseModel, metaclass=UniModelMeta): 160 """ 161 Base class for graph node models. 162 163 Subclass this to define your node types. Each UniNode subclass 164 represents a vertex label in the graph database. 165 166 Attributes: 167 __label__: The vertex label name. Defaults to the class name. 168 __relationships__: Dictionary of relationship configurations. 169 170 Private Attributes: 171 _vid: The vertex ID assigned by the database. 172 _uid: The unique identifier (content-addressed hash). 173 _session: Reference to the owning session. 174 _dirty: Set of modified field names. 175 176 Example: 177 >>> class Person(UniNode): 178 ... __label__ = "Person" 179 ... 180 ... name: str 181 ... age: int | None = None 182 ... email: str = Field(unique=True) 183 ... 184 ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both") 185 """ 186 187 model_config = ConfigDict( 188 # Allow extra fields for future extensibility 189 extra="forbid", 190 # Validate on assignment for dirty tracking 191 validate_assignment=True, 192 # Allow arbitrary types (for Vector, etc.) 193 arbitrary_types_allowed=True, 194 # Use enum values 195 use_enum_values=True, 196 ) 197 198 # Class-level configuration 199 __label__: ClassVar[str] = "" 200 __relationships__: ClassVar[dict[str, RelationshipConfig]] = {} 201 202 # Private attributes for session tracking 203 _vid: int | None = PrivateAttr(default=None) 204 _uid: str | None = PrivateAttr(default=None) 205 _session: UniSession | None = PrivateAttr(default=None) 206 _dirty: set[str] = PrivateAttr(default_factory=set) 207 _is_new: bool = PrivateAttr(default=True) 208 209 def __init_subclass__(cls, **kwargs: Any) -> None: 210 super().__init_subclass__(**kwargs) 211 # Set default label to class name if not specified 212 if not cls.__label__: 213 cls.__label__ = cls.__name__ 214 215 def model_post_init(self, __context: Any) -> None: 216 """Clear dirty tracking after construction.""" 217 super().model_post_init(__context) 218 self._dirty = set() 219 220 @property 221 def vid(self) -> int | None: 222 """The vertex ID assigned by the database.""" 223 return self._vid 224 225 @property 226 def uid(self) -> str | None: 227 """The unique identifier (content-addressed hash).""" 228 return self._uid 229 230 @property 231 def is_persisted(self) -> bool: 232 """Whether this node has been saved to the database.""" 233 return self._vid is not None 234 235 @property 236 def is_dirty(self) -> bool: 237 """Whether this node has unsaved changes.""" 238 return bool(self._dirty) 239 240 def __setattr__(self, name: str, value: Any) -> None: 241 # Track dirty fields (but not private attributes) 242 if not name.startswith("_") and hasattr(self, "_dirty"): 243 self._dirty.add(name) 244 super().__setattr__(name, value) 245 246 def _mark_clean(self) -> None: 247 """Mark all fields as clean (called after commit).""" 248 self._dirty.clear() 249 self._is_new = False 250 251 def _attach_session( 252 self, session: UniSession, vid: int, uid: str | None = None 253 ) -> None: 254 """Attach this node to a session with its database IDs.""" 255 self._session = session 256 self._vid = vid 257 self._uid = uid 258 self._is_new = False 259 260 @classmethod 261 def get_property_fields(cls) -> dict[str, FieldInfo]: 262 """Get all property fields (excluding relationships).""" 263 return { 264 name: info 265 for name, info in cls.model_fields.items() 266 if name not in cls.__relationships__ 267 } 268 269 @classmethod 270 def get_relationship_fields(cls) -> dict[str, RelationshipConfig]: 271 """Get all relationship field configurations.""" 272 return cls.__relationships__ 273 274 def to_properties(self) -> dict[str, Any]: 275 """Convert to a property dictionary for database storage. 276 277 Uses python_to_db_value() for type conversion. Includes None 278 explicitly so null-outs work. 279 """ 280 return _model_to_properties(self, self.get_property_fields()) 281 282 @classmethod 283 def from_properties( 284 cls, 285 props: dict[str, Any], 286 *, 287 vid: int | None = None, 288 uid: str | None = None, 289 session: UniSession | None = None, 290 ) -> UniNode: 291 """Create an instance from a property dictionary. 292 293 Accepts _id (string->int vid) and _label from uni-db node dicts. 294 Does not mutate the input dict. 295 """ 296 data = dict(props) 297 298 raw_id = data.pop("_id", None) 299 if raw_id is not None and vid is None: 300 vid = int(raw_id) if not isinstance(raw_id, int) else raw_id 301 data.pop("_label", None) 302 303 converted = _convert_db_values(data, cls) 304 305 instance = cls.model_validate(converted) 306 if vid is not None: 307 instance._vid = vid 308 if uid is not None: 309 instance._uid = uid 310 if session is not None: 311 instance._session = session 312 instance._is_new = vid is None 313 instance._dirty = set() 314 return instance 315 316 def __repr__(self) -> str: 317 vid_str = f"vid={self._vid}" if self._vid else "unsaved" 318 return f"{self.__class__.__name__}({vid_str}, {super().__repr__()})"
Base class for graph node models.
Subclass this to define your node types. Each UniNode subclass represents a vertex label in the graph database.
Attributes: __label__: The vertex label name. Defaults to the class name. __relationships__: Dictionary of relationship configurations.
Private Attributes: _vid: The vertex ID assigned by the database. _uid: The unique identifier (content-addressed hash). _session: Reference to the owning session. _dirty: Set of modified field names.
Example:
class Person(UniNode): ... __label__ = "Person" ... ... name: str ... age: int | None = None ... email: str = Field(unique=True) ... ... friends: list["Person"] = Relationship("FRIEND_OF", direction="both")
220 @property 221 def vid(self) -> int | None: 222 """The vertex ID assigned by the database.""" 223 return self._vid
The vertex ID assigned by the database.
225 @property 226 def uid(self) -> str | None: 227 """The unique identifier (content-addressed hash).""" 228 return self._uid
The unique identifier (content-addressed hash).
230 @property 231 def is_persisted(self) -> bool: 232 """Whether this node has been saved to the database.""" 233 return self._vid is not None
Whether this node has been saved to the database.
235 @property 236 def is_dirty(self) -> bool: 237 """Whether this node has unsaved changes.""" 238 return bool(self._dirty)
Whether this node has unsaved changes.
260 @classmethod 261 def get_property_fields(cls) -> dict[str, FieldInfo]: 262 """Get all property fields (excluding relationships).""" 263 return { 264 name: info 265 for name, info in cls.model_fields.items() 266 if name not in cls.__relationships__ 267 }
Get all property fields (excluding relationships).
269 @classmethod 270 def get_relationship_fields(cls) -> dict[str, RelationshipConfig]: 271 """Get all relationship field configurations.""" 272 return cls.__relationships__
Get all relationship field configurations.
274 def to_properties(self) -> dict[str, Any]: 275 """Convert to a property dictionary for database storage. 276 277 Uses python_to_db_value() for type conversion. Includes None 278 explicitly so null-outs work. 279 """ 280 return _model_to_properties(self, self.get_property_fields())
Convert to a property dictionary for database storage.
Uses python_to_db_value() for type conversion. Includes None explicitly so null-outs work.
282 @classmethod 283 def from_properties( 284 cls, 285 props: dict[str, Any], 286 *, 287 vid: int | None = None, 288 uid: str | None = None, 289 session: UniSession | None = None, 290 ) -> UniNode: 291 """Create an instance from a property dictionary. 292 293 Accepts _id (string->int vid) and _label from uni-db node dicts. 294 Does not mutate the input dict. 295 """ 296 data = dict(props) 297 298 raw_id = data.pop("_id", None) 299 if raw_id is not None and vid is None: 300 vid = int(raw_id) if not isinstance(raw_id, int) else raw_id 301 data.pop("_label", None) 302 303 converted = _convert_db_values(data, cls) 304 305 instance = cls.model_validate(converted) 306 if vid is not None: 307 instance._vid = vid 308 if uid is not None: 309 instance._uid = uid 310 if session is not None: 311 instance._session = session 312 instance._is_new = vid is None 313 instance._dirty = set() 314 return instance
Create an instance from a property dictionary.
Accepts _id (string->int vid) and _label from uni-db node dicts. Does not mutate the input dict.
321class UniEdge(BaseModel, metaclass=UniModelMeta): 322 """ 323 Base class for graph edge models with properties. 324 325 Subclass this to define edge types with typed properties. 326 Edges represent relationships between nodes. 327 328 Attributes: 329 __edge_type__: The edge type name. 330 __from__: The source node type(s). 331 __to__: The target node type(s). 332 333 Private Attributes: 334 _eid: The edge ID assigned by the database. 335 _src_vid: The source vertex ID. 336 _dst_vid: The destination vertex ID. 337 _session: Reference to the owning session. 338 339 Example: 340 >>> class FriendshipEdge(UniEdge): 341 ... __edge_type__ = "FRIEND_OF" 342 ... __from__ = Person 343 ... __to__ = Person 344 ... 345 ... since: date 346 ... strength: float = 1.0 347 """ 348 349 model_config = ConfigDict( 350 extra="forbid", 351 validate_assignment=True, 352 arbitrary_types_allowed=True, 353 use_enum_values=True, 354 ) 355 356 # Class-level configuration 357 __edge_type__: ClassVar[str] = "" 358 __from__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None 359 __to__: ClassVar[type[UniNode] | tuple[type[UniNode], ...] | None] = None 360 __relationships__: ClassVar[dict[str, RelationshipConfig]] = {} 361 362 # Private attributes 363 _eid: int | None = PrivateAttr(default=None) 364 _src_vid: int | None = PrivateAttr(default=None) 365 _dst_vid: int | None = PrivateAttr(default=None) 366 _session: UniSession | None = PrivateAttr(default=None) 367 _is_new: bool = PrivateAttr(default=True) 368 369 def __init_subclass__(cls, **kwargs: Any) -> None: 370 super().__init_subclass__(**kwargs) 371 # Set default edge type to class name if not specified 372 if not cls.__edge_type__: 373 cls.__edge_type__ = cls.__name__ 374 375 @property 376 def eid(self) -> int | None: 377 """The edge ID assigned by the database.""" 378 return self._eid 379 380 @property 381 def src_vid(self) -> int | None: 382 """The source vertex ID.""" 383 return self._src_vid 384 385 @property 386 def dst_vid(self) -> int | None: 387 """The destination vertex ID.""" 388 return self._dst_vid 389 390 @property 391 def is_persisted(self) -> bool: 392 """Whether this edge has been saved to the database.""" 393 return self._eid is not None 394 395 def _attach( 396 self, 397 session: UniSession, 398 eid: int, 399 src_vid: int, 400 dst_vid: int, 401 ) -> None: 402 """Attach this edge to a session with its database IDs.""" 403 self._session = session 404 self._eid = eid 405 self._src_vid = src_vid 406 self._dst_vid = dst_vid 407 self._is_new = False 408 409 @classmethod 410 def get_from_labels(cls) -> list[str]: 411 """Get the source label names.""" 412 if cls.__from__ is None: 413 return [] 414 if isinstance(cls.__from__, tuple): 415 return [n.__label__ for n in cls.__from__] 416 return [cls.__from__.__label__] 417 418 @classmethod 419 def get_to_labels(cls) -> list[str]: 420 """Get the target label names.""" 421 if cls.__to__ is None: 422 return [] 423 if isinstance(cls.__to__, tuple): 424 return [n.__label__ for n in cls.__to__] 425 return [cls.__to__.__label__] 426 427 @classmethod 428 def get_property_fields(cls) -> dict[str, FieldInfo]: 429 """Get all property fields.""" 430 return dict(cls.model_fields) 431 432 def to_properties(self) -> dict[str, Any]: 433 """Convert to a property dictionary for database storage.""" 434 return _model_to_properties(self, self.get_property_fields()) 435 436 @classmethod 437 def from_properties( 438 cls, 439 props: dict[str, Any], 440 *, 441 eid: int | None = None, 442 src_vid: int | None = None, 443 dst_vid: int | None = None, 444 session: UniSession | None = None, 445 ) -> UniEdge: 446 """Create an instance from a property dictionary. 447 448 Accepts _id, _type, _src, _dst from uni-db edge dicts. 449 Does not mutate the input dict. 450 """ 451 data = dict(props) 452 453 raw_id = data.pop("_id", None) 454 if raw_id is not None and eid is None: 455 eid = int(raw_id) if not isinstance(raw_id, int) else raw_id 456 data.pop("_type", None) 457 raw_src = data.pop("_src", None) 458 if raw_src is not None and src_vid is None: 459 src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src 460 raw_dst = data.pop("_dst", None) 461 if raw_dst is not None and dst_vid is None: 462 dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst 463 464 converted = _convert_db_values(data, cls) 465 466 instance = cls.model_validate(converted) 467 if eid is not None: 468 instance._eid = eid 469 if src_vid is not None: 470 instance._src_vid = src_vid 471 if dst_vid is not None: 472 instance._dst_vid = dst_vid 473 if session is not None: 474 instance._session = session 475 instance._is_new = eid is None 476 return instance 477 478 @classmethod 479 def from_edge_result( 480 cls, 481 data: dict[str, Any], 482 *, 483 session: UniSession | None = None, 484 ) -> UniEdge: 485 """Create an instance from a uni-db edge result dict. 486 487 Convenience method that handles _id, _type, _src, _dst keys. 488 """ 489 return cls.from_properties(data, session=session) 490 491 def __repr__(self) -> str: 492 eid_str = f"eid={self._eid}" if self._eid else "unsaved" 493 return f"{self.__class__.__name__}({eid_str}, {super().__repr__()})"
Base class for graph edge models with properties.
Subclass this to define edge types with typed properties. Edges represent relationships between nodes.
Attributes: __edge_type__: The edge type name. __from__: The source node type(s). __to__: The target node type(s).
Private Attributes: _eid: The edge ID assigned by the database. _src_vid: The source vertex ID. _dst_vid: The destination vertex ID. _session: Reference to the owning session.
Example:
class FriendshipEdge(UniEdge): ... __edge_type__ = "FRIEND_OF" ... __from__ = Person ... __to__ = Person ... ... since: date ... strength: float = 1.0
375 @property 376 def eid(self) -> int | None: 377 """The edge ID assigned by the database.""" 378 return self._eid
The edge ID assigned by the database.
380 @property 381 def src_vid(self) -> int | None: 382 """The source vertex ID.""" 383 return self._src_vid
The source vertex ID.
385 @property 386 def dst_vid(self) -> int | None: 387 """The destination vertex ID.""" 388 return self._dst_vid
The destination vertex ID.
390 @property 391 def is_persisted(self) -> bool: 392 """Whether this edge has been saved to the database.""" 393 return self._eid is not None
Whether this edge has been saved to the database.
409 @classmethod 410 def get_from_labels(cls) -> list[str]: 411 """Get the source label names.""" 412 if cls.__from__ is None: 413 return [] 414 if isinstance(cls.__from__, tuple): 415 return [n.__label__ for n in cls.__from__] 416 return [cls.__from__.__label__]
Get the source label names.
418 @classmethod 419 def get_to_labels(cls) -> list[str]: 420 """Get the target label names.""" 421 if cls.__to__ is None: 422 return [] 423 if isinstance(cls.__to__, tuple): 424 return [n.__label__ for n in cls.__to__] 425 return [cls.__to__.__label__]
Get the target label names.
427 @classmethod 428 def get_property_fields(cls) -> dict[str, FieldInfo]: 429 """Get all property fields.""" 430 return dict(cls.model_fields)
Get all property fields.
432 def to_properties(self) -> dict[str, Any]: 433 """Convert to a property dictionary for database storage.""" 434 return _model_to_properties(self, self.get_property_fields())
Convert to a property dictionary for database storage.
436 @classmethod 437 def from_properties( 438 cls, 439 props: dict[str, Any], 440 *, 441 eid: int | None = None, 442 src_vid: int | None = None, 443 dst_vid: int | None = None, 444 session: UniSession | None = None, 445 ) -> UniEdge: 446 """Create an instance from a property dictionary. 447 448 Accepts _id, _type, _src, _dst from uni-db edge dicts. 449 Does not mutate the input dict. 450 """ 451 data = dict(props) 452 453 raw_id = data.pop("_id", None) 454 if raw_id is not None and eid is None: 455 eid = int(raw_id) if not isinstance(raw_id, int) else raw_id 456 data.pop("_type", None) 457 raw_src = data.pop("_src", None) 458 if raw_src is not None and src_vid is None: 459 src_vid = int(raw_src) if not isinstance(raw_src, int) else raw_src 460 raw_dst = data.pop("_dst", None) 461 if raw_dst is not None and dst_vid is None: 462 dst_vid = int(raw_dst) if not isinstance(raw_dst, int) else raw_dst 463 464 converted = _convert_db_values(data, cls) 465 466 instance = cls.model_validate(converted) 467 if eid is not None: 468 instance._eid = eid 469 if src_vid is not None: 470 instance._src_vid = src_vid 471 if dst_vid is not None: 472 instance._dst_vid = dst_vid 473 if session is not None: 474 instance._session = session 475 instance._is_new = eid is None 476 return instance
Create an instance from a property dictionary.
Accepts _id, _type, _src, _dst from uni-db edge dicts. Does not mutate the input dict.
478 @classmethod 479 def from_edge_result( 480 cls, 481 data: dict[str, Any], 482 *, 483 session: UniSession | None = None, 484 ) -> UniEdge: 485 """Create an instance from a uni-db edge result dict. 486 487 Convenience method that handles _id, _type, _src, _dst keys. 488 """ 489 return cls.from_properties(data, session=session)
Create an instance from a uni-db edge result dict.
Convenience method that handles _id, _type, _src, _dst keys.
160class UniSession: 161 """ 162 Session for interacting with the graph database using Pydantic models. 163 164 The session manages model registration, schema synchronization, 165 and provides CRUD operations and query building. 166 167 Example: 168 >>> from uni_db import Uni 169 >>> from uni_pydantic import UniSession 170 >>> 171 >>> db = Uni("./my_graph") 172 >>> session = UniSession(db) 173 >>> session.register(Person, Company) 174 >>> session.sync_schema() 175 >>> 176 >>> alice = Person(name="Alice", age=30) 177 >>> session.add(alice) 178 >>> session.commit() 179 """ 180 181 def __init__(self, db: uni_db.Uni) -> None: 182 self._db = db 183 self._db_session = db.session() 184 self._schema_gen = SchemaGenerator() 185 self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = ( 186 WeakValueDictionary() 187 ) 188 self._pending_new: list[UniNode] = [] 189 self._pending_delete: list[UniNode] = [] 190 191 def __enter__(self) -> UniSession: 192 return self 193 194 def __exit__( 195 self, 196 exc_type: type[BaseException] | None, 197 exc_val: BaseException | None, 198 exc_tb: TracebackType | None, 199 ) -> None: 200 self.close() 201 202 def close(self) -> None: 203 """Close the session and clear all pending state.""" 204 self._pending_new.clear() 205 self._pending_delete.clear() 206 207 @property 208 def db(self) -> uni_db.Uni: 209 """Access the underlying uni_db.Uni for low-level operations.""" 210 return self._db 211 212 def locy( 213 self, program: str, params: dict[str, Any] | None = None 214 ) -> uni_db.LocyResult: 215 """ 216 Evaluate a Locy program and return derived facts, stats, and warnings. 217 218 Delegates to the underlying ``uni_db.Session.locy()``. 219 """ 220 return self._db_session.locy(program, params) 221 222 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 223 """ 224 Register model classes with the session. 225 226 Registered models can be used for schema generation and queries. 227 228 Args: 229 *models: UniNode or UniEdge subclasses to register. 230 """ 231 self._schema_gen.register(*models) 232 233 def sync_schema(self) -> None: 234 """ 235 Synchronize database schema with registered models. 236 237 Creates labels, edge types, properties, and indexes as needed. 238 This is additive-only; it won't remove existing schema elements. 239 """ 240 self._schema_gen.apply_to_database(self._db) 241 242 def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]: 243 """ 244 Create a query builder for the given model. 245 246 Args: 247 model: The UniNode subclass to query. 248 249 Returns: 250 A QueryBuilder for constructing queries. 251 """ 252 return QueryBuilder(self, model) 253 254 def add(self, entity: UniNode) -> None: 255 """ 256 Add a new entity to be persisted. 257 258 The entity will be inserted on the next commit(). 259 """ 260 if entity.is_persisted: 261 raise SessionError(f"Entity {entity!r} is already persisted") 262 entity._session = self 263 self._pending_new.append(entity) 264 265 def add_all(self, entities: Sequence[UniNode]) -> None: 266 """Add multiple entities to be persisted.""" 267 for entity in entities: 268 self.add(entity) 269 270 def delete(self, entity: UniNode) -> None: 271 """Mark an entity for deletion.""" 272 if not entity.is_persisted: 273 raise NotPersisted(entity) 274 self._pending_delete.append(entity) 275 276 def get( 277 self, 278 model: type[NodeT], 279 vid: int | None = None, 280 uid: str | None = None, 281 **kwargs: Any, 282 ) -> NodeT | None: 283 """ 284 Get an entity by ID or unique properties. 285 286 Args: 287 model: The model type to retrieve. 288 vid: Vertex ID to look up. 289 uid: Unique ID to look up. 290 **kwargs: Property equality filters. 291 292 Returns: 293 The model instance or None if not found. 294 """ 295 # Check identity map first 296 if vid is not None: 297 cached = self._identity_map.get((model.__label__, vid)) 298 if cached is not None: 299 return cached # type: ignore[return-value] 300 301 # Build query 302 label = model.__label__ 303 params: dict[str, Any] = {} 304 305 if vid is not None: 306 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 307 params["vid"] = vid 308 elif uid is not None: 309 cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}" 310 params["uid"] = uid 311 elif kwargs: 312 # Validate property names 313 for k in kwargs: 314 _validate_property(k, model) 315 conditions = [f"n.{k} = ${k}" for k in kwargs] 316 cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1" 317 params.update(kwargs) 318 else: 319 raise ValueError("Must provide vid, uid, or property filters") 320 321 results = self._db_session.query(cypher, params) 322 if not results: 323 return None 324 325 node_data = _row_to_node_dict(results[0].to_dict()) 326 if node_data is None: 327 return None 328 return self._result_to_model(node_data, model) 329 330 def refresh(self, entity: UniNode) -> None: 331 """Refresh an entity's properties from the database.""" 332 if not entity.is_persisted: 333 raise NotPersisted(entity) 334 335 label = entity.__class__.__label__ 336 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 337 results = self._db_session.query(cypher, {"vid": entity._vid}) 338 339 if not results: 340 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 341 342 # Update properties 343 props = _row_to_node_dict(results[0].to_dict()) 344 if props is None: 345 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 346 try: 347 hints = get_type_hints(type(entity)) 348 except Exception: 349 hints = {} 350 351 for field_name in entity.get_property_fields(): 352 if field_name in props: 353 value = props[field_name] 354 if field_name in hints: 355 value = db_to_python_value(value, hints[field_name]) 356 setattr(entity, field_name, value) 357 358 entity._mark_clean() 359 360 def commit(self) -> None: 361 """ 362 Commit all pending changes to the database. 363 364 This persists new entities, updates dirty entities, 365 and deletes marked entities. 366 """ 367 # Insert new entities 368 for entity in self._pending_new: 369 self._create_node(entity) 370 371 # Update dirty entities in identity map 372 for (label, vid), entity in list(self._identity_map.items()): 373 if entity.is_dirty and entity.is_persisted: 374 self._update_node(entity) 375 376 # Delete marked entities 377 for entity in self._pending_delete: 378 self._delete_node(entity) 379 380 # Flush to storage 381 self._db.flush() 382 383 # Clear pending lists 384 self._pending_new.clear() 385 self._pending_delete.clear() 386 387 def rollback(self) -> None: 388 """Discard all pending changes.""" 389 # Clear pending new — detach entities 390 for entity in self._pending_new: 391 entity._session = None 392 self._pending_new.clear() 393 394 # Clear pending deletes 395 self._pending_delete.clear() 396 397 # Invalidate dirty identity map entries 398 for entity in list(self._identity_map.values()): 399 if entity.is_dirty: 400 self.refresh(entity) 401 402 @contextmanager 403 def transaction(self) -> Iterator[UniTransaction]: 404 """Create a transaction context.""" 405 tx = UniTransaction(self) 406 with tx: 407 yield tx 408 409 def begin(self) -> UniTransaction: 410 """Begin a new transaction.""" 411 tx = UniTransaction(self) 412 tx._tx = self._db_session.tx() 413 return tx 414 415 def cypher( 416 self, 417 query: str, 418 params: dict[str, Any] | None = None, 419 result_type: type[NodeT] | None = None, 420 ) -> list[NodeT] | list[dict[str, Any]]: 421 """ 422 Execute a raw Cypher query. 423 424 Args: 425 query: Cypher query string. 426 params: Query parameters. 427 result_type: Optional model type for result mapping. 428 429 Returns: 430 List of results (model instances if result_type provided). 431 """ 432 results = self._db_session.query(query, params) 433 434 if result_type is None: 435 return [r.to_dict() for r in results] 436 437 # Map results to model instances 438 mapped = [] 439 for raw_row in results: 440 row = raw_row.to_dict() 441 # Try to find node data in the row 442 for key, value in row.items(): 443 if isinstance(value, dict): 444 # Check for _id/_label keys (uni-db node dict) 445 if "_id" in value and "_label" in value: 446 instance = self._result_to_model(value, result_type) 447 if instance: 448 mapped.append(instance) 449 break 450 # Also check if _label matches registered model 451 elif "_label" in value: 452 label = value["_label"] 453 if label in self._schema_gen._node_models: 454 model = self._schema_gen._node_models[label] 455 instance = self._result_to_model(value, model) 456 if instance: 457 mapped.append(instance) 458 break 459 else: 460 # Try the first column 461 first_value = next(iter(row.values()), None) 462 if isinstance(first_value, dict): 463 instance = self._result_to_model(first_value, result_type) 464 if instance: 465 mapped.append(instance) 466 467 return mapped 468 469 @staticmethod 470 def _validate_edge_endpoints( 471 source: UniNode, target: UniNode 472 ) -> tuple[int, int, str, str]: 473 """Validate that both endpoints are persisted and return (src_vid, dst_vid, src_label, dst_label).""" 474 if not source.is_persisted: 475 raise NotPersisted(source) 476 if not target.is_persisted: 477 raise NotPersisted(target) 478 return ( 479 source._vid, 480 target._vid, 481 source.__class__.__label__, 482 target.__class__.__label__, 483 ) 484 485 @staticmethod 486 def _normalize_edge_properties( 487 properties: dict[str, Any] | UniEdge | None, 488 ) -> dict[str, Any]: 489 """Normalize edge properties from dict, UniEdge, or None.""" 490 if isinstance(properties, UniEdge): 491 return properties.to_properties() 492 if properties: 493 return properties 494 return {} 495 496 def create_edge( 497 self, 498 source: UniNode, 499 edge_type: str, 500 target: UniNode, 501 properties: dict[str, Any] | UniEdge | None = None, 502 ) -> None: 503 """Create an edge between two nodes.""" 504 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 505 source, target 506 ) 507 props = self._normalize_edge_properties(properties) 508 509 # Build CREATE edge query with labels (required by Cypher implementation) 510 props_str = ", ".join(f"{k}: ${k}" for k in props) 511 if props_str: 512 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)" 513 else: 514 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)" 515 516 params = {"src": src_vid, "dst": dst_vid, **props} 517 with self._db_session.tx() as tx: 518 tx.execute(cypher, params) 519 tx.commit() 520 521 def delete_edge( 522 self, 523 source: UniNode, 524 edge_type: str, 525 target: UniNode, 526 ) -> int: 527 """Delete edges between two nodes. Returns the number of deleted edges.""" 528 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 529 source, target 530 ) 531 cypher = ( 532 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 533 f"WHERE a._vid = $src AND b._vid = $dst " 534 f"DELETE r RETURN count(r) as count" 535 ) 536 with self._db_session.tx() as tx: 537 results = tx.query(cypher, {"src": src_vid, "dst": dst_vid}) 538 tx.commit() 539 return cast(int, results[0]["count"]) if results else 0 540 541 def update_edge( 542 self, 543 source: UniNode, 544 edge_type: str, 545 target: UniNode, 546 properties: dict[str, Any], 547 ) -> int: 548 """Update properties on edges between two nodes. Returns the number of updated edges.""" 549 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 550 source, target 551 ) 552 set_parts = [f"r.{k} = ${k}" for k in properties] 553 params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties} 554 cypher = ( 555 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 556 f"WHERE a._vid = $src AND b._vid = $dst " 557 f"SET {', '.join(set_parts)} " 558 f"RETURN count(r) as count" 559 ) 560 with self._db_session.tx() as tx: 561 results = tx.query(cypher, params) 562 tx.commit() 563 return cast(int, results[0]["count"]) if results else 0 564 565 def get_edge( 566 self, 567 source: UniNode, 568 edge_type: str, 569 target: UniNode, 570 edge_model: type[EdgeT] | None = None, 571 ) -> list[dict[str, Any]] | list[EdgeT]: 572 """Get edges between two nodes. Returns dicts or edge model instances.""" 573 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 574 source, target 575 ) 576 cypher = ( 577 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 578 f"WHERE a._vid = $src AND b._vid = $dst " 579 f"RETURN properties(r) AS _props, id(r) AS _eid" 580 ) 581 results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid}) 582 rows = [r.to_dict() for r in results] 583 584 if edge_model is None: 585 edge_dicts: list[dict[str, Any]] = [] 586 for row in rows: 587 props = row.get("_props", {}) 588 if isinstance(props, dict): 589 edge_dict = dict(props) 590 edge_dict["_eid"] = row.get("_eid") 591 edge_dicts.append(edge_dict) 592 return edge_dicts 593 594 edges = [] 595 for row in rows: 596 r_data = row.get("_props", {}) 597 if isinstance(r_data, dict): 598 edge = edge_model.from_properties( 599 r_data, 600 src_vid=src_vid, 601 dst_vid=dst_vid, 602 session=self, 603 ) 604 edges.append(edge) 605 return edges 606 607 def bulk_add(self, entities: Sequence[UniNode]) -> list[int]: 608 """ 609 Bulk-add entities using bulk_writer for performance. 610 611 Groups entities by label and uses db.bulk_writer(). 612 Returns VIDs and attaches sessions. 613 614 Args: 615 entities: Sequence of UniNode instances to bulk-insert. 616 617 Returns: 618 List of assigned vertex IDs. 619 620 Raises: 621 BulkLoadError: If bulk insertion fails. 622 """ 623 if not entities: 624 return [] 625 626 # Group by label 627 by_label: dict[str, list[UniNode]] = {} 628 for entity in entities: 629 label = entity.__class__.__label__ 630 if label not in by_label: 631 by_label[label] = [] 632 by_label[label].append(entity) 633 634 all_vids: list[int] = [] 635 try: 636 for label, group in by_label.items(): 637 # Run before_create hooks 638 for entity in group: 639 run_hooks(entity, _BEFORE_CREATE) 640 641 # Convert to property dicts 642 prop_dicts = [e.to_properties() for e in group] 643 644 # Bulk insert via transaction 645 tx = self._db_session.tx() 646 with tx.bulk_writer().build() as bw: 647 vids = bw.insert_vertices(label, prop_dicts) 648 bw.commit() 649 tx.commit() 650 651 # Attach sessions and record VIDs 652 for entity, vid in zip(group, vids): 653 entity._attach_session(self, vid) 654 self._identity_map[(label, vid)] = entity 655 run_hooks(entity, _AFTER_CREATE) 656 entity._mark_clean() 657 658 all_vids.extend(vids) 659 except Exception as e: 660 raise BulkLoadError(f"Bulk insert failed: {e}") from e 661 662 return all_vids 663 664 def explain(self, cypher: str) -> uni_db.ExplainOutput: 665 """Get the query execution plan without running it.""" 666 return self._db_session.explain(cypher) 667 668 def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]: 669 """Run the query with profiling and return results + stats.""" 670 return self._db_session.profile(cypher) 671 672 def save_schema(self, path: str) -> None: 673 """Save the database schema to a file.""" 674 self._db.save_schema(path) 675 676 def load_schema(self, path: str) -> None: 677 """Load a database schema from a file.""" 678 self._db.load_schema(path) 679 680 # ------------------------------------------------------------------------- 681 # Internal Methods 682 # ------------------------------------------------------------------------- 683 684 def _create_node(self, entity: UniNode) -> None: 685 """Create a node in the database.""" 686 # Run before_create hooks 687 run_hooks(entity, _BEFORE_CREATE) 688 689 label = entity.__class__.__label__ 690 props = entity.to_properties() 691 692 # Build CREATE query 693 props_str = ", ".join(f"{k}: ${k}" for k in props) 694 cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid" 695 696 with self._db_session.tx() as tx: 697 results = tx.query(cypher, props) 698 tx.commit() 699 if results: 700 vid = results[0]["vid"] 701 entity._attach_session(self, vid) 702 703 # Add to identity map 704 self._identity_map[(label, vid)] = entity 705 706 # Run after_create hooks 707 run_hooks(entity, _AFTER_CREATE) 708 entity._mark_clean() 709 710 def _create_node_in_tx(self, entity: UniNode, tx: uni_db.Transaction) -> None: 711 """Create a node within a transaction.""" 712 run_hooks(entity, _BEFORE_CREATE) 713 714 label = entity.__class__.__label__ 715 props = entity.to_properties() 716 717 props_str = ", ".join(f"{k}: ${k}" for k in props) 718 cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid" 719 720 results = tx.query(cypher, props) 721 if results: 722 vid = results[0]["vid"] 723 entity._attach_session(self, vid) 724 self._identity_map[(label, vid)] = entity 725 726 run_hooks(entity, _AFTER_CREATE) 727 728 def _create_edge_in_tx( 729 self, 730 source: UniNode, 731 edge_type: str, 732 target: UniNode, 733 properties: UniEdge | None, 734 tx: uni_db.Transaction, 735 ) -> None: 736 """Create an edge within a transaction.""" 737 props = properties.to_properties() if properties else {} 738 src_label = source.__class__.__label__ 739 dst_label = target.__class__.__label__ 740 741 props_str = ", ".join(f"{k}: ${k}" for k in props) 742 if props_str: 743 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)" 744 else: 745 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)" 746 747 params = {"src": source._vid, "dst": target._vid, **props} 748 tx.query(cypher, params) 749 750 def _update_node(self, entity: UniNode) -> None: 751 """Update a node in the database.""" 752 run_hooks(entity, _BEFORE_UPDATE) 753 754 label = entity.__class__.__label__ 755 756 # Convert dirty prop values via python_to_db_value 757 try: 758 hints = get_type_hints(type(entity)) 759 except Exception: 760 hints = {} 761 762 dirty_props = {} 763 for name in entity._dirty: 764 value = getattr(entity, name) 765 if name in hints: 766 value = python_to_db_value(value, hints[name]) 767 dirty_props[name] = value 768 769 if not dirty_props: 770 return 771 772 set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props) 773 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}" 774 params = {"vid": entity._vid, **dirty_props} 775 776 with self._db_session.tx() as tx: 777 tx.execute(cypher, params) 778 tx.commit() 779 780 run_hooks(entity, _AFTER_UPDATE) 781 entity._mark_clean() 782 783 def _delete_node(self, entity: UniNode) -> None: 784 """Delete a node from the database.""" 785 run_hooks(entity, _BEFORE_DELETE) 786 787 label = entity.__class__.__label__ 788 vid = entity._vid 789 790 # DETACH DELETE to also remove connected edges 791 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n" 792 with self._db_session.tx() as tx: 793 tx.execute(cypher, {"vid": vid}) 794 tx.commit() 795 796 # Remove from identity map 797 if vid is not None and (label, vid) in self._identity_map: 798 del self._identity_map[(label, vid)] 799 800 # Clear entity IDs 801 entity._vid = None 802 entity._uid = None 803 entity._session = None 804 805 run_hooks(entity, _AFTER_DELETE) 806 807 def _result_to_model( 808 self, 809 data: dict[str, Any], 810 model: type[NodeT], 811 ) -> NodeT | None: 812 """Convert a query result row to a model instance. 813 814 Does not mutate the input dict. 815 """ 816 if not data: 817 return None 818 819 # Work on a copy 820 data = dict(data) 821 822 # Run before_load hooks 823 data = run_class_hooks(model, _BEFORE_LOAD, data) or data 824 825 # Extract _id → vid (uni-db returns _id as string or int) 826 vid = data.pop("_id", None) 827 if vid is None: 828 vid = data.pop("_vid", None) 829 if vid is None: 830 vid = data.pop("vid", None) 831 if vid is not None and not isinstance(vid, int): 832 vid = int(vid) 833 834 # Remove _label (informational) 835 data.pop("_label", None) 836 837 try: 838 instance = cast( 839 NodeT, 840 model.from_properties( 841 data, 842 vid=vid, 843 session=self, 844 ), 845 ) 846 except Exception: 847 # If validation fails, return None 848 return None 849 850 # Add to identity map if we have a vid 851 if vid is not None: 852 existing = self._identity_map.get((model.__label__, vid)) 853 if existing is not None: 854 return cast(NodeT, existing) 855 self._identity_map[(model.__label__, vid)] = instance 856 857 # Run after_load hooks 858 run_hooks(instance, _AFTER_LOAD) 859 860 return instance 861 862 def _load_relationship( 863 self, 864 entity: UniNode, 865 descriptor: RelationshipDescriptor[Any], 866 ) -> list[UniNode] | UniNode | None: 867 """Load a relationship for an entity.""" 868 if not entity.is_persisted: 869 raise NotPersisted(entity) 870 871 config = descriptor.config 872 label = entity.__class__.__label__ 873 pattern = _edge_pattern(config.edge_type, config.direction) 874 875 cypher = ( 876 f"MATCH (a:{label}){pattern}(b) WHERE id(a) = $vid " 877 f"RETURN properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels" 878 ) 879 results = self._db_session.query(cypher, {"vid": entity._vid}) 880 881 nodes = [] 882 for raw_row in results: 883 row = raw_row.to_dict() 884 node_data = _row_to_node_dict(row) 885 if node_data is None: 886 continue 887 # Try to find the model for this node 888 node_label = node_data.get("_label") 889 if node_label and node_label in self._schema_gen._node_models: 890 model = self._schema_gen._node_models[node_label] 891 instance = self._result_to_model(node_data, model) 892 if instance: 893 nodes.append(instance) 894 895 if not descriptor.is_list: 896 return nodes[0] if nodes else None 897 return nodes 898 899 def _eager_load_relationships( 900 self, 901 entities: list[NodeT], 902 relationships: list[str], 903 ) -> None: 904 """Eager load relationships for a list of entities.""" 905 if not entities: 906 return 907 908 model = type(entities[0]) 909 rel_configs = model.get_relationship_fields() 910 911 for rel_name in relationships: 912 if rel_name not in rel_configs: 913 continue 914 915 config = rel_configs[rel_name] 916 label = model.__label__ 917 vids = [e._vid for e in entities if e._vid is not None] 918 919 if not vids: 920 continue 921 922 pattern = _edge_pattern(config.edge_type, config.direction) 923 cypher = ( 924 f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids " 925 f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels" 926 ) 927 results = self._db_session.query(cypher, {"vids": vids}) 928 929 # Group results by source vid 930 by_source: dict[int, list[Any]] = {} 931 for raw_row in results: 932 row = raw_row.to_dict() 933 src_vid = row["src_vid"] 934 node_data = _row_to_node_dict(row) 935 if node_data is None: 936 continue 937 if src_vid not in by_source: 938 by_source[src_vid] = [] 939 by_source[src_vid].append(node_data) 940 941 # Set cached values on entities 942 for entity in entities: 943 if entity._vid in by_source: 944 related = by_source[entity._vid] 945 cache_attr = f"_rel_cache_{rel_name}" 946 setattr(entity, cache_attr, related)
Session for interacting with the graph database using Pydantic models.
The session manages model registration, schema synchronization, and provides CRUD operations and query building.
Example:
from uni_db import Uni from uni_pydantic import UniSession
db = Uni("./my_graph") session = UniSession(db) session.register(Person, Company) session.sync_schema()
alice = Person(name="Alice", age=30) session.add(alice) session.commit()
181 def __init__(self, db: uni_db.Uni) -> None: 182 self._db = db 183 self._db_session = db.session() 184 self._schema_gen = SchemaGenerator() 185 self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = ( 186 WeakValueDictionary() 187 ) 188 self._pending_new: list[UniNode] = [] 189 self._pending_delete: list[UniNode] = []
202 def close(self) -> None: 203 """Close the session and clear all pending state.""" 204 self._pending_new.clear() 205 self._pending_delete.clear()
Close the session and clear all pending state.
207 @property 208 def db(self) -> uni_db.Uni: 209 """Access the underlying uni_db.Uni for low-level operations.""" 210 return self._db
Access the underlying uni_db.Uni for low-level operations.
212 def locy( 213 self, program: str, params: dict[str, Any] | None = None 214 ) -> uni_db.LocyResult: 215 """ 216 Evaluate a Locy program and return derived facts, stats, and warnings. 217 218 Delegates to the underlying ``uni_db.Session.locy()``. 219 """ 220 return self._db_session.locy(program, params)
Evaluate a Locy program and return derived facts, stats, and warnings.
Delegates to the underlying uni_db.Session.locy().
222 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 223 """ 224 Register model classes with the session. 225 226 Registered models can be used for schema generation and queries. 227 228 Args: 229 *models: UniNode or UniEdge subclasses to register. 230 """ 231 self._schema_gen.register(*models)
Register model classes with the session.
Registered models can be used for schema generation and queries.
Args: *models: UniNode or UniEdge subclasses to register.
233 def sync_schema(self) -> None: 234 """ 235 Synchronize database schema with registered models. 236 237 Creates labels, edge types, properties, and indexes as needed. 238 This is additive-only; it won't remove existing schema elements. 239 """ 240 self._schema_gen.apply_to_database(self._db)
Synchronize database schema with registered models.
Creates labels, edge types, properties, and indexes as needed. This is additive-only; it won't remove existing schema elements.
242 def query(self, model: type[NodeT]) -> QueryBuilder[NodeT]: 243 """ 244 Create a query builder for the given model. 245 246 Args: 247 model: The UniNode subclass to query. 248 249 Returns: 250 A QueryBuilder for constructing queries. 251 """ 252 return QueryBuilder(self, model)
Create a query builder for the given model.
Args: model: The UniNode subclass to query.
Returns: A QueryBuilder for constructing queries.
254 def add(self, entity: UniNode) -> None: 255 """ 256 Add a new entity to be persisted. 257 258 The entity will be inserted on the next commit(). 259 """ 260 if entity.is_persisted: 261 raise SessionError(f"Entity {entity!r} is already persisted") 262 entity._session = self 263 self._pending_new.append(entity)
Add a new entity to be persisted.
The entity will be inserted on the next commit().
265 def add_all(self, entities: Sequence[UniNode]) -> None: 266 """Add multiple entities to be persisted.""" 267 for entity in entities: 268 self.add(entity)
Add multiple entities to be persisted.
270 def delete(self, entity: UniNode) -> None: 271 """Mark an entity for deletion.""" 272 if not entity.is_persisted: 273 raise NotPersisted(entity) 274 self._pending_delete.append(entity)
Mark an entity for deletion.
276 def get( 277 self, 278 model: type[NodeT], 279 vid: int | None = None, 280 uid: str | None = None, 281 **kwargs: Any, 282 ) -> NodeT | None: 283 """ 284 Get an entity by ID or unique properties. 285 286 Args: 287 model: The model type to retrieve. 288 vid: Vertex ID to look up. 289 uid: Unique ID to look up. 290 **kwargs: Property equality filters. 291 292 Returns: 293 The model instance or None if not found. 294 """ 295 # Check identity map first 296 if vid is not None: 297 cached = self._identity_map.get((model.__label__, vid)) 298 if cached is not None: 299 return cached # type: ignore[return-value] 300 301 # Build query 302 label = model.__label__ 303 params: dict[str, Any] = {} 304 305 if vid is not None: 306 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 307 params["vid"] = vid 308 elif uid is not None: 309 cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}" 310 params["uid"] = uid 311 elif kwargs: 312 # Validate property names 313 for k in kwargs: 314 _validate_property(k, model) 315 conditions = [f"n.{k} = ${k}" for k in kwargs] 316 cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1" 317 params.update(kwargs) 318 else: 319 raise ValueError("Must provide vid, uid, or property filters") 320 321 results = self._db_session.query(cypher, params) 322 if not results: 323 return None 324 325 node_data = _row_to_node_dict(results[0].to_dict()) 326 if node_data is None: 327 return None 328 return self._result_to_model(node_data, model)
Get an entity by ID or unique properties.
Args: model: The model type to retrieve. vid: Vertex ID to look up. uid: Unique ID to look up. **kwargs: Property equality filters.
Returns: The model instance or None if not found.
330 def refresh(self, entity: UniNode) -> None: 331 """Refresh an entity's properties from the database.""" 332 if not entity.is_persisted: 333 raise NotPersisted(entity) 334 335 label = entity.__class__.__label__ 336 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 337 results = self._db_session.query(cypher, {"vid": entity._vid}) 338 339 if not results: 340 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 341 342 # Update properties 343 props = _row_to_node_dict(results[0].to_dict()) 344 if props is None: 345 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 346 try: 347 hints = get_type_hints(type(entity)) 348 except Exception: 349 hints = {} 350 351 for field_name in entity.get_property_fields(): 352 if field_name in props: 353 value = props[field_name] 354 if field_name in hints: 355 value = db_to_python_value(value, hints[field_name]) 356 setattr(entity, field_name, value) 357 358 entity._mark_clean()
Refresh an entity's properties from the database.
360 def commit(self) -> None: 361 """ 362 Commit all pending changes to the database. 363 364 This persists new entities, updates dirty entities, 365 and deletes marked entities. 366 """ 367 # Insert new entities 368 for entity in self._pending_new: 369 self._create_node(entity) 370 371 # Update dirty entities in identity map 372 for (label, vid), entity in list(self._identity_map.items()): 373 if entity.is_dirty and entity.is_persisted: 374 self._update_node(entity) 375 376 # Delete marked entities 377 for entity in self._pending_delete: 378 self._delete_node(entity) 379 380 # Flush to storage 381 self._db.flush() 382 383 # Clear pending lists 384 self._pending_new.clear() 385 self._pending_delete.clear()
Commit all pending changes to the database.
This persists new entities, updates dirty entities, and deletes marked entities.
387 def rollback(self) -> None: 388 """Discard all pending changes.""" 389 # Clear pending new — detach entities 390 for entity in self._pending_new: 391 entity._session = None 392 self._pending_new.clear() 393 394 # Clear pending deletes 395 self._pending_delete.clear() 396 397 # Invalidate dirty identity map entries 398 for entity in list(self._identity_map.values()): 399 if entity.is_dirty: 400 self.refresh(entity)
Discard all pending changes.
402 @contextmanager 403 def transaction(self) -> Iterator[UniTransaction]: 404 """Create a transaction context.""" 405 tx = UniTransaction(self) 406 with tx: 407 yield tx
Create a transaction context.
409 def begin(self) -> UniTransaction: 410 """Begin a new transaction.""" 411 tx = UniTransaction(self) 412 tx._tx = self._db_session.tx() 413 return tx
Begin a new transaction.
415 def cypher( 416 self, 417 query: str, 418 params: dict[str, Any] | None = None, 419 result_type: type[NodeT] | None = None, 420 ) -> list[NodeT] | list[dict[str, Any]]: 421 """ 422 Execute a raw Cypher query. 423 424 Args: 425 query: Cypher query string. 426 params: Query parameters. 427 result_type: Optional model type for result mapping. 428 429 Returns: 430 List of results (model instances if result_type provided). 431 """ 432 results = self._db_session.query(query, params) 433 434 if result_type is None: 435 return [r.to_dict() for r in results] 436 437 # Map results to model instances 438 mapped = [] 439 for raw_row in results: 440 row = raw_row.to_dict() 441 # Try to find node data in the row 442 for key, value in row.items(): 443 if isinstance(value, dict): 444 # Check for _id/_label keys (uni-db node dict) 445 if "_id" in value and "_label" in value: 446 instance = self._result_to_model(value, result_type) 447 if instance: 448 mapped.append(instance) 449 break 450 # Also check if _label matches registered model 451 elif "_label" in value: 452 label = value["_label"] 453 if label in self._schema_gen._node_models: 454 model = self._schema_gen._node_models[label] 455 instance = self._result_to_model(value, model) 456 if instance: 457 mapped.append(instance) 458 break 459 else: 460 # Try the first column 461 first_value = next(iter(row.values()), None) 462 if isinstance(first_value, dict): 463 instance = self._result_to_model(first_value, result_type) 464 if instance: 465 mapped.append(instance) 466 467 return mapped
Execute a raw Cypher query.
Args: query: Cypher query string. params: Query parameters. result_type: Optional model type for result mapping.
Returns: List of results (model instances if result_type provided).
496 def create_edge( 497 self, 498 source: UniNode, 499 edge_type: str, 500 target: UniNode, 501 properties: dict[str, Any] | UniEdge | None = None, 502 ) -> None: 503 """Create an edge between two nodes.""" 504 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 505 source, target 506 ) 507 props = self._normalize_edge_properties(properties) 508 509 # Build CREATE edge query with labels (required by Cypher implementation) 510 props_str = ", ".join(f"{k}: ${k}" for k in props) 511 if props_str: 512 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)" 513 else: 514 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)" 515 516 params = {"src": src_vid, "dst": dst_vid, **props} 517 with self._db_session.tx() as tx: 518 tx.execute(cypher, params) 519 tx.commit()
Create an edge between two nodes.
521 def delete_edge( 522 self, 523 source: UniNode, 524 edge_type: str, 525 target: UniNode, 526 ) -> int: 527 """Delete edges between two nodes. Returns the number of deleted edges.""" 528 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 529 source, target 530 ) 531 cypher = ( 532 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 533 f"WHERE a._vid = $src AND b._vid = $dst " 534 f"DELETE r RETURN count(r) as count" 535 ) 536 with self._db_session.tx() as tx: 537 results = tx.query(cypher, {"src": src_vid, "dst": dst_vid}) 538 tx.commit() 539 return cast(int, results[0]["count"]) if results else 0
Delete edges between two nodes. Returns the number of deleted edges.
541 def update_edge( 542 self, 543 source: UniNode, 544 edge_type: str, 545 target: UniNode, 546 properties: dict[str, Any], 547 ) -> int: 548 """Update properties on edges between two nodes. Returns the number of updated edges.""" 549 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 550 source, target 551 ) 552 set_parts = [f"r.{k} = ${k}" for k in properties] 553 params: dict[str, Any] = {"src": src_vid, "dst": dst_vid, **properties} 554 cypher = ( 555 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 556 f"WHERE a._vid = $src AND b._vid = $dst " 557 f"SET {', '.join(set_parts)} " 558 f"RETURN count(r) as count" 559 ) 560 with self._db_session.tx() as tx: 561 results = tx.query(cypher, params) 562 tx.commit() 563 return cast(int, results[0]["count"]) if results else 0
Update properties on edges between two nodes. Returns the number of updated edges.
565 def get_edge( 566 self, 567 source: UniNode, 568 edge_type: str, 569 target: UniNode, 570 edge_model: type[EdgeT] | None = None, 571 ) -> list[dict[str, Any]] | list[EdgeT]: 572 """Get edges between two nodes. Returns dicts or edge model instances.""" 573 src_vid, dst_vid, src_label, dst_label = self._validate_edge_endpoints( 574 source, target 575 ) 576 cypher = ( 577 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 578 f"WHERE a._vid = $src AND b._vid = $dst " 579 f"RETURN properties(r) AS _props, id(r) AS _eid" 580 ) 581 results = self._db_session.query(cypher, {"src": src_vid, "dst": dst_vid}) 582 rows = [r.to_dict() for r in results] 583 584 if edge_model is None: 585 edge_dicts: list[dict[str, Any]] = [] 586 for row in rows: 587 props = row.get("_props", {}) 588 if isinstance(props, dict): 589 edge_dict = dict(props) 590 edge_dict["_eid"] = row.get("_eid") 591 edge_dicts.append(edge_dict) 592 return edge_dicts 593 594 edges = [] 595 for row in rows: 596 r_data = row.get("_props", {}) 597 if isinstance(r_data, dict): 598 edge = edge_model.from_properties( 599 r_data, 600 src_vid=src_vid, 601 dst_vid=dst_vid, 602 session=self, 603 ) 604 edges.append(edge) 605 return edges
Get edges between two nodes. Returns dicts or edge model instances.
607 def bulk_add(self, entities: Sequence[UniNode]) -> list[int]: 608 """ 609 Bulk-add entities using bulk_writer for performance. 610 611 Groups entities by label and uses db.bulk_writer(). 612 Returns VIDs and attaches sessions. 613 614 Args: 615 entities: Sequence of UniNode instances to bulk-insert. 616 617 Returns: 618 List of assigned vertex IDs. 619 620 Raises: 621 BulkLoadError: If bulk insertion fails. 622 """ 623 if not entities: 624 return [] 625 626 # Group by label 627 by_label: dict[str, list[UniNode]] = {} 628 for entity in entities: 629 label = entity.__class__.__label__ 630 if label not in by_label: 631 by_label[label] = [] 632 by_label[label].append(entity) 633 634 all_vids: list[int] = [] 635 try: 636 for label, group in by_label.items(): 637 # Run before_create hooks 638 for entity in group: 639 run_hooks(entity, _BEFORE_CREATE) 640 641 # Convert to property dicts 642 prop_dicts = [e.to_properties() for e in group] 643 644 # Bulk insert via transaction 645 tx = self._db_session.tx() 646 with tx.bulk_writer().build() as bw: 647 vids = bw.insert_vertices(label, prop_dicts) 648 bw.commit() 649 tx.commit() 650 651 # Attach sessions and record VIDs 652 for entity, vid in zip(group, vids): 653 entity._attach_session(self, vid) 654 self._identity_map[(label, vid)] = entity 655 run_hooks(entity, _AFTER_CREATE) 656 entity._mark_clean() 657 658 all_vids.extend(vids) 659 except Exception as e: 660 raise BulkLoadError(f"Bulk insert failed: {e}") from e 661 662 return all_vids
Bulk-add entities using bulk_writer for performance.
Groups entities by label and uses db.bulk_writer(). Returns VIDs and attaches sessions.
Args: entities: Sequence of UniNode instances to bulk-insert.
Returns: List of assigned vertex IDs.
Raises: BulkLoadError: If bulk insertion fails.
664 def explain(self, cypher: str) -> uni_db.ExplainOutput: 665 """Get the query execution plan without running it.""" 666 return self._db_session.explain(cypher)
Get the query execution plan without running it.
668 def profile(self, cypher: str) -> tuple[uni_db.QueryResult, uni_db.ProfileOutput]: 669 """Run the query with profiling and return results + stats.""" 670 return self._db_session.profile(cypher)
Run the query with profiling and return results + stats.
61class UniTransaction: 62 """ 63 Transaction context for atomic operations. 64 65 Provides commit/rollback semantics for a group of operations. 66 67 Example: 68 >>> with session.transaction() as tx: 69 ... alice = Person(name="Alice") 70 ... tx.add(alice) 71 ... # Auto-commits on success, rolls back on exception 72 """ 73 74 def __init__(self, session: UniSession) -> None: 75 self._session = session 76 self._tx: uni_db.Transaction | None = None 77 self._pending_nodes: list[UniNode] = [] 78 self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = [] 79 self._committed = False 80 self._rolled_back = False 81 82 def __enter__(self) -> UniTransaction: 83 self._tx = self._session._db_session.tx() 84 return self 85 86 def __exit__( 87 self, 88 exc_type: type[BaseException] | None, 89 exc_val: BaseException | None, 90 exc_tb: TracebackType | None, 91 ) -> None: 92 if exc_type is not None: 93 self.rollback() 94 return 95 if not self._committed and not self._rolled_back: 96 self.commit() 97 98 def add(self, entity: UniNode) -> None: 99 """Add a node to be created in this transaction.""" 100 self._pending_nodes.append(entity) 101 102 def create_edge( 103 self, 104 source: UniNode, 105 edge_type: str, 106 target: UniNode, 107 properties: UniEdge | None = None, 108 **kwargs: Any, 109 ) -> None: 110 """Create an edge between two nodes in this transaction.""" 111 if not source.is_persisted: 112 raise NotPersisted(source) 113 if not target.is_persisted: 114 raise NotPersisted(target) 115 self._pending_edges.append((source, edge_type, target, properties)) 116 117 def commit(self) -> None: 118 """Commit the transaction.""" 119 if self._committed: 120 raise TransactionError("Transaction already committed") 121 if self._rolled_back: 122 raise TransactionError("Transaction already rolled back") 123 124 if self._tx is None: 125 raise TransactionError("Transaction not started") 126 127 try: 128 # Create pending nodes 129 for node in self._pending_nodes: 130 self._session._create_node_in_tx(node, self._tx) 131 132 # Create pending edges 133 for source, edge_type, target, props in self._pending_edges: 134 self._session._create_edge_in_tx( 135 source, edge_type, target, props, self._tx 136 ) 137 138 self._tx.commit() 139 self._committed = True 140 141 # Mark nodes as clean 142 for node in self._pending_nodes: 143 node._mark_clean() 144 145 except Exception as e: 146 self.rollback() 147 raise TransactionError(f"Commit failed: {e}") from e 148 149 def rollback(self) -> None: 150 """Rollback the transaction.""" 151 if self._rolled_back: 152 return 153 if self._tx is not None: 154 self._tx.rollback() 155 self._rolled_back = True 156 self._pending_nodes.clear() 157 self._pending_edges.clear()
Transaction context for atomic operations.
Provides commit/rollback semantics for a group of operations.
Example:
with session.transaction() as tx: ... alice = Person(name="Alice") ... tx.add(alice) ... # Auto-commits on success, rolls back on exception
74 def __init__(self, session: UniSession) -> None: 75 self._session = session 76 self._tx: uni_db.Transaction | None = None 77 self._pending_nodes: list[UniNode] = [] 78 self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = [] 79 self._committed = False 80 self._rolled_back = False
98 def add(self, entity: UniNode) -> None: 99 """Add a node to be created in this transaction.""" 100 self._pending_nodes.append(entity)
Add a node to be created in this transaction.
102 def create_edge( 103 self, 104 source: UniNode, 105 edge_type: str, 106 target: UniNode, 107 properties: UniEdge | None = None, 108 **kwargs: Any, 109 ) -> None: 110 """Create an edge between two nodes in this transaction.""" 111 if not source.is_persisted: 112 raise NotPersisted(source) 113 if not target.is_persisted: 114 raise NotPersisted(target) 115 self._pending_edges.append((source, edge_type, target, properties))
Create an edge between two nodes in this transaction.
117 def commit(self) -> None: 118 """Commit the transaction.""" 119 if self._committed: 120 raise TransactionError("Transaction already committed") 121 if self._rolled_back: 122 raise TransactionError("Transaction already rolled back") 123 124 if self._tx is None: 125 raise TransactionError("Transaction not started") 126 127 try: 128 # Create pending nodes 129 for node in self._pending_nodes: 130 self._session._create_node_in_tx(node, self._tx) 131 132 # Create pending edges 133 for source, edge_type, target, props in self._pending_edges: 134 self._session._create_edge_in_tx( 135 source, edge_type, target, props, self._tx 136 ) 137 138 self._tx.commit() 139 self._committed = True 140 141 # Mark nodes as clean 142 for node in self._pending_nodes: 143 node._mark_clean() 144 145 except Exception as e: 146 self.rollback() 147 raise TransactionError(f"Commit failed: {e}") from e
Commit the transaction.
149 def rollback(self) -> None: 150 """Rollback the transaction.""" 151 if self._rolled_back: 152 return 153 if self._tx is not None: 154 self._tx.rollback() 155 self._rolled_back = True 156 self._pending_nodes.clear() 157 self._pending_edges.clear()
Rollback the transaction.
134class AsyncUniSession: 135 """ 136 Async session for interacting with the graph database. 137 138 Mirrors UniSession with async methods. Uses AsyncUni. 139 140 Example: 141 >>> from uni_db import AsyncUni 142 >>> from uni_pydantic import AsyncUniSession 143 >>> 144 >>> db = await AsyncUni.open("./my_graph") 145 >>> async with AsyncUniSession(db) as session: 146 ... session.register(Person) 147 ... await session.sync_schema() 148 ... alice = Person(name="Alice", age=30) 149 ... session.add(alice) 150 ... await session.commit() 151 """ 152 153 def __init__(self, db: uni_db.AsyncUni) -> None: 154 self._db = db 155 self._db_session = db.session() 156 self._schema_gen = SchemaGenerator() 157 self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = ( 158 WeakValueDictionary() 159 ) 160 self._pending_new: list[UniNode] = [] 161 self._pending_delete: list[UniNode] = [] 162 163 async def __aenter__(self) -> AsyncUniSession: 164 return self 165 166 async def __aexit__( 167 self, 168 exc_type: type[BaseException] | None, 169 exc_val: BaseException | None, 170 exc_tb: TracebackType | None, 171 ) -> None: 172 self.close() 173 174 def close(self) -> None: 175 """Close the session and clear pending state.""" 176 self._pending_new.clear() 177 self._pending_delete.clear() 178 179 @property 180 def db(self) -> uni_db.AsyncUni: 181 """Access the underlying uni_db.AsyncUni for low-level operations.""" 182 return self._db 183 184 async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any: 185 """ 186 Evaluate a Locy program and return derived facts, stats, and warnings. 187 188 Delegates to the underlying ``uni_db.AsyncSession.locy()``. 189 """ 190 return await self._db_session.locy(program, params) 191 192 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 193 """Register model classes with the session (sync).""" 194 self._schema_gen.register(*models) 195 196 async def sync_schema(self) -> None: 197 """Synchronize database schema with registered models.""" 198 await self._schema_gen.async_apply_to_database(self._db) 199 200 def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]: 201 """Create an async query builder for the given model.""" 202 return AsyncQueryBuilder(self, model) 203 204 def add(self, entity: UniNode) -> None: 205 """Add a new entity to be persisted (sync — just collects).""" 206 if entity.is_persisted: 207 raise SessionError(f"Entity {entity!r} is already persisted") 208 entity._session = self 209 self._pending_new.append(entity) 210 211 def add_all(self, entities: Sequence[UniNode]) -> None: 212 """Add multiple entities (sync — just collects).""" 213 for entity in entities: 214 self.add(entity) 215 216 def delete(self, entity: UniNode) -> None: 217 """Mark an entity for deletion (sync — just collects).""" 218 if not entity.is_persisted: 219 raise NotPersisted(entity) 220 self._pending_delete.append(entity) 221 222 async def get( 223 self, 224 model: type[NodeT], 225 vid: int | None = None, 226 uid: str | None = None, 227 **kwargs: Any, 228 ) -> NodeT | None: 229 """Get an entity by ID or unique properties.""" 230 if vid is not None: 231 cached = self._identity_map.get((model.__label__, vid)) 232 if cached is not None: 233 return cached # type: ignore[return-value] 234 235 label = model.__label__ 236 params: dict[str, Any] = {} 237 238 if vid is not None: 239 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 240 params["vid"] = vid 241 elif uid is not None: 242 cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}" 243 params["uid"] = uid 244 elif kwargs: 245 for k in kwargs: 246 _validate_property(k, model) 247 conditions = [f"n.{k} = ${k}" for k in kwargs] 248 cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1" 249 params.update(kwargs) 250 else: 251 raise ValueError("Must provide vid, uid, or property filters") 252 253 results = await self._db_session.query(cypher, params) 254 if not results: 255 return None 256 257 node_data = _row_to_node_dict(results[0].to_dict()) 258 if node_data is None: 259 return None 260 return self._result_to_model(node_data, model) 261 262 async def refresh(self, entity: UniNode) -> None: 263 """Refresh an entity's properties from the database.""" 264 if not entity.is_persisted: 265 raise NotPersisted(entity) 266 267 label = entity.__class__.__label__ 268 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 269 results = await self._db_session.query(cypher, {"vid": entity._vid}) 270 271 if not results: 272 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 273 274 props = _row_to_node_dict(results[0].to_dict()) 275 if props is None: 276 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 277 try: 278 hints = get_type_hints(type(entity)) 279 except Exception: 280 hints = {} 281 282 for field_name in entity.get_property_fields(): 283 if field_name in props: 284 value = props[field_name] 285 if field_name in hints: 286 value = db_to_python_value(value, hints[field_name]) 287 setattr(entity, field_name, value) 288 289 entity._mark_clean() 290 291 async def commit(self) -> None: 292 """Commit all pending changes.""" 293 for entity in self._pending_new: 294 await self._create_node(entity) 295 296 for (label, vid), entity in list(self._identity_map.items()): 297 if entity.is_dirty and entity.is_persisted: 298 await self._update_node(entity) 299 300 for entity in self._pending_delete: 301 await self._delete_node(entity) 302 303 await self._db.flush() 304 self._pending_new.clear() 305 self._pending_delete.clear() 306 307 async def rollback(self) -> None: 308 """Discard all pending changes.""" 309 for entity in self._pending_new: 310 entity._session = None 311 self._pending_new.clear() 312 self._pending_delete.clear() 313 for entity in list(self._identity_map.values()): 314 if entity.is_dirty: 315 await self.refresh(entity) 316 317 async def transaction(self) -> AsyncUniTransaction: 318 """Create an async transaction. Use as `async with session.transaction() as tx:`.""" 319 return AsyncUniTransaction(self) 320 321 async def cypher( 322 self, 323 query: str, 324 params: dict[str, Any] | None = None, 325 result_type: type[NodeT] | None = None, 326 ) -> list[NodeT] | list[dict[str, Any]]: 327 """Execute a raw Cypher query.""" 328 results = await self._db_session.query(query, params) 329 330 if result_type is None: 331 return [r.to_dict() for r in results] 332 333 mapped = [] 334 for raw_row in results: 335 row = raw_row.to_dict() 336 for key, value in row.items(): 337 if isinstance(value, dict): 338 if "_id" in value and "_label" in value: 339 instance = self._result_to_model(value, result_type) 340 if instance: 341 mapped.append(instance) 342 break 343 elif "_label" in value: 344 label = value["_label"] 345 if label in self._schema_gen._node_models: 346 model = self._schema_gen._node_models[label] 347 instance = self._result_to_model(value, model) 348 if instance: 349 mapped.append(instance) 350 break 351 else: 352 first_value = next(iter(row.values()), None) 353 if isinstance(first_value, dict): 354 instance = self._result_to_model(first_value, result_type) 355 if instance: 356 mapped.append(instance) 357 358 return mapped 359 360 async def create_edge( 361 self, 362 source: UniNode, 363 edge_type: str, 364 target: UniNode, 365 properties: dict[str, Any] | UniEdge | None = None, 366 ) -> None: 367 """Create an edge between two nodes.""" 368 src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints( 369 source, target 370 ) 371 props = UniSession._normalize_edge_properties(properties) 372 373 props_str = ", ".join(f"{k}: ${k}" for k in props) 374 if props_str: 375 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)" 376 else: 377 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)" 378 379 async with await self._db_session.tx() as tx: 380 await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props}) 381 await tx.commit() 382 383 async def delete_edge( 384 self, source: UniNode, edge_type: str, target: UniNode 385 ) -> int: 386 """Delete edges between two nodes. Returns the number of deleted edges.""" 387 src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints( 388 source, target 389 ) 390 cypher = ( 391 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 392 f"WHERE a._vid = $src AND b._vid = $dst " 393 f"DELETE r RETURN count(r) as count" 394 ) 395 async with await self._db_session.tx() as tx: 396 results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid}) 397 await tx.commit() 398 return cast(int, results[0]["count"]) if results else 0 399 400 async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]: 401 """Bulk-add entities using bulk_writer.""" 402 if not entities: 403 return [] 404 405 by_label: dict[str, list[UniNode]] = {} 406 for entity in entities: 407 label = entity.__class__.__label__ 408 if label not in by_label: 409 by_label[label] = [] 410 by_label[label].append(entity) 411 412 all_vids: list[int] = [] 413 try: 414 for label, group in by_label.items(): 415 for entity in group: 416 run_hooks(entity, _BEFORE_CREATE) 417 prop_dicts = [e.to_properties() for e in group] 418 tx = await self._db_session.tx() 419 async with await tx.bulk_writer().build() as bw: 420 vids = await bw.insert_vertices(label, prop_dicts) 421 await bw.commit() 422 await tx.commit() 423 for entity, vid in zip(group, vids): 424 entity._attach_session(self, vid) 425 self._identity_map[(label, vid)] = entity 426 run_hooks(entity, _AFTER_CREATE) 427 entity._mark_clean() 428 all_vids.extend(vids) 429 except Exception as e: 430 raise BulkLoadError(f"Bulk insert failed: {e}") from e 431 432 return all_vids 433 434 async def explain(self, cypher: str) -> Any: 435 """Get the query execution plan.""" 436 return await self._db_session.explain(cypher) 437 438 async def profile(self, cypher: str) -> Any: 439 """Run the query with profiling and return results + stats.""" 440 return await self._db_session.profile(cypher) 441 442 async def save_schema(self, path: str) -> None: 443 """Save the database schema to a file.""" 444 await self._db.save_schema(path) 445 446 async def load_schema(self, path: str) -> None: 447 """Load a database schema from a file.""" 448 await self._db.load_schema(path) 449 450 # ---- Internal methods ---- 451 452 async def _create_node(self, entity: UniNode) -> None: 453 run_hooks(entity, _BEFORE_CREATE) 454 label = entity.__class__.__label__ 455 props = entity.to_properties() 456 props_str = ", ".join(f"{k}: ${k}" for k in props) 457 cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid" 458 async with await self._db_session.tx() as tx: 459 results = await tx.query(cypher, props) 460 await tx.commit() 461 if results: 462 vid = results[0]["vid"] 463 entity._attach_session(self, vid) 464 self._identity_map[(label, vid)] = entity 465 run_hooks(entity, _AFTER_CREATE) 466 entity._mark_clean() 467 468 async def _create_node_in_tx( 469 self, entity: UniNode, tx: uni_db.AsyncTransaction 470 ) -> None: 471 run_hooks(entity, _BEFORE_CREATE) 472 label = entity.__class__.__label__ 473 props = entity.to_properties() 474 props_str = ", ".join(f"{k}: ${k}" for k in props) 475 cypher = f"CREATE (n:{label} {{{props_str}}}) RETURN id(n) as vid" 476 results = await tx.query(cypher, props) 477 if results: 478 vid = results[0]["vid"] 479 entity._attach_session(self, vid) 480 self._identity_map[(label, vid)] = entity 481 run_hooks(entity, _AFTER_CREATE) 482 483 async def _create_edge_in_tx( 484 self, 485 source: UniNode, 486 edge_type: str, 487 target: UniNode, 488 properties: UniEdge | None, 489 tx: uni_db.AsyncTransaction, 490 ) -> None: 491 props = properties.to_properties() if properties else {} 492 src_label = source.__class__.__label__ 493 dst_label = target.__class__.__label__ 494 props_str = ", ".join(f"{k}: ${k}" for k in props) 495 if props_str: 496 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type} {{{props_str}}}]->(b)" 497 else: 498 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[:{edge_type}]->(b)" 499 params = {"src": source._vid, "dst": target._vid, **props} 500 await tx.query(cypher, params) 501 502 async def _update_node(self, entity: UniNode) -> None: 503 run_hooks(entity, _BEFORE_UPDATE) 504 label = entity.__class__.__label__ 505 try: 506 hints = get_type_hints(type(entity)) 507 except Exception: 508 hints = {} 509 dirty_props = {} 510 for name in entity._dirty: 511 value = getattr(entity, name) 512 if name in hints: 513 value = python_to_db_value(value, hints[name]) 514 dirty_props[name] = value 515 if not dirty_props: 516 return 517 set_clause = ", ".join(f"n.{k} = ${k}" for k in dirty_props) 518 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid SET {set_clause}" 519 params = {"vid": entity._vid, **dirty_props} 520 async with await self._db_session.tx() as tx: 521 await tx.execute(cypher, params) 522 await tx.commit() 523 run_hooks(entity, _AFTER_UPDATE) 524 entity._mark_clean() 525 526 async def _delete_node(self, entity: UniNode) -> None: 527 run_hooks(entity, _BEFORE_DELETE) 528 label = entity.__class__.__label__ 529 vid = entity._vid 530 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid DETACH DELETE n" 531 async with await self._db_session.tx() as tx: 532 await tx.execute(cypher, {"vid": vid}) 533 await tx.commit() 534 if vid is not None and (label, vid) in self._identity_map: 535 del self._identity_map[(label, vid)] 536 entity._vid = None 537 entity._uid = None 538 entity._session = None 539 run_hooks(entity, _AFTER_DELETE) 540 541 def _result_to_model( 542 self, 543 data: dict[str, Any], 544 model: type[NodeT], 545 ) -> NodeT | None: 546 """Convert a query result row to a model instance (sync — pure dict processing).""" 547 if not data: 548 return None 549 550 data = dict(data) 551 data = run_class_hooks(model, _BEFORE_LOAD, data) or data 552 553 vid = data.pop("_id", None) 554 if vid is None: 555 vid = data.pop("_vid", None) 556 if vid is None: 557 vid = data.pop("vid", None) 558 if vid is not None and not isinstance(vid, int): 559 vid = int(vid) 560 data.pop("_label", None) 561 562 try: 563 instance = cast( 564 NodeT, 565 model.from_properties(data, vid=vid, session=self), 566 ) 567 except Exception: 568 return None 569 570 if vid is not None: 571 existing = self._identity_map.get((model.__label__, vid)) 572 if existing is not None: 573 return cast(NodeT, existing) 574 self._identity_map[(model.__label__, vid)] = instance 575 576 run_hooks(instance, _AFTER_LOAD) 577 return instance 578 579 def _load_relationship( 580 self, 581 entity: UniNode, 582 descriptor: RelationshipDescriptor[Any], 583 ) -> list[UniNode] | UniNode | None: 584 """Sync relationship loading — raises error for async session. 585 Use _async_load_relationship instead.""" 586 raise SessionError( 587 "Cannot synchronously load relationships in an async session. " 588 "Use eager_load() or access relationships via async queries." 589 ) 590 591 async def _async_eager_load_relationships( 592 self, 593 entities: list[NodeT], 594 relationships: list[str], 595 ) -> None: 596 """Eager load relationships for a list of entities (async).""" 597 if not entities: 598 return 599 600 model = type(entities[0]) 601 rel_configs = model.get_relationship_fields() 602 603 for rel_name in relationships: 604 if rel_name not in rel_configs: 605 continue 606 607 config = rel_configs[rel_name] 608 label = model.__label__ 609 vids = [e._vid for e in entities if e._vid is not None] 610 611 if not vids: 612 continue 613 614 pattern = _edge_pattern(config.edge_type, config.direction) 615 cypher = ( 616 f"MATCH (a:{label}){pattern}(b) WHERE id(a) IN $vids " 617 f"RETURN id(a) as src_vid, properties(b) AS _props, id(b) AS _vid, labels(b) AS _labels" 618 ) 619 results = await self._db_session.query(cypher, {"vids": vids}) 620 621 by_source: dict[int, list[Any]] = {} 622 for raw_row in results: 623 row = raw_row.to_dict() 624 src_vid = row["src_vid"] 625 node_data = _row_to_node_dict(row) 626 if node_data is None: 627 continue 628 if src_vid not in by_source: 629 by_source[src_vid] = [] 630 by_source[src_vid].append(node_data) 631 632 for entity in entities: 633 if entity._vid in by_source: 634 related = by_source[entity._vid] 635 cache_attr = f"_rel_cache_{rel_name}" 636 setattr(entity, cache_attr, related)
Async session for interacting with the graph database.
Mirrors UniSession with async methods. Uses AsyncUni.
Example:
from uni_db import AsyncUni from uni_pydantic import AsyncUniSession
db = await AsyncUni.open("./my_graph") async with AsyncUniSession(db) as session: ... session.register(Person) ... await session.sync_schema() ... alice = Person(name="Alice", age=30) ... session.add(alice) ... await session.commit()
153 def __init__(self, db: uni_db.AsyncUni) -> None: 154 self._db = db 155 self._db_session = db.session() 156 self._schema_gen = SchemaGenerator() 157 self._identity_map: WeakValueDictionary[tuple[str, int], UniNode] = ( 158 WeakValueDictionary() 159 ) 160 self._pending_new: list[UniNode] = [] 161 self._pending_delete: list[UniNode] = []
174 def close(self) -> None: 175 """Close the session and clear pending state.""" 176 self._pending_new.clear() 177 self._pending_delete.clear()
Close the session and clear pending state.
179 @property 180 def db(self) -> uni_db.AsyncUni: 181 """Access the underlying uni_db.AsyncUni for low-level operations.""" 182 return self._db
Access the underlying uni_db.AsyncUni for low-level operations.
184 async def locy(self, program: str, params: dict[str, Any] | None = None) -> Any: 185 """ 186 Evaluate a Locy program and return derived facts, stats, and warnings. 187 188 Delegates to the underlying ``uni_db.AsyncSession.locy()``. 189 """ 190 return await self._db_session.locy(program, params)
Evaluate a Locy program and return derived facts, stats, and warnings.
Delegates to the underlying uni_db.AsyncSession.locy().
192 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 193 """Register model classes with the session (sync).""" 194 self._schema_gen.register(*models)
Register model classes with the session (sync).
196 async def sync_schema(self) -> None: 197 """Synchronize database schema with registered models.""" 198 await self._schema_gen.async_apply_to_database(self._db)
Synchronize database schema with registered models.
200 def query(self, model: type[NodeT]) -> AsyncQueryBuilder[NodeT]: 201 """Create an async query builder for the given model.""" 202 return AsyncQueryBuilder(self, model)
Create an async query builder for the given model.
204 def add(self, entity: UniNode) -> None: 205 """Add a new entity to be persisted (sync — just collects).""" 206 if entity.is_persisted: 207 raise SessionError(f"Entity {entity!r} is already persisted") 208 entity._session = self 209 self._pending_new.append(entity)
Add a new entity to be persisted (sync — just collects).
211 def add_all(self, entities: Sequence[UniNode]) -> None: 212 """Add multiple entities (sync — just collects).""" 213 for entity in entities: 214 self.add(entity)
Add multiple entities (sync — just collects).
216 def delete(self, entity: UniNode) -> None: 217 """Mark an entity for deletion (sync — just collects).""" 218 if not entity.is_persisted: 219 raise NotPersisted(entity) 220 self._pending_delete.append(entity)
Mark an entity for deletion (sync — just collects).
222 async def get( 223 self, 224 model: type[NodeT], 225 vid: int | None = None, 226 uid: str | None = None, 227 **kwargs: Any, 228 ) -> NodeT | None: 229 """Get an entity by ID or unique properties.""" 230 if vid is not None: 231 cached = self._identity_map.get((model.__label__, vid)) 232 if cached is not None: 233 return cached # type: ignore[return-value] 234 235 label = model.__label__ 236 params: dict[str, Any] = {} 237 238 if vid is not None: 239 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 240 params["vid"] = vid 241 elif uid is not None: 242 cypher = f"MATCH (n:{label}) WHERE n._uid = $uid RETURN {_NODE_RETURN}" 243 params["uid"] = uid 244 elif kwargs: 245 for k in kwargs: 246 _validate_property(k, model) 247 conditions = [f"n.{k} = ${k}" for k in kwargs] 248 cypher = f"MATCH (n:{label}) WHERE {' AND '.join(conditions)} RETURN {_NODE_RETURN} LIMIT 1" 249 params.update(kwargs) 250 else: 251 raise ValueError("Must provide vid, uid, or property filters") 252 253 results = await self._db_session.query(cypher, params) 254 if not results: 255 return None 256 257 node_data = _row_to_node_dict(results[0].to_dict()) 258 if node_data is None: 259 return None 260 return self._result_to_model(node_data, model)
Get an entity by ID or unique properties.
262 async def refresh(self, entity: UniNode) -> None: 263 """Refresh an entity's properties from the database.""" 264 if not entity.is_persisted: 265 raise NotPersisted(entity) 266 267 label = entity.__class__.__label__ 268 cypher = f"MATCH (n:{label}) WHERE id(n) = $vid RETURN {_NODE_RETURN}" 269 results = await self._db_session.query(cypher, {"vid": entity._vid}) 270 271 if not results: 272 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 273 274 props = _row_to_node_dict(results[0].to_dict()) 275 if props is None: 276 raise SessionError(f"Entity with vid={entity._vid} no longer exists") 277 try: 278 hints = get_type_hints(type(entity)) 279 except Exception: 280 hints = {} 281 282 for field_name in entity.get_property_fields(): 283 if field_name in props: 284 value = props[field_name] 285 if field_name in hints: 286 value = db_to_python_value(value, hints[field_name]) 287 setattr(entity, field_name, value) 288 289 entity._mark_clean()
Refresh an entity's properties from the database.
291 async def commit(self) -> None: 292 """Commit all pending changes.""" 293 for entity in self._pending_new: 294 await self._create_node(entity) 295 296 for (label, vid), entity in list(self._identity_map.items()): 297 if entity.is_dirty and entity.is_persisted: 298 await self._update_node(entity) 299 300 for entity in self._pending_delete: 301 await self._delete_node(entity) 302 303 await self._db.flush() 304 self._pending_new.clear() 305 self._pending_delete.clear()
Commit all pending changes.
307 async def rollback(self) -> None: 308 """Discard all pending changes.""" 309 for entity in self._pending_new: 310 entity._session = None 311 self._pending_new.clear() 312 self._pending_delete.clear() 313 for entity in list(self._identity_map.values()): 314 if entity.is_dirty: 315 await self.refresh(entity)
Discard all pending changes.
317 async def transaction(self) -> AsyncUniTransaction: 318 """Create an async transaction. Use as `async with session.transaction() as tx:`.""" 319 return AsyncUniTransaction(self)
Create an async transaction. Use as async with session.transaction() as tx:.
321 async def cypher( 322 self, 323 query: str, 324 params: dict[str, Any] | None = None, 325 result_type: type[NodeT] | None = None, 326 ) -> list[NodeT] | list[dict[str, Any]]: 327 """Execute a raw Cypher query.""" 328 results = await self._db_session.query(query, params) 329 330 if result_type is None: 331 return [r.to_dict() for r in results] 332 333 mapped = [] 334 for raw_row in results: 335 row = raw_row.to_dict() 336 for key, value in row.items(): 337 if isinstance(value, dict): 338 if "_id" in value and "_label" in value: 339 instance = self._result_to_model(value, result_type) 340 if instance: 341 mapped.append(instance) 342 break 343 elif "_label" in value: 344 label = value["_label"] 345 if label in self._schema_gen._node_models: 346 model = self._schema_gen._node_models[label] 347 instance = self._result_to_model(value, model) 348 if instance: 349 mapped.append(instance) 350 break 351 else: 352 first_value = next(iter(row.values()), None) 353 if isinstance(first_value, dict): 354 instance = self._result_to_model(first_value, result_type) 355 if instance: 356 mapped.append(instance) 357 358 return mapped
Execute a raw Cypher query.
360 async def create_edge( 361 self, 362 source: UniNode, 363 edge_type: str, 364 target: UniNode, 365 properties: dict[str, Any] | UniEdge | None = None, 366 ) -> None: 367 """Create an edge between two nodes.""" 368 src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints( 369 source, target 370 ) 371 props = UniSession._normalize_edge_properties(properties) 372 373 props_str = ", ".join(f"{k}: ${k}" for k in props) 374 if props_str: 375 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type} {{{props_str}}}]->(b)" 376 else: 377 cypher = f"MATCH (a:{src_label}), (b:{dst_label}) WHERE a._vid = $src AND b._vid = $dst CREATE (a)-[r:{edge_type}]->(b)" 378 379 async with await self._db_session.tx() as tx: 380 await tx.execute(cypher, {"src": src_vid, "dst": dst_vid, **props}) 381 await tx.commit()
Create an edge between two nodes.
383 async def delete_edge( 384 self, source: UniNode, edge_type: str, target: UniNode 385 ) -> int: 386 """Delete edges between two nodes. Returns the number of deleted edges.""" 387 src_vid, dst_vid, src_label, dst_label = UniSession._validate_edge_endpoints( 388 source, target 389 ) 390 cypher = ( 391 f"MATCH (a:{src_label})-[r:{edge_type}]->(b:{dst_label}) " 392 f"WHERE a._vid = $src AND b._vid = $dst " 393 f"DELETE r RETURN count(r) as count" 394 ) 395 async with await self._db_session.tx() as tx: 396 results = await tx.query(cypher, {"src": src_vid, "dst": dst_vid}) 397 await tx.commit() 398 return cast(int, results[0]["count"]) if results else 0
Delete edges between two nodes. Returns the number of deleted edges.
400 async def bulk_add(self, entities: Sequence[UniNode]) -> list[int]: 401 """Bulk-add entities using bulk_writer.""" 402 if not entities: 403 return [] 404 405 by_label: dict[str, list[UniNode]] = {} 406 for entity in entities: 407 label = entity.__class__.__label__ 408 if label not in by_label: 409 by_label[label] = [] 410 by_label[label].append(entity) 411 412 all_vids: list[int] = [] 413 try: 414 for label, group in by_label.items(): 415 for entity in group: 416 run_hooks(entity, _BEFORE_CREATE) 417 prop_dicts = [e.to_properties() for e in group] 418 tx = await self._db_session.tx() 419 async with await tx.bulk_writer().build() as bw: 420 vids = await bw.insert_vertices(label, prop_dicts) 421 await bw.commit() 422 await tx.commit() 423 for entity, vid in zip(group, vids): 424 entity._attach_session(self, vid) 425 self._identity_map[(label, vid)] = entity 426 run_hooks(entity, _AFTER_CREATE) 427 entity._mark_clean() 428 all_vids.extend(vids) 429 except Exception as e: 430 raise BulkLoadError(f"Bulk insert failed: {e}") from e 431 432 return all_vids
Bulk-add entities using bulk_writer.
434 async def explain(self, cypher: str) -> Any: 435 """Get the query execution plan.""" 436 return await self._db_session.explain(cypher)
Get the query execution plan.
438 async def profile(self, cypher: str) -> Any: 439 """Run the query with profiling and return results + stats.""" 440 return await self._db_session.profile(cypher)
Run the query with profiling and return results + stats.
54class AsyncUniTransaction: 55 """Async transaction context for atomic operations.""" 56 57 def __init__(self, session: AsyncUniSession) -> None: 58 self._session = session 59 self._tx: uni_db.AsyncTransaction | None = None 60 self._pending_nodes: list[UniNode] = [] 61 self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = [] 62 self._committed = False 63 self._rolled_back = False 64 65 async def __aenter__(self) -> AsyncUniTransaction: 66 self._tx = await self._session._db_session.tx() 67 return self 68 69 async def __aexit__( 70 self, 71 exc_type: type[BaseException] | None, 72 exc_val: BaseException | None, 73 exc_tb: TracebackType | None, 74 ) -> None: 75 if exc_type is not None: 76 await self.rollback() 77 return 78 if not self._committed and not self._rolled_back: 79 await self.commit() 80 81 def add(self, entity: UniNode) -> None: 82 """Add a node to be created in this transaction (sync — just collects).""" 83 self._pending_nodes.append(entity) 84 85 def create_edge( 86 self, 87 source: UniNode, 88 edge_type: str, 89 target: UniNode, 90 properties: UniEdge | None = None, 91 ) -> None: 92 """Create an edge between two nodes in this transaction (sync — just collects).""" 93 if not source.is_persisted: 94 raise NotPersisted(source) 95 if not target.is_persisted: 96 raise NotPersisted(target) 97 self._pending_edges.append((source, edge_type, target, properties)) 98 99 async def commit(self) -> None: 100 """Commit the transaction.""" 101 if self._committed: 102 raise TransactionError("Transaction already committed") 103 if self._rolled_back: 104 raise TransactionError("Transaction already rolled back") 105 if self._tx is None: 106 raise TransactionError("Transaction not started") 107 108 try: 109 for node in self._pending_nodes: 110 await self._session._create_node_in_tx(node, self._tx) 111 for source, edge_type, target, props in self._pending_edges: 112 await self._session._create_edge_in_tx( 113 source, edge_type, target, props, self._tx 114 ) 115 await self._tx.commit() 116 self._committed = True 117 for node in self._pending_nodes: 118 node._mark_clean() 119 except Exception as e: 120 await self.rollback() 121 raise TransactionError(f"Commit failed: {e}") from e 122 123 async def rollback(self) -> None: 124 """Rollback the transaction.""" 125 if self._rolled_back: 126 return 127 if self._tx is not None: 128 await self._tx.rollback() 129 self._rolled_back = True 130 self._pending_nodes.clear() 131 self._pending_edges.clear()
Async transaction context for atomic operations.
57 def __init__(self, session: AsyncUniSession) -> None: 58 self._session = session 59 self._tx: uni_db.AsyncTransaction | None = None 60 self._pending_nodes: list[UniNode] = [] 61 self._pending_edges: list[tuple[UniNode, str, UniNode, UniEdge | None]] = [] 62 self._committed = False 63 self._rolled_back = False
81 def add(self, entity: UniNode) -> None: 82 """Add a node to be created in this transaction (sync — just collects).""" 83 self._pending_nodes.append(entity)
Add a node to be created in this transaction (sync — just collects).
85 def create_edge( 86 self, 87 source: UniNode, 88 edge_type: str, 89 target: UniNode, 90 properties: UniEdge | None = None, 91 ) -> None: 92 """Create an edge between two nodes in this transaction (sync — just collects).""" 93 if not source.is_persisted: 94 raise NotPersisted(source) 95 if not target.is_persisted: 96 raise NotPersisted(target) 97 self._pending_edges.append((source, edge_type, target, properties))
Create an edge between two nodes in this transaction (sync — just collects).
99 async def commit(self) -> None: 100 """Commit the transaction.""" 101 if self._committed: 102 raise TransactionError("Transaction already committed") 103 if self._rolled_back: 104 raise TransactionError("Transaction already rolled back") 105 if self._tx is None: 106 raise TransactionError("Transaction not started") 107 108 try: 109 for node in self._pending_nodes: 110 await self._session._create_node_in_tx(node, self._tx) 111 for source, edge_type, target, props in self._pending_edges: 112 await self._session._create_edge_in_tx( 113 source, edge_type, target, props, self._tx 114 ) 115 await self._tx.commit() 116 self._committed = True 117 for node in self._pending_nodes: 118 node._mark_clean() 119 except Exception as e: 120 await self.rollback() 121 raise TransactionError(f"Commit failed: {e}") from e
Commit the transaction.
123 async def rollback(self) -> None: 124 """Rollback the transaction.""" 125 if self._rolled_back: 126 return 127 if self._tx is not None: 128 await self._tx.rollback() 129 self._rolled_back = True 130 self._pending_nodes.clear() 131 self._pending_edges.clear()
Rollback the transaction.
69def Field( 70 default: Any = ..., 71 *, 72 default_factory: Callable[[], Any] | None = None, 73 alias: str | None = None, 74 title: str | None = None, 75 description: str | None = None, 76 examples: list[Any] | None = None, 77 exclude: bool = False, 78 json_schema_extra: dict[str, Any] | None = None, 79 # Uni-specific options 80 index: IndexType | None = None, 81 unique: bool = False, 82 tokenizer: str | None = None, 83 metric: VectorMetric | None = None, 84 generated: str | None = None, 85) -> Any: 86 """ 87 Create a field with uni-pydantic configuration. 88 89 This extends Pydantic's Field with graph database options. 90 91 Args: 92 default: Default value for the field. 93 default_factory: Factory function for default value. 94 alias: Field alias for serialization. 95 title: Human-readable title. 96 description: Field description. 97 examples: Example values. 98 exclude: Exclude from serialization. 99 json_schema_extra: Extra JSON schema properties. 100 index: Index type ("btree", "hash", "fulltext", "vector"). 101 unique: Whether to create a unique constraint. 102 tokenizer: Tokenizer for fulltext index (default: "standard"). 103 metric: Distance metric for vector index ("l2", "cosine", "dot"). 104 generated: Expression for generated/computed property. 105 106 Returns: 107 A Pydantic FieldInfo with uni-pydantic metadata attached. 108 109 Examples: 110 >>> class Person(UniNode): 111 ... name: str = Field(index="btree") 112 ... email: str = Field(unique=True) 113 ... bio: str = Field(index="fulltext", tokenizer="standard") 114 ... embedding: Vector[768] = Field(metric="cosine") 115 """ 116 # Default tokenizer for fulltext indexes 117 if index == "fulltext" and tokenizer is None: 118 tokenizer = "standard" 119 120 # Store uni config in json_schema_extra 121 uni_config = FieldConfig( 122 index=index, 123 unique=unique, 124 tokenizer=tokenizer, 125 metric=metric, 126 generated=generated, 127 default=default, 128 default_factory=default_factory, 129 alias=alias, 130 title=title, 131 description=description, 132 examples=examples, 133 exclude=exclude, 134 json_schema_extra=json_schema_extra, 135 ) 136 137 # Merge uni config into json_schema_extra 138 extra = json_schema_extra or {} 139 extra["uni_config"] = uni_config 140 141 # Create Pydantic FieldInfo 142 from pydantic.fields import FieldInfo as PydanticFieldInfo 143 144 if default_factory is not None: 145 return PydanticFieldInfo( 146 default_factory=default_factory, 147 alias=alias, 148 title=title, 149 description=description, 150 examples=examples, 151 exclude=exclude, 152 json_schema_extra=extra, 153 ) 154 elif default is not ...: 155 return PydanticFieldInfo( 156 default=default, 157 alias=alias, 158 title=title, 159 description=description, 160 examples=examples, 161 exclude=exclude, 162 json_schema_extra=extra, 163 ) 164 else: 165 return PydanticFieldInfo( 166 alias=alias, 167 title=title, 168 description=description, 169 examples=examples, 170 exclude=exclude, 171 json_schema_extra=extra, 172 )
Create a field with uni-pydantic configuration.
This extends Pydantic's Field with graph database options.
Args: default: Default value for the field. default_factory: Factory function for default value. alias: Field alias for serialization. title: Human-readable title. description: Field description. examples: Example values. exclude: Exclude from serialization. json_schema_extra: Extra JSON schema properties. index: Index type ("btree", "hash", "fulltext", "vector"). unique: Whether to create a unique constraint. tokenizer: Tokenizer for fulltext index (default: "standard"). metric: Distance metric for vector index ("l2", "cosine", "dot"). generated: Expression for generated/computed property.
Returns: A Pydantic FieldInfo with uni-pydantic metadata attached.
Examples:
class Person(UniNode): ... name: str = Field(index="btree") ... email: str = Field(unique=True) ... bio: str = Field(index="fulltext", tokenizer="standard") ... embedding: Vector[768] = Field(metric="cosine")
41@dataclass 42class FieldConfig: 43 """Configuration for a uni-pydantic field.""" 44 45 # Index configuration 46 index: IndexType | None = None 47 unique: bool = False 48 49 # Fulltext index options 50 tokenizer: str | None = None 51 52 # Vector index options 53 metric: VectorMetric | None = None 54 55 # Generated/computed property 56 generated: str | None = None 57 58 # Pydantic field options (passed through) 59 default: Any = dataclass_field(default_factory=lambda: ...) 60 default_factory: Callable[[], Any] | None = None 61 alias: str | None = None 62 title: str | None = None 63 description: str | None = None 64 examples: list[Any] | None = None 65 exclude: bool = False 66 json_schema_extra: dict[str, Any] | None = None
Configuration for a uni-pydantic field.
272def Relationship( 273 edge_type: str, 274 *, 275 direction: Direction = "outgoing", 276 edge_model: type[UniEdge] | None = None, 277 eager: bool = False, 278 cascade_delete: bool = False, 279) -> Any: 280 """ 281 Declare a relationship to another node type. 282 283 Relationships are lazy-loaded by default. Use eager=True or 284 query.eager_load() to load them with the parent query. 285 286 Args: 287 edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT"). 288 direction: Relationship direction: 289 - "outgoing": Follow edges from this node (default) 290 - "incoming": Follow edges to this node 291 - "both": Follow edges in both directions 292 edge_model: Optional UniEdge subclass for typed edge properties. 293 eager: Whether to eager-load this relationship by default. 294 cascade_delete: Whether to delete related edges when this node is deleted. 295 296 Returns: 297 A RelationshipDescriptor that will be processed during model creation. 298 299 Examples: 300 >>> class Person(UniNode): 301 ... # Outgoing relationship (default) 302 ... follows: list["Person"] = Relationship("FOLLOWS") 303 ... 304 ... # Incoming relationship 305 ... followers: list["Person"] = Relationship("FOLLOWS", direction="incoming") 306 ... 307 ... # Single optional relationship 308 ... manager: "Person | None" = Relationship("REPORTS_TO") 309 ... 310 ... # Relationship with edge properties 311 ... friendships: list[tuple["Person", FriendshipEdge]] = Relationship( 312 ... "FRIEND_OF", 313 ... edge_model=FriendshipEdge 314 ... ) 315 """ 316 config = RelationshipConfig( 317 edge_type=edge_type, 318 direction=direction, 319 edge_model=edge_model, 320 eager=eager, 321 cascade_delete=cascade_delete, 322 ) 323 # Return a marker that will be processed by the metaclass 324 return _RelationshipMarker(config)
Declare a relationship to another node type.
Relationships are lazy-loaded by default. Use eager=True or query.eager_load() to load them with the parent query.
Args: edge_type: The edge type name (e.g., "FRIEND_OF", "WORKS_AT"). direction: Relationship direction: - "outgoing": Follow edges from this node (default) - "incoming": Follow edges to this node - "both": Follow edges in both directions edge_model: Optional UniEdge subclass for typed edge properties. eager: Whether to eager-load this relationship by default. cascade_delete: Whether to delete related edges when this node is deleted.
Returns: A RelationshipDescriptor that will be processed during model creation.
Examples:
class Person(UniNode): ... # Outgoing relationship (default) ... follows: list["Person"] = Relationship("FOLLOWS") ... ... # Incoming relationship ... followers: list["Person"] = Relationship("FOLLOWS", direction="incoming") ... ... # Single optional relationship ... manager: "Person | None" = Relationship("REPORTS_TO") ... ... # Relationship with edge properties ... friendships: list[tuple["Person", FriendshipEdge]] = Relationship( ... "FRIEND_OF", ... edge_model=FriendshipEdge ... )
185@dataclass 186class RelationshipConfig: 187 """Configuration for a relationship field.""" 188 189 edge_type: str 190 direction: Direction = "outgoing" 191 edge_model: type[UniEdge] | None = None 192 eager: bool = False 193 cascade_delete: bool = False
Configuration for a relationship field.
196class RelationshipDescriptor(Generic[NodeT]): 197 """ 198 Descriptor for relationship fields that enables lazy loading. 199 200 When accessed on an instance, it returns the related nodes. 201 When accessed on the class, it returns the descriptor for query building. 202 """ 203 204 def __init__( 205 self, 206 config: RelationshipConfig, 207 field_name: str, 208 target_type: type[NodeT] | str | None = None, 209 is_list: bool = True, 210 ) -> None: 211 self.config = config 212 self.field_name = field_name 213 self.target_type = target_type 214 self.is_list = is_list 215 self._cache_attr = f"_rel_cache_{field_name}" 216 217 def __set_name__(self, owner: type, name: str) -> None: 218 self.field_name = name 219 self._cache_attr = f"_rel_cache_{name}" 220 221 @overload 222 def __get__( 223 self, obj: None, objtype: type[NodeT] 224 ) -> RelationshipDescriptor[NodeT]: ... 225 226 @overload 227 def __get__( 228 self, obj: NodeT, objtype: type[NodeT] | None = None 229 ) -> list[NodeT] | NodeT | None: ... 230 231 def __get__( 232 self, obj: NodeT | None, objtype: type[NodeT] | None = None 233 ) -> RelationshipDescriptor[NodeT] | list[NodeT] | NodeT | None: 234 if obj is None: 235 # Class-level access returns the descriptor 236 return self 237 238 # Instance-level access - check cache first 239 if hasattr(obj, self._cache_attr): 240 cached = getattr(obj, self._cache_attr) 241 return cast("list[NodeT] | NodeT | None", cached) 242 243 # Check if we have a session for lazy loading 244 session = getattr(obj, "_session", None) 245 if session is None: 246 from .exceptions import LazyLoadError 247 248 raise LazyLoadError( 249 self.field_name, 250 "No session attached. Use session.get() or enable eager loading.", 251 ) 252 253 # Lazy load the relationship 254 result = session._load_relationship(obj, self) 255 256 # Cache the result 257 setattr(obj, self._cache_attr, result) 258 return cast("list[NodeT] | NodeT | None", result) 259 260 def __set__(self, obj: NodeT, value: list[NodeT] | NodeT | None) -> None: 261 # Allow setting the cached value (e.g., during eager loading) 262 setattr(obj, self._cache_attr, value) 263 264 def __repr__(self) -> str: 265 return f"Relationship({self.config.edge_type!r}, direction={self.config.direction!r})"
Descriptor for relationship fields that enables lazy loading.
When accessed on an instance, it returns the related nodes. When accessed on the class, it returns the descriptor for query building.
204 def __init__( 205 self, 206 config: RelationshipConfig, 207 field_name: str, 208 target_type: type[NodeT] | str | None = None, 209 is_list: bool = True, 210 ) -> None: 211 self.config = config 212 self.field_name = field_name 213 self.target_type = target_type 214 self.is_list = is_list 215 self._cache_attr = f"_rel_cache_{field_name}"
175def get_field_config(field_info: FieldInfo) -> FieldConfig | None: 176 """Extract uni-pydantic config from a Pydantic FieldInfo.""" 177 extra = field_info.json_schema_extra 178 if isinstance(extra, dict): 179 config = extra.get("uni_config") 180 if isinstance(config, FieldConfig): 181 return config 182 return None
Extract uni-pydantic config from a Pydantic FieldInfo.
150class Btic: 151 """A BTIC temporal interval value for Uni graph database. 152 153 Construct from an ISO 8601-inspired string literal:: 154 155 Btic("1985") 156 Btic("1985-03/2024-06") 157 Btic("~1985") # approximate certainty 158 Btic("2020-03/") # ongoing (unbounded hi) 159 160 Use as a Pydantic model field type:: 161 162 class Event(UniNode): 163 when: Btic 164 """ 165 166 def __init__(self, value: str | object) -> None: 167 if _PyBtic is None: 168 raise ImportError("uni_db is required for Btic type") 169 if isinstance(value, str): 170 self._inner = _PyBtic(value) 171 elif _PyBtic is not None and isinstance(value, _PyBtic): 172 self._inner = value 173 elif isinstance(value, Btic): 174 self._inner = value._inner 175 else: 176 raise TypeError(f"Expected str or Btic, got {type(value)}") 177 178 @property 179 def lo(self) -> int: 180 """Lower bound in milliseconds since epoch.""" 181 return self._inner.lo 182 183 @property 184 def hi(self) -> int: 185 """Upper bound in milliseconds since epoch.""" 186 return self._inner.hi 187 188 @property 189 def meta(self) -> int: 190 """Raw 64-bit metadata word.""" 191 return self._inner.meta 192 193 @property 194 def lo_granularity(self) -> str: 195 """Lower bound granularity name.""" 196 return self._inner.lo_granularity 197 198 @property 199 def hi_granularity(self) -> str: 200 """Upper bound granularity name.""" 201 return self._inner.hi_granularity 202 203 @property 204 def lo_certainty(self) -> str: 205 """Lower bound certainty name.""" 206 return self._inner.lo_certainty 207 208 @property 209 def hi_certainty(self) -> str: 210 """Upper bound certainty name.""" 211 return self._inner.hi_certainty 212 213 @property 214 def duration_ms(self) -> int | None: 215 """Duration in milliseconds, or None if unbounded.""" 216 return self._inner.duration_ms 217 218 @property 219 def is_instant(self) -> bool: 220 """True if the interval is exactly 1 millisecond wide.""" 221 return self._inner.is_instant 222 223 @property 224 def is_unbounded(self) -> bool: 225 """True if either bound is infinite.""" 226 return self._inner.is_unbounded 227 228 @property 229 def is_finite(self) -> bool: 230 """True if both bounds are finite.""" 231 return self._inner.is_finite 232 233 def __repr__(self) -> str: 234 return f'Btic("{self._inner}")' 235 236 def __str__(self) -> str: 237 return str(self._inner) 238 239 def __eq__(self, other: object) -> bool: 240 if isinstance(other, Btic): 241 return self._inner == other._inner 242 return False 243 244 def __hash__(self) -> int: 245 return hash(self._inner) 246 247 @classmethod 248 def __get_pydantic_core_schema__( 249 cls, source_type: Any, handler: GetCoreSchemaHandler 250 ) -> CoreSchema: 251 """Make Btic compatible with Pydantic v2.""" 252 253 def validate_btic(v: Any) -> Btic: 254 if isinstance(v, Btic): 255 return v 256 if isinstance(v, str): 257 return Btic(v) 258 if _PyBtic is not None and isinstance(v, _PyBtic): 259 return Btic(v) 260 raise TypeError(f"Expected str or Btic, got {type(v)}") 261 262 return core_schema.no_info_plain_validator_function( 263 validate_btic, 264 serialization=core_schema.plain_serializer_function_ser_schema( 265 lambda v: str(v._inner) if isinstance(v, Btic) else str(v), 266 info_arg=False, 267 ), 268 )
A BTIC temporal interval value for Uni graph database.
Construct from an ISO 8601-inspired string literal::
Btic("1985")
Btic("1985-03/2024-06")
Btic("~1985") # approximate certainty
Btic("2020-03/") # ongoing (unbounded hi)
Use as a Pydantic model field type::
class Event(UniNode):
when: Btic
166 def __init__(self, value: str | object) -> None: 167 if _PyBtic is None: 168 raise ImportError("uni_db is required for Btic type") 169 if isinstance(value, str): 170 self._inner = _PyBtic(value) 171 elif _PyBtic is not None and isinstance(value, _PyBtic): 172 self._inner = value 173 elif isinstance(value, Btic): 174 self._inner = value._inner 175 else: 176 raise TypeError(f"Expected str or Btic, got {type(value)}")
178 @property 179 def lo(self) -> int: 180 """Lower bound in milliseconds since epoch.""" 181 return self._inner.lo
Lower bound in milliseconds since epoch.
183 @property 184 def hi(self) -> int: 185 """Upper bound in milliseconds since epoch.""" 186 return self._inner.hi
Upper bound in milliseconds since epoch.
188 @property 189 def meta(self) -> int: 190 """Raw 64-bit metadata word.""" 191 return self._inner.meta
Raw 64-bit metadata word.
193 @property 194 def lo_granularity(self) -> str: 195 """Lower bound granularity name.""" 196 return self._inner.lo_granularity
Lower bound granularity name.
198 @property 199 def hi_granularity(self) -> str: 200 """Upper bound granularity name.""" 201 return self._inner.hi_granularity
Upper bound granularity name.
203 @property 204 def lo_certainty(self) -> str: 205 """Lower bound certainty name.""" 206 return self._inner.lo_certainty
Lower bound certainty name.
208 @property 209 def hi_certainty(self) -> str: 210 """Upper bound certainty name.""" 211 return self._inner.hi_certainty
Upper bound certainty name.
213 @property 214 def duration_ms(self) -> int | None: 215 """Duration in milliseconds, or None if unbounded.""" 216 return self._inner.duration_ms
Duration in milliseconds, or None if unbounded.
218 @property 219 def is_instant(self) -> bool: 220 """True if the interval is exactly 1 millisecond wide.""" 221 return self._inner.is_instant
True if the interval is exactly 1 millisecond wide.
61class Vector(Generic[N], metaclass=VectorMeta): 62 """ 63 A vector type with fixed dimensions for embeddings. 64 65 Usage: 66 embedding: Vector[1536] # 1536-dimensional vector 67 68 At runtime, vectors are stored as list[float]. 69 """ 70 71 __dimensions__: int = 0 72 __origin__: type | None = None 73 74 def __init__(self, values: list[float]) -> None: 75 expected = self.__class__.__dimensions__ 76 if expected > 0 and len(values) != expected: 77 raise ValueError(f"Vector expects {expected} dimensions, got {len(values)}") 78 self._values = values 79 80 @property 81 def values(self) -> list[float]: 82 return self._values 83 84 def __repr__(self) -> str: 85 dims = self.__class__.__dimensions__ 86 return ( 87 f"Vector[{dims}]({self._values[:3]}...)" 88 if len(self._values) > 3 89 else f"Vector[{dims}]({self._values})" 90 ) 91 92 def __eq__(self, other: object) -> bool: 93 if isinstance(other, Vector): 94 return self._values == other._values 95 if isinstance(other, list): 96 return self._values == other 97 return False 98 99 def __len__(self) -> int: 100 return len(self._values) 101 102 def __iter__(self): # type: ignore[no-untyped-def] 103 return iter(self._values) 104 105 @classmethod 106 def __get_pydantic_core_schema__( 107 cls, source_type: Any, handler: GetCoreSchemaHandler 108 ) -> CoreSchema: 109 """Make Vector compatible with Pydantic v2.""" 110 dimensions = getattr(source_type, "__dimensions__", 0) 111 vec_cls = source_type if dimensions > 0 else cls 112 113 def validate_vector(v: Any) -> Vector: # type: ignore[type-arg] 114 if isinstance(v, Vector): 115 if dimensions > 0 and len(v) != dimensions: 116 raise ValueError( 117 f"Vector expects {dimensions} dimensions, got {len(v)}" 118 ) 119 return v 120 if isinstance(v, list): 121 if dimensions > 0 and len(v) != dimensions: 122 raise ValueError( 123 f"Vector expects {dimensions} dimensions, got {len(v)}" 124 ) 125 return vec_cls([float(x) for x in v]) 126 raise TypeError(f"Expected list or Vector, got {type(v)}") 127 128 return core_schema.no_info_plain_validator_function( 129 validate_vector, 130 serialization=core_schema.plain_serializer_function_ser_schema( 131 lambda v: v.values if isinstance(v, Vector) else list(v), 132 info_arg=False, 133 ), 134 )
A vector type with fixed dimensions for embeddings.
Usage: embedding: Vector[1536] # 1536-dimensional vector
At runtime, vectors are stored as list[float].
415def python_type_to_uni(type_hint: Any, *, nullable: bool = False) -> tuple[str, bool]: 416 """ 417 Convert a Python type hint to a Uni DataType string. 418 419 Args: 420 type_hint: The Python type hint to convert. 421 nullable: Whether the field is explicitly nullable. 422 423 Returns: 424 Tuple of (uni_data_type, is_nullable) 425 426 Raises: 427 TypeMappingError: If the type cannot be mapped. 428 """ 429 # Unwrap Annotated if present 430 type_hint, _ = unwrap_annotated(type_hint) 431 432 # Check for optional (T | None) 433 is_opt, inner_type = is_optional(type_hint) 434 if is_opt: 435 uni_type, _ = python_type_to_uni(inner_type) 436 return uni_type, True 437 438 # Check for Vector types 439 dims = get_vector_dimensions(type_hint) 440 if dims is not None: 441 return f"vector:{dims}", nullable 442 443 # Check for list types 444 is_lst, elem_type = is_list_type(type_hint) 445 if is_lst: 446 if elem_type in (str, int, float, bool): 447 # Simple list types 448 elem_uni = TYPE_MAP.get(elem_type, "string") 449 return f"list:{elem_uni}", nullable 450 # Complex list types stored as JSON 451 return "json", nullable 452 453 # Direct type mapping 454 if type_hint in TYPE_MAP: 455 return TYPE_MAP[type_hint], nullable 456 457 # Handle generic dict types 458 origin = get_origin(type_hint) 459 if origin is dict: 460 return "json", nullable 461 462 # Handle forward references (strings) 463 if isinstance(type_hint, str): 464 # This is a forward reference, can't resolve here 465 raise TypeMappingError( 466 type_hint, 467 f"Cannot resolve forward reference {type_hint!r}. " 468 "Ensure the referenced class is defined before schema sync.", 469 ) 470 471 raise TypeMappingError(type_hint)
Convert a Python type hint to a Uni DataType string.
Args: type_hint: The Python type hint to convert. nullable: Whether the field is explicitly nullable.
Returns: Tuple of (uni_data_type, is_nullable)
Raises: TypeMappingError: If the type cannot be mapped.
474def uni_to_python_type(uni_type: str) -> type: 475 """ 476 Convert a Uni DataType string to a Python type. 477 478 Args: 479 uni_type: The Uni data type string. 480 481 Returns: 482 The corresponding Python type. 483 """ 484 # Reverse mapping — manually constructed to avoid bytes overwriting str for "string" 485 _REVERSE_MAP: dict[str, type] = { 486 "string": str, 487 "int64": int, 488 "float64": float, 489 "bool": bool, 490 "datetime": datetime, 491 "date": date, 492 "time": time, 493 "duration": timedelta, 494 "json": dict, 495 "btic": Btic, 496 } 497 498 # Handle vector types 499 if uni_type.startswith("vector:"): 500 return list # Vectors are stored as list[float] 501 502 # Handle list types 503 if uni_type.startswith("list:"): 504 return list 505 506 return _REVERSE_MAP.get(uni_type.lower(), str)
Convert a Uni DataType string to a Python type.
Args: uni_type: The Uni data type string.
Returns: The corresponding Python type.
137def get_vector_dimensions(type_hint: Any) -> int | None: 138 """Extract vector dimensions from a Vector[N] type hint.""" 139 if hasattr(type_hint, "__dimensions__"): 140 dims: int = type_hint.__dimensions__ 141 return dims 142 origin = get_origin(type_hint) 143 if origin is Vector: 144 args = get_args(type_hint) 145 if args and isinstance(args[0], int): 146 return args[0] 147 return None
Extract vector dimensions from a Vector[N] type hint.
271def is_optional(type_hint: Any) -> tuple[bool, Any]: 272 """ 273 Check if a type hint is Optional (T | None). 274 275 Returns: 276 Tuple of (is_optional, inner_type) 277 """ 278 origin = get_origin(type_hint) 279 280 # Handle Union types (including T | None which is Union[T, None]) 281 if origin is Union: 282 args = get_args(type_hint) 283 non_none_args = [arg for arg in args if arg is not type(None)] 284 if len(non_none_args) == 1 and type(None) in args: 285 return True, non_none_args[0] 286 287 # Python 3.10+ uses types.UnionType for X | Y syntax 288 if isinstance(type_hint, types.UnionType): 289 args = get_args(type_hint) 290 non_none_args = [arg for arg in args if arg is not type(None)] 291 if len(non_none_args) == 1 and type(None) in args: 292 return True, non_none_args[0] 293 294 return False, type_hint
Check if a type hint is Optional (T | None).
Returns: Tuple of (is_optional, inner_type)
297def is_list_type(type_hint: Any) -> tuple[bool, Any | None]: 298 """ 299 Check if a type hint is a list type. 300 301 Returns: 302 Tuple of (is_list, element_type) 303 """ 304 origin = get_origin(type_hint) 305 if origin is list: 306 args = get_args(type_hint) 307 return True, args[0] if args else Any 308 return False, None
Check if a type hint is a list type.
Returns: Tuple of (is_list, element_type)
311def unwrap_annotated(type_hint: Any) -> tuple[Any, tuple[Any, ...]]: 312 """ 313 Unwrap an Annotated type. 314 315 Returns: 316 Tuple of (base_type, metadata_tuple) 317 """ 318 origin = get_origin(type_hint) 319 if origin is Annotated: 320 args = get_args(type_hint) 321 return args[0], args[1:] 322 return type_hint, ()
Unwrap an Annotated type.
Returns: Tuple of (base_type, metadata_tuple)
329def python_to_db_value(value: Any, type_hint: Any) -> Any: 330 """Convert a Python value to a database-compatible value. 331 332 Passes datetime/date/time/timedelta through to the Rust layer which 333 converts them to proper Value::Temporal types. Converts Vector to 334 list[float] and passes through everything else. 335 """ 336 if value is None: 337 return None 338 339 # Vector → list[float] 340 if isinstance(value, Vector): 341 return value.values 342 343 # Btic → unwrap to the Rust PyBtic for py_object_to_value 344 if isinstance(value, Btic): 345 return value._inner 346 347 # datetime/date/time/timedelta pass through — the Rust py_object_to_value 348 # handles conversion to Value::Temporal with proper type information. 349 return value
Convert a Python value to a database-compatible value.
Passes datetime/date/time/timedelta through to the Rust layer which converts them to proper Value::Temporal types. Converts Vector to list[float] and passes through everything else.
352def db_to_python_value(value: Any, type_hint: Any) -> Any: 353 """Convert a database value back to a Python value. 354 355 The Rust layer now returns proper Python datetime/date/time objects 356 via Value::Temporal, so in most cases values pass through directly. 357 """ 358 if value is None: 359 return None 360 361 # Unwrap Optional 362 _, inner = is_optional(type_hint) 363 if inner is not type_hint: 364 type_hint = inner 365 366 # Unwrap Annotated 367 type_hint, _ = unwrap_annotated(type_hint) 368 369 # If value is already the right Python type, pass through 370 if type_hint is datetime and isinstance(value, datetime): 371 return value 372 if type_hint is date and isinstance(value, date): 373 return value 374 if type_hint is time and isinstance(value, time): 375 return value 376 if type_hint is timedelta and isinstance(value, timedelta): 377 return value 378 379 # Btic — wrap Rust PyBtic in the pydantic Btic wrapper 380 if type_hint is Btic and _PyBtic is not None and isinstance(value, _PyBtic): 381 return Btic(value) 382 383 # Handle struct dict from Arrow deserialization (e.g. datetime struct) 384 if type_hint is datetime and isinstance(value, dict): 385 nanos = value.get("nanos_since_epoch") 386 if nanos is not None: 387 return datetime.fromtimestamp(nanos / 1_000_000_000) 388 return None 389 390 # Vector fields: list[float] → Vector 391 dims = get_vector_dimensions(type_hint) 392 if dims is not None and isinstance(value, list): 393 vec_cls = Vector[dims] 394 return vec_cls(value) 395 396 return value
Convert a database value back to a Python value.
The Rust layer now returns proper Python datetime/date/time objects via Value::Temporal, so in most cases values pass through directly.
591class QueryBuilder(_QueryBuilderBase[NodeT]): 592 """ 593 Immutable, type-safe query builder for graph queries. 594 595 Each method returns a **new** QueryBuilder instance. The original is 596 never mutated. Provides a fluent API for building Cypher queries 597 with type checking and IDE autocomplete support. 598 599 Example: 600 >>> adults = ( 601 ... session.query(Person) 602 ... .filter(Person.age >= 18) 603 ... .order_by(Person.name) 604 ... .limit(10) 605 ... .all() 606 ... ) 607 """ 608 609 def __init__(self, session: UniSession, model: type[NodeT]) -> None: 610 self._init_state(session, model) 611 612 def _execute_query( 613 self, cypher: str, params: dict[str, Any] 614 ) -> list[dict[str, Any]]: 615 """Execute a query, using query_with if timeout/max_memory is set.""" 616 if self._timeout is not None or self._max_memory is not None: 617 builder = self._session._db_session.query_with(cypher) 618 if params: 619 builder = builder.params(params) 620 if self._timeout is not None: 621 builder = builder.timeout(self._timeout) 622 if self._max_memory is not None: 623 builder = builder.max_memory(self._max_memory) 624 result = builder.fetch_all() 625 else: 626 result = self._session._db_session.query(cypher, params) 627 return [row.to_dict() for row in result] 628 629 def all(self) -> list[NodeT]: 630 """Execute the query and return all results.""" 631 cypher, params = self._build_cypher() 632 results = self._execute_query(cypher, params) 633 instances = self._rows_to_instances(results) 634 if self._eager_load and instances: 635 self._session._eager_load_relationships(instances, self._eager_load) 636 return instances 637 638 def first(self) -> NodeT | None: 639 """Execute the query and return the first result.""" 640 clone = self._clone() 641 clone._limit = 1 642 results = clone.all() 643 return results[0] if results else None 644 645 def one(self) -> NodeT: 646 """Execute the query and return exactly one result. 647 648 Raises QueryError if no results or more than one result. 649 """ 650 clone = self._clone() 651 clone._limit = 2 652 results = clone.all() 653 if not results: 654 raise QueryError("Query returned no results") 655 if len(results) > 1: 656 raise QueryError("Query returned more than one result") 657 return results[0] 658 659 def count(self) -> int: 660 """Execute the query and return the count of results.""" 661 cypher, params = self._build_count_cypher() 662 results = self._execute_query(cypher, params) 663 return cast(int, results[0]["count"]) if results else 0 664 665 def exists(self) -> bool: 666 """Check if any matching records exist.""" 667 cypher, params = self._build_exists_cypher() 668 results = self._execute_query(cypher, params) 669 return len(results) > 0 670 671 def delete(self) -> int: 672 """Delete all matching records (DETACH DELETE).""" 673 cypher, params = self._build_delete_cypher() 674 with self._session._db_session.tx() as tx: 675 results = tx.query(cypher, params) 676 tx.commit() 677 return results[0].to_dict()["count"] if results else 0 678 679 def update(self, **kwargs: Any) -> int: 680 """Update all matching records.""" 681 cypher, params = self._build_update_cypher(**kwargs) 682 with self._session._db_session.tx() as tx: 683 results = tx.query(cypher, params) 684 tx.commit() 685 return results[0].to_dict()["count"] if results else 0
Immutable, type-safe query builder for graph queries.
Each method returns a new QueryBuilder instance. The original is never mutated. Provides a fluent API for building Cypher queries with type checking and IDE autocomplete support.
Example:
adults = ( ... session.query(Person) ... .filter(Person.age >= 18) ... .order_by(Person.name) ... .limit(10) ... .all() ... )
629 def all(self) -> list[NodeT]: 630 """Execute the query and return all results.""" 631 cypher, params = self._build_cypher() 632 results = self._execute_query(cypher, params) 633 instances = self._rows_to_instances(results) 634 if self._eager_load and instances: 635 self._session._eager_load_relationships(instances, self._eager_load) 636 return instances
Execute the query and return all results.
638 def first(self) -> NodeT | None: 639 """Execute the query and return the first result.""" 640 clone = self._clone() 641 clone._limit = 1 642 results = clone.all() 643 return results[0] if results else None
Execute the query and return the first result.
645 def one(self) -> NodeT: 646 """Execute the query and return exactly one result. 647 648 Raises QueryError if no results or more than one result. 649 """ 650 clone = self._clone() 651 clone._limit = 2 652 results = clone.all() 653 if not results: 654 raise QueryError("Query returned no results") 655 if len(results) > 1: 656 raise QueryError("Query returned more than one result") 657 return results[0]
Execute the query and return exactly one result.
Raises QueryError if no results or more than one result.
659 def count(self) -> int: 660 """Execute the query and return the count of results.""" 661 cypher, params = self._build_count_cypher() 662 results = self._execute_query(cypher, params) 663 return cast(int, results[0]["count"]) if results else 0
Execute the query and return the count of results.
665 def exists(self) -> bool: 666 """Check if any matching records exist.""" 667 cypher, params = self._build_exists_cypher() 668 results = self._execute_query(cypher, params) 669 return len(results) > 0
Check if any matching records exist.
671 def delete(self) -> int: 672 """Delete all matching records (DETACH DELETE).""" 673 cypher, params = self._build_delete_cypher() 674 with self._session._db_session.tx() as tx: 675 results = tx.query(cypher, params) 676 tx.commit() 677 return results[0].to_dict()["count"] if results else 0
Delete all matching records (DETACH DELETE).
679 def update(self, **kwargs: Any) -> int: 680 """Update all matching records.""" 681 cypher, params = self._build_update_cypher(**kwargs) 682 with self._session._db_session.tx() as tx: 683 results = tx.query(cypher, params) 684 tx.commit() 685 return results[0].to_dict()["count"] if results else 0
Update all matching records.
26class AsyncQueryBuilder(_QueryBuilderBase[NodeT]): 27 """ 28 Immutable, async query builder for graph queries. 29 30 Inherits all Cypher-building and immutable builder methods from 31 ``_QueryBuilderBase``. Only the execution methods are async. 32 """ 33 34 def __init__(self, session: AsyncUniSession, model: type[NodeT]) -> None: 35 self._init_state(session, model) 36 37 async def _execute_query( 38 self, cypher: str, params: dict[str, Any] 39 ) -> list[dict[str, Any]]: 40 """Execute a query, using query_with if timeout/max_memory is set.""" 41 if self._timeout is not None or self._max_memory is not None: 42 builder = self._session._db_session.query_with(cypher) 43 if params: 44 builder = builder.params(params) 45 if self._timeout is not None: 46 builder = builder.timeout(self._timeout) 47 if self._max_memory is not None: 48 builder = builder.max_memory(self._max_memory) 49 result = await builder.fetch_all() 50 else: 51 result = await self._session._db_session.query(cypher, params) 52 return [row.to_dict() for row in result] 53 54 async def all(self) -> list[NodeT]: 55 """Execute the query and return all results.""" 56 cypher, params = self._build_cypher() 57 results = await self._execute_query(cypher, params) 58 instances = self._rows_to_instances(results) 59 if self._eager_load and instances: 60 await self._session._async_eager_load_relationships( 61 instances, self._eager_load 62 ) 63 return instances 64 65 async def first(self) -> NodeT | None: 66 """Execute the query and return the first result.""" 67 clone = self._clone() 68 clone._limit = 1 69 results = await clone.all() 70 return results[0] if results else None 71 72 async def one(self) -> NodeT: 73 """Execute the query and return exactly one result. 74 75 Raises QueryError if no results or more than one result. 76 """ 77 clone = self._clone() 78 clone._limit = 2 79 results = await clone.all() 80 if not results: 81 raise QueryError("Query returned no results") 82 if len(results) > 1: 83 raise QueryError("Query returned more than one result") 84 return results[0] 85 86 async def count(self) -> int: 87 """Execute the query and return the count of results.""" 88 cypher, params = self._build_count_cypher() 89 results = await self._execute_query(cypher, params) 90 return cast(int, results[0]["count"]) if results else 0 91 92 async def exists(self) -> bool: 93 """Check if any matching records exist.""" 94 cypher, params = self._build_exists_cypher() 95 results = await self._execute_query(cypher, params) 96 return len(results) > 0 97 98 async def delete(self) -> int: 99 """Delete all matching records (DETACH DELETE).""" 100 cypher, params = self._build_delete_cypher() 101 async with await self._session._db_session.tx() as tx: 102 results = await tx.query(cypher, params) 103 await tx.commit() 104 return results[0].to_dict()["count"] if results else 0 105 106 async def update(self, **kwargs: Any) -> int: 107 """Update all matching records.""" 108 cypher, params = self._build_update_cypher(**kwargs) 109 async with await self._session._db_session.tx() as tx: 110 results = await tx.query(cypher, params) 111 await tx.commit() 112 return results[0].to_dict()["count"] if results else 0
Immutable, async query builder for graph queries.
Inherits all Cypher-building and immutable builder methods from
_QueryBuilderBase. Only the execution methods are async.
54 async def all(self) -> list[NodeT]: 55 """Execute the query and return all results.""" 56 cypher, params = self._build_cypher() 57 results = await self._execute_query(cypher, params) 58 instances = self._rows_to_instances(results) 59 if self._eager_load and instances: 60 await self._session._async_eager_load_relationships( 61 instances, self._eager_load 62 ) 63 return instances
Execute the query and return all results.
65 async def first(self) -> NodeT | None: 66 """Execute the query and return the first result.""" 67 clone = self._clone() 68 clone._limit = 1 69 results = await clone.all() 70 return results[0] if results else None
Execute the query and return the first result.
72 async def one(self) -> NodeT: 73 """Execute the query and return exactly one result. 74 75 Raises QueryError if no results or more than one result. 76 """ 77 clone = self._clone() 78 clone._limit = 2 79 results = await clone.all() 80 if not results: 81 raise QueryError("Query returned no results") 82 if len(results) > 1: 83 raise QueryError("Query returned more than one result") 84 return results[0]
Execute the query and return exactly one result.
Raises QueryError if no results or more than one result.
86 async def count(self) -> int: 87 """Execute the query and return the count of results.""" 88 cypher, params = self._build_count_cypher() 89 results = await self._execute_query(cypher, params) 90 return cast(int, results[0]["count"]) if results else 0
Execute the query and return the count of results.
92 async def exists(self) -> bool: 93 """Check if any matching records exist.""" 94 cypher, params = self._build_exists_cypher() 95 results = await self._execute_query(cypher, params) 96 return len(results) > 0
Check if any matching records exist.
98 async def delete(self) -> int: 99 """Delete all matching records (DETACH DELETE).""" 100 cypher, params = self._build_delete_cypher() 101 async with await self._session._db_session.tx() as tx: 102 results = await tx.query(cypher, params) 103 await tx.commit() 104 return results[0].to_dict()["count"] if results else 0
Delete all matching records (DETACH DELETE).
106 async def update(self, **kwargs: Any) -> int: 107 """Update all matching records.""" 108 cypher, params = self._build_update_cypher(**kwargs) 109 async with await self._session._db_session.tx() as tx: 110 results = await tx.query(cypher, params) 111 await tx.commit() 112 return results[0].to_dict()["count"] if results else 0
Update all matching records.
116@dataclass 117class FilterExpr: 118 """A filter expression for a query.""" 119 120 property_name: str 121 op: FilterOp 122 value: Any = None 123 124 def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]: 125 """Convert to Cypher WHERE clause fragment.""" 126 prop = f"{node_var}.{self.property_name}" 127 128 if self.op == FilterOp.IS_NULL: 129 return f"{prop} IS NULL", {} 130 elif self.op == FilterOp.IS_NOT_NULL: 131 return f"{prop} IS NOT NULL", {} 132 elif self.op == FilterOp.IN: 133 return f"{prop} IN ${param_name}", {param_name: self.value} 134 elif self.op == FilterOp.NOT_IN: 135 return f"NOT {prop} IN ${param_name}", {param_name: self.value} 136 elif self.op == FilterOp.LIKE: 137 return f"{prop} =~ ${param_name}", {param_name: self.value} 138 elif self.op == FilterOp.STARTS_WITH: 139 return f"{prop} STARTS WITH ${param_name}", {param_name: self.value} 140 elif self.op == FilterOp.ENDS_WITH: 141 return f"{prop} ENDS WITH ${param_name}", {param_name: self.value} 142 elif self.op == FilterOp.CONTAINS: 143 return f"{prop} CONTAINS ${param_name}", {param_name: self.value} 144 else: 145 return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}
A filter expression for a query.
124 def to_cypher(self, node_var: str, param_name: str) -> tuple[str, dict[str, Any]]: 125 """Convert to Cypher WHERE clause fragment.""" 126 prop = f"{node_var}.{self.property_name}" 127 128 if self.op == FilterOp.IS_NULL: 129 return f"{prop} IS NULL", {} 130 elif self.op == FilterOp.IS_NOT_NULL: 131 return f"{prop} IS NOT NULL", {} 132 elif self.op == FilterOp.IN: 133 return f"{prop} IN ${param_name}", {param_name: self.value} 134 elif self.op == FilterOp.NOT_IN: 135 return f"NOT {prop} IN ${param_name}", {param_name: self.value} 136 elif self.op == FilterOp.LIKE: 137 return f"{prop} =~ ${param_name}", {param_name: self.value} 138 elif self.op == FilterOp.STARTS_WITH: 139 return f"{prop} STARTS WITH ${param_name}", {param_name: self.value} 140 elif self.op == FilterOp.ENDS_WITH: 141 return f"{prop} ENDS WITH ${param_name}", {param_name: self.value} 142 elif self.op == FilterOp.CONTAINS: 143 return f"{prop} CONTAINS ${param_name}", {param_name: self.value} 144 else: 145 return f"{prop} {self.op.value} ${param_name}", {param_name: self.value}
Convert to Cypher WHERE clause fragment.
97class FilterOp(Enum): 98 """Filter operation types.""" 99 100 EQ = "=" 101 NE = "<>" 102 LT = "<" 103 LE = "<=" 104 GT = ">" 105 GE = ">=" 106 IN = "IN" 107 NOT_IN = "NOT IN" 108 LIKE = "=~" 109 IS_NULL = "IS NULL" 110 IS_NOT_NULL = "IS NOT NULL" 111 STARTS_WITH = "STARTS WITH" 112 ENDS_WITH = "ENDS WITH" 113 CONTAINS = "CONTAINS"
Filter operation types.
148class PropertyProxy(Generic[T]): 149 """ 150 Proxy for model properties that enables filter expressions. 151 152 Used in query builder to create type-safe filter conditions. 153 154 Example: 155 >>> query.filter(Person.age >= 18) 156 >>> query.filter(Person.name.starts_with("A")) 157 """ 158 159 def __init__(self, property_name: str, model: type[UniNode]) -> None: 160 self._property_name = property_name 161 self._model = model 162 163 def __eq__(self, other: Any) -> FilterExpr: # type: ignore[override] 164 return FilterExpr(self._property_name, FilterOp.EQ, other) 165 166 def __ne__(self, other: Any) -> FilterExpr: # type: ignore[override] 167 return FilterExpr(self._property_name, FilterOp.NE, other) 168 169 def __lt__(self, other: Any) -> FilterExpr: 170 return FilterExpr(self._property_name, FilterOp.LT, other) 171 172 def __le__(self, other: Any) -> FilterExpr: 173 return FilterExpr(self._property_name, FilterOp.LE, other) 174 175 def __gt__(self, other: Any) -> FilterExpr: 176 return FilterExpr(self._property_name, FilterOp.GT, other) 177 178 def __ge__(self, other: Any) -> FilterExpr: 179 return FilterExpr(self._property_name, FilterOp.GE, other) 180 181 def in_(self, values: Sequence[T]) -> FilterExpr: 182 """Check if value is in a list.""" 183 return FilterExpr(self._property_name, FilterOp.IN, list(values)) 184 185 def not_in(self, values: Sequence[T]) -> FilterExpr: 186 """Check if value is not in a list.""" 187 return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values)) 188 189 def like(self, pattern: str) -> FilterExpr: 190 """Match a regex pattern.""" 191 return FilterExpr(self._property_name, FilterOp.LIKE, pattern) 192 193 def is_null(self) -> FilterExpr: 194 """Check if value is null.""" 195 return FilterExpr(self._property_name, FilterOp.IS_NULL) 196 197 def is_not_null(self) -> FilterExpr: 198 """Check if value is not null.""" 199 return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL) 200 201 def starts_with(self, prefix: str) -> FilterExpr: 202 """Check if string starts with prefix.""" 203 return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix) 204 205 def ends_with(self, suffix: str) -> FilterExpr: 206 """Check if string ends with suffix.""" 207 return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix) 208 209 def contains(self, substring: str) -> FilterExpr: 210 """Check if string contains substring.""" 211 return FilterExpr(self._property_name, FilterOp.CONTAINS, substring)
Proxy for model properties that enables filter expressions.
Used in query builder to create type-safe filter conditions.
Example:
query.filter(Person.age >= 18) query.filter(Person.name.starts_with("A"))
181 def in_(self, values: Sequence[T]) -> FilterExpr: 182 """Check if value is in a list.""" 183 return FilterExpr(self._property_name, FilterOp.IN, list(values))
Check if value is in a list.
185 def not_in(self, values: Sequence[T]) -> FilterExpr: 186 """Check if value is not in a list.""" 187 return FilterExpr(self._property_name, FilterOp.NOT_IN, list(values))
Check if value is not in a list.
189 def like(self, pattern: str) -> FilterExpr: 190 """Match a regex pattern.""" 191 return FilterExpr(self._property_name, FilterOp.LIKE, pattern)
Match a regex pattern.
193 def is_null(self) -> FilterExpr: 194 """Check if value is null.""" 195 return FilterExpr(self._property_name, FilterOp.IS_NULL)
Check if value is null.
197 def is_not_null(self) -> FilterExpr: 198 """Check if value is not null.""" 199 return FilterExpr(self._property_name, FilterOp.IS_NOT_NULL)
Check if value is not null.
201 def starts_with(self, prefix: str) -> FilterExpr: 202 """Check if string starts with prefix.""" 203 return FilterExpr(self._property_name, FilterOp.STARTS_WITH, prefix)
Check if string starts with prefix.
205 def ends_with(self, suffix: str) -> FilterExpr: 206 """Check if string ends with suffix.""" 207 return FilterExpr(self._property_name, FilterOp.ENDS_WITH, suffix)
Check if string ends with suffix.
214class ModelProxy(Generic[NodeT]): 215 """ 216 Proxy for model classes that provides property proxies. 217 218 Enables type-safe property access in query filters. 219 220 Example: 221 >>> Person.name # Returns PropertyProxy for 'name' 222 >>> query.filter(Person.age >= 18) 223 """ 224 225 def __init__(self, model: type[NodeT]) -> None: 226 self._model = model 227 228 def __getattr__(self, name: str) -> PropertyProxy[Any]: 229 if name.startswith("_"): 230 raise AttributeError(name) 231 return PropertyProxy(name, self._model)
Proxy for model classes that provides property proxies.
Enables type-safe property access in query filters.
Example:
Person.name # Returns PropertyProxy for 'name' query.filter(Person.age >= 18)
234@dataclass 235class OrderByClause: 236 """An ORDER BY clause.""" 237 238 property_name: str 239 descending: bool = False
An ORDER BY clause.
242@dataclass 243class TraversalStep: 244 """A relationship traversal step.""" 245 246 edge_type: str 247 direction: Literal["outgoing", "incoming", "both"] 248 target_label: str | None = None
A relationship traversal step.
251@dataclass 252class VectorSearchConfig: 253 """Configuration for vector similarity search.""" 254 255 property_name: str 256 query_vector: list[float] 257 k: int 258 threshold: float | None = None 259 pre_filter: str | None = None
Configuration for vector similarity search.
60class SchemaGenerator: 61 """Generates Uni database schema from registered models.""" 62 63 def __init__(self) -> None: 64 self._node_models: dict[str, type[UniNode]] = {} 65 self._edge_models: dict[str, type[UniEdge]] = {} 66 self._schema: DatabaseSchema | None = None 67 68 def register_node(self, model: type[UniNode]) -> None: 69 """Register a node model for schema generation.""" 70 label = model.__label__ 71 if not label: 72 raise SchemaError(f"Model {model.__name__} has no __label__", model) 73 self._node_models[label] = model 74 self._schema = None # Invalidate cached schema 75 76 def register_edge(self, model: type[UniEdge]) -> None: 77 """Register an edge model for schema generation.""" 78 edge_type = model.__edge_type__ 79 if not edge_type: 80 raise SchemaError(f"Model {model.__name__} has no __edge_type__", model) 81 self._edge_models[edge_type] = model 82 self._schema = None 83 84 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 85 """Register multiple models.""" 86 for model in models: 87 if issubclass(model, UniEdge): 88 self.register_edge(model) 89 elif issubclass(model, UniNode): 90 self.register_node(model) 91 else: 92 raise SchemaError( 93 f"Model {model.__name__} must be a subclass of UniNode or UniEdge" 94 ) 95 96 def _generate_property_schema( 97 self, 98 model: type[UniNode] | type[UniEdge], 99 field_name: str, 100 ) -> PropertySchema: 101 """Generate schema for a single property field.""" 102 field_info = model.model_fields[field_name] 103 104 # Get type hints with forward refs resolved 105 try: 106 hints = get_type_hints(model) 107 type_hint = hints.get(field_name, field_info.annotation) 108 except Exception: 109 type_hint = field_info.annotation 110 111 # Check for nullability 112 is_nullable, inner_type = is_optional(type_hint) 113 114 # Get Uni data type 115 data_type, nullable = python_type_to_uni(type_hint, nullable=is_nullable) 116 117 # Check for vector dimensions 118 vec_dims = get_vector_dimensions(inner_type if is_nullable else type_hint) 119 if vec_dims: 120 data_type = f"vector:{vec_dims}" 121 122 # Get field config for index settings 123 config = get_field_config(field_info) 124 index_type = config.index if config else None 125 unique = config.unique if config else False 126 tokenizer = config.tokenizer if config else None 127 metric = config.metric if config else None 128 129 # Auto-create vector index for Vector fields (regardless of Field config) 130 if vec_dims and not index_type: 131 index_type = "vector" 132 133 return PropertySchema( 134 name=field_name, 135 data_type=data_type, 136 nullable=nullable, 137 index_type=index_type, 138 unique=unique, 139 tokenizer=tokenizer, 140 metric=metric, 141 ) 142 143 def _generate_label_schema(self, model: type[UniNode]) -> LabelSchema: 144 """Generate schema for a node model.""" 145 label = model.__label__ 146 147 properties = {} 148 for field_name in model.get_property_fields(): 149 prop_schema = self._generate_property_schema(model, field_name) 150 properties[field_name] = prop_schema 151 152 return LabelSchema( 153 name=label, 154 properties=properties, 155 ) 156 157 def _generate_edge_type_schema(self, model: type[UniEdge]) -> EdgeTypeSchema: 158 """Generate schema for an edge model.""" 159 edge_type = model.__edge_type__ 160 from_labels = model.get_from_labels() 161 to_labels = model.get_to_labels() 162 163 # If from/to not specified, allow any labels 164 if not from_labels: 165 from_labels = list(self._node_models.keys()) 166 if not to_labels: 167 to_labels = list(self._node_models.keys()) 168 169 properties = {} 170 for field_name in model.get_property_fields(): 171 prop_schema = self._generate_property_schema(model, field_name) 172 properties[field_name] = prop_schema 173 174 return EdgeTypeSchema( 175 name=edge_type, 176 from_labels=from_labels, 177 to_labels=to_labels, 178 properties=properties, 179 ) 180 181 def generate(self) -> DatabaseSchema: 182 """Generate the complete database schema.""" 183 if self._schema is not None: 184 return self._schema 185 186 schema = DatabaseSchema() 187 188 # Generate label schemas 189 for label, model in self._node_models.items(): 190 schema.labels[label] = self._generate_label_schema(model) 191 192 # Generate edge type schemas 193 for edge_type_name, edge_model in self._edge_models.items(): 194 schema.edge_types[edge_type_name] = self._generate_edge_type_schema( 195 edge_model 196 ) 197 198 # Also generate labels from relationships in node models 199 for model in self._node_models.values(): 200 for rel_name, rel_config in model.get_relationship_fields().items(): 201 edge_type = rel_config.edge_type 202 if edge_type not in schema.edge_types: 203 # Create a minimal edge type schema 204 schema.edge_types[edge_type] = EdgeTypeSchema( 205 name=edge_type, 206 from_labels=list(self._node_models.keys()), 207 to_labels=list(self._node_models.keys()), 208 ) 209 210 self._schema = schema 211 return schema 212 213 def apply_to_database(self, db: uni_db.Uni) -> None: 214 """Apply the generated schema to a database using SchemaBuilder. 215 216 Uses db.schema() for atomic schema application with additive-only 217 semantics. Creates labels, edge types, properties, and indexes. 218 """ 219 schema = self.generate() 220 221 # Build the full schema using SchemaBuilder, skipping existing labels/edge types 222 builder = db.schema() 223 has_changes = False 224 225 for label, label_schema in schema.labels.items(): 226 if db.label_exists(label): 227 continue # Additive-only: skip existing labels 228 lb = builder.label(label) 229 for prop in label_schema.properties.values(): 230 # Check for vector type 231 if prop.data_type.startswith("vector:"): 232 dims = int(prop.data_type.split(":")[1]) 233 lb = lb.vector(prop.name, dims) 234 elif prop.nullable: 235 lb = lb.property_nullable(prop.name, prop.data_type) 236 else: 237 lb = lb.property(prop.name, prop.data_type) 238 239 # Add indexes (not vector — vector is handled by .vector()) 240 if prop.index_type and prop.index_type in ("btree", "hash"): 241 lb = lb.index(prop.name, prop.index_type) 242 builder = lb.done() 243 has_changes = True 244 245 for edge_type, edge_schema in schema.edge_types.items(): 246 if db.edge_type_exists(edge_type): 247 continue # Skip existing edge types 248 eb = builder.edge_type( 249 edge_type, edge_schema.from_labels, edge_schema.to_labels 250 ) 251 for prop in edge_schema.properties.values(): 252 if prop.nullable: 253 eb = eb.property_nullable(prop.name, prop.data_type) 254 else: 255 eb = eb.property(prop.name, prop.data_type) 256 builder = eb.done() 257 has_changes = True 258 259 if has_changes: 260 builder.apply() 261 262 # Create vector and fulltext indexes via schema builder 263 for label, label_schema in schema.labels.items(): 264 for prop in label_schema.properties.values(): 265 if prop.index_type == "vector": 266 metric = prop.metric or "l2" 267 try: 268 db.schema().label(label).index( 269 prop.name, {"type": "vector", "metric": metric} 270 ).apply() 271 except Exception: 272 pass # Index may already exist 273 elif prop.index_type == "fulltext": 274 try: 275 db.schema().label(label).index(prop.name, "fulltext").apply() 276 except Exception: 277 pass # Index may already exist 278 279 async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None: 280 """Apply the generated schema to an async database. 281 282 Async variant of apply_to_database using AsyncSchemaBuilder. 283 """ 284 schema = self.generate() 285 286 # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types 287 builder = db.schema() 288 has_changes = False 289 290 for label, label_schema in schema.labels.items(): 291 if await db.label_exists(label): 292 continue 293 lb = builder.label(label) 294 for prop in label_schema.properties.values(): 295 if prop.data_type.startswith("vector:"): 296 dims = int(prop.data_type.split(":")[1]) 297 lb = lb.vector(prop.name, dims) 298 elif prop.nullable: 299 lb = lb.property_nullable(prop.name, prop.data_type) 300 else: 301 lb = lb.property(prop.name, prop.data_type) 302 303 if prop.index_type and prop.index_type in ("btree", "hash"): 304 lb = lb.index(prop.name, prop.index_type) 305 builder = lb.done() 306 has_changes = True 307 308 for edge_type, edge_schema in schema.edge_types.items(): 309 if await db.edge_type_exists(edge_type): 310 continue 311 eb = builder.edge_type( 312 edge_type, edge_schema.from_labels, edge_schema.to_labels 313 ) 314 for prop in edge_schema.properties.values(): 315 if prop.nullable: 316 eb = eb.property_nullable(prop.name, prop.data_type) 317 else: 318 eb = eb.property(prop.name, prop.data_type) 319 builder = eb.done() 320 has_changes = True 321 322 if has_changes: 323 await builder.apply() 324 325 # Create vector and fulltext indexes via schema builder 326 for label, label_schema in schema.labels.items(): 327 for prop in label_schema.properties.values(): 328 if prop.index_type == "vector": 329 metric = prop.metric or "l2" 330 try: 331 await ( 332 db.schema() 333 .label(label) 334 .index(prop.name, {"type": "vector", "metric": metric}) 335 .apply() 336 ) 337 except Exception: 338 pass # Index may already exist 339 elif prop.index_type == "fulltext": 340 try: 341 await ( 342 db.schema() 343 .label(label) 344 .index(prop.name, "fulltext") 345 .apply() 346 ) 347 except Exception: 348 pass # Index may already exist
Generates Uni database schema from registered models.
68 def register_node(self, model: type[UniNode]) -> None: 69 """Register a node model for schema generation.""" 70 label = model.__label__ 71 if not label: 72 raise SchemaError(f"Model {model.__name__} has no __label__", model) 73 self._node_models[label] = model 74 self._schema = None # Invalidate cached schema
Register a node model for schema generation.
76 def register_edge(self, model: type[UniEdge]) -> None: 77 """Register an edge model for schema generation.""" 78 edge_type = model.__edge_type__ 79 if not edge_type: 80 raise SchemaError(f"Model {model.__name__} has no __edge_type__", model) 81 self._edge_models[edge_type] = model 82 self._schema = None
Register an edge model for schema generation.
84 def register(self, *models: type[UniNode] | type[UniEdge]) -> None: 85 """Register multiple models.""" 86 for model in models: 87 if issubclass(model, UniEdge): 88 self.register_edge(model) 89 elif issubclass(model, UniNode): 90 self.register_node(model) 91 else: 92 raise SchemaError( 93 f"Model {model.__name__} must be a subclass of UniNode or UniEdge" 94 )
Register multiple models.
181 def generate(self) -> DatabaseSchema: 182 """Generate the complete database schema.""" 183 if self._schema is not None: 184 return self._schema 185 186 schema = DatabaseSchema() 187 188 # Generate label schemas 189 for label, model in self._node_models.items(): 190 schema.labels[label] = self._generate_label_schema(model) 191 192 # Generate edge type schemas 193 for edge_type_name, edge_model in self._edge_models.items(): 194 schema.edge_types[edge_type_name] = self._generate_edge_type_schema( 195 edge_model 196 ) 197 198 # Also generate labels from relationships in node models 199 for model in self._node_models.values(): 200 for rel_name, rel_config in model.get_relationship_fields().items(): 201 edge_type = rel_config.edge_type 202 if edge_type not in schema.edge_types: 203 # Create a minimal edge type schema 204 schema.edge_types[edge_type] = EdgeTypeSchema( 205 name=edge_type, 206 from_labels=list(self._node_models.keys()), 207 to_labels=list(self._node_models.keys()), 208 ) 209 210 self._schema = schema 211 return schema
Generate the complete database schema.
213 def apply_to_database(self, db: uni_db.Uni) -> None: 214 """Apply the generated schema to a database using SchemaBuilder. 215 216 Uses db.schema() for atomic schema application with additive-only 217 semantics. Creates labels, edge types, properties, and indexes. 218 """ 219 schema = self.generate() 220 221 # Build the full schema using SchemaBuilder, skipping existing labels/edge types 222 builder = db.schema() 223 has_changes = False 224 225 for label, label_schema in schema.labels.items(): 226 if db.label_exists(label): 227 continue # Additive-only: skip existing labels 228 lb = builder.label(label) 229 for prop in label_schema.properties.values(): 230 # Check for vector type 231 if prop.data_type.startswith("vector:"): 232 dims = int(prop.data_type.split(":")[1]) 233 lb = lb.vector(prop.name, dims) 234 elif prop.nullable: 235 lb = lb.property_nullable(prop.name, prop.data_type) 236 else: 237 lb = lb.property(prop.name, prop.data_type) 238 239 # Add indexes (not vector — vector is handled by .vector()) 240 if prop.index_type and prop.index_type in ("btree", "hash"): 241 lb = lb.index(prop.name, prop.index_type) 242 builder = lb.done() 243 has_changes = True 244 245 for edge_type, edge_schema in schema.edge_types.items(): 246 if db.edge_type_exists(edge_type): 247 continue # Skip existing edge types 248 eb = builder.edge_type( 249 edge_type, edge_schema.from_labels, edge_schema.to_labels 250 ) 251 for prop in edge_schema.properties.values(): 252 if prop.nullable: 253 eb = eb.property_nullable(prop.name, prop.data_type) 254 else: 255 eb = eb.property(prop.name, prop.data_type) 256 builder = eb.done() 257 has_changes = True 258 259 if has_changes: 260 builder.apply() 261 262 # Create vector and fulltext indexes via schema builder 263 for label, label_schema in schema.labels.items(): 264 for prop in label_schema.properties.values(): 265 if prop.index_type == "vector": 266 metric = prop.metric or "l2" 267 try: 268 db.schema().label(label).index( 269 prop.name, {"type": "vector", "metric": metric} 270 ).apply() 271 except Exception: 272 pass # Index may already exist 273 elif prop.index_type == "fulltext": 274 try: 275 db.schema().label(label).index(prop.name, "fulltext").apply() 276 except Exception: 277 pass # Index may already exist
Apply the generated schema to a database using SchemaBuilder.
Uses db.schema() for atomic schema application with additive-only semantics. Creates labels, edge types, properties, and indexes.
279 async def async_apply_to_database(self, db: uni_db.AsyncUni) -> None: 280 """Apply the generated schema to an async database. 281 282 Async variant of apply_to_database using AsyncSchemaBuilder. 283 """ 284 schema = self.generate() 285 286 # Build the full schema using AsyncSchemaBuilder, skipping existing labels/edge types 287 builder = db.schema() 288 has_changes = False 289 290 for label, label_schema in schema.labels.items(): 291 if await db.label_exists(label): 292 continue 293 lb = builder.label(label) 294 for prop in label_schema.properties.values(): 295 if prop.data_type.startswith("vector:"): 296 dims = int(prop.data_type.split(":")[1]) 297 lb = lb.vector(prop.name, dims) 298 elif prop.nullable: 299 lb = lb.property_nullable(prop.name, prop.data_type) 300 else: 301 lb = lb.property(prop.name, prop.data_type) 302 303 if prop.index_type and prop.index_type in ("btree", "hash"): 304 lb = lb.index(prop.name, prop.index_type) 305 builder = lb.done() 306 has_changes = True 307 308 for edge_type, edge_schema in schema.edge_types.items(): 309 if await db.edge_type_exists(edge_type): 310 continue 311 eb = builder.edge_type( 312 edge_type, edge_schema.from_labels, edge_schema.to_labels 313 ) 314 for prop in edge_schema.properties.values(): 315 if prop.nullable: 316 eb = eb.property_nullable(prop.name, prop.data_type) 317 else: 318 eb = eb.property(prop.name, prop.data_type) 319 builder = eb.done() 320 has_changes = True 321 322 if has_changes: 323 await builder.apply() 324 325 # Create vector and fulltext indexes via schema builder 326 for label, label_schema in schema.labels.items(): 327 for prop in label_schema.properties.values(): 328 if prop.index_type == "vector": 329 metric = prop.metric or "l2" 330 try: 331 await ( 332 db.schema() 333 .label(label) 334 .index(prop.name, {"type": "vector", "metric": metric}) 335 .apply() 336 ) 337 except Exception: 338 pass # Index may already exist 339 elif prop.index_type == "fulltext": 340 try: 341 await ( 342 db.schema() 343 .label(label) 344 .index(prop.name, "fulltext") 345 .apply() 346 ) 347 except Exception: 348 pass # Index may already exist
Apply the generated schema to an async database.
Async variant of apply_to_database using AsyncSchemaBuilder.
52@dataclass 53class DatabaseSchema: 54 """Complete database schema generated from models.""" 55 56 labels: dict[str, LabelSchema] = field(default_factory=dict) 57 edge_types: dict[str, EdgeTypeSchema] = field(default_factory=dict)
Complete database schema generated from models.
34@dataclass 35class LabelSchema: 36 """Schema for a vertex label.""" 37 38 name: str 39 properties: dict[str, PropertySchema] = field(default_factory=dict)
Schema for a vertex label.
42@dataclass 43class EdgeTypeSchema: 44 """Schema for an edge type.""" 45 46 name: str 47 from_labels: list[str] = field(default_factory=list) 48 to_labels: list[str] = field(default_factory=list) 49 properties: dict[str, PropertySchema] = field(default_factory=dict)
Schema for an edge type.
21@dataclass 22class PropertySchema: 23 """Schema for a single property.""" 24 25 name: str 26 data_type: str 27 nullable: bool = False 28 index_type: str | None = None 29 unique: bool = False 30 tokenizer: str | None = None 31 metric: str | None = None
Schema for a single property.
351def generate_schema(*models: type[UniNode] | type[UniEdge]) -> DatabaseSchema: 352 """Generate a database schema from the given models.""" 353 generator = SchemaGenerator() 354 generator.register(*models) 355 return generator.generate()
Generate a database schema from the given models.
15class UniDatabase: 16 """ 17 Thin wrapper around uni-db UniBuilder for ergonomic database creation. 18 19 Example: 20 >>> db = UniDatabase.open("./path").cache_size(1024*1024).build() 21 >>> db = UniDatabase.temporary().build() 22 >>> db = UniDatabase.in_memory().build() 23 """ 24 25 def __init__(self, builder: uni_db.UniBuilder) -> None: 26 self._builder = builder 27 28 @classmethod 29 def open(cls, path: str) -> UniDatabase: 30 """Open or create a database at the given path.""" 31 import uni_db 32 33 return cls(uni_db.UniBuilder.open(path)) 34 35 @classmethod 36 def create(cls, path: str) -> UniDatabase: 37 """Create a new database at the given path.""" 38 import uni_db 39 40 return cls(uni_db.UniBuilder.create(path)) 41 42 @classmethod 43 def open_existing(cls, path: str) -> UniDatabase: 44 """Open an existing database (must already exist).""" 45 import uni_db 46 47 return cls(uni_db.UniBuilder.open_existing(path)) 48 49 @classmethod 50 def temporary(cls) -> UniDatabase: 51 """Create an ephemeral in-memory database.""" 52 import uni_db 53 54 return cls(uni_db.UniBuilder.temporary()) 55 56 @classmethod 57 def in_memory(cls) -> UniDatabase: 58 """Create a persistent in-memory database.""" 59 import uni_db 60 61 return cls(uni_db.UniBuilder.in_memory()) 62 63 def cache_size(self, bytes_: int) -> UniDatabase: 64 """Set the cache size in bytes.""" 65 self._builder = self._builder.cache_size(bytes_) 66 return self 67 68 def parallelism(self, n: int) -> UniDatabase: 69 """Set the parallelism level.""" 70 self._builder = self._builder.parallelism(n) 71 return self 72 73 def build(self) -> uni_db.Uni: 74 """Build and return the database instance.""" 75 return self._builder.build()
Thin wrapper around uni-db UniBuilder for ergonomic database creation.
Example:
db = UniDatabase.open("./path").cache_size(1024*1024).build() db = UniDatabase.temporary().build() db = UniDatabase.in_memory().build()
28 @classmethod 29 def open(cls, path: str) -> UniDatabase: 30 """Open or create a database at the given path.""" 31 import uni_db 32 33 return cls(uni_db.UniBuilder.open(path))
Open or create a database at the given path.
35 @classmethod 36 def create(cls, path: str) -> UniDatabase: 37 """Create a new database at the given path.""" 38 import uni_db 39 40 return cls(uni_db.UniBuilder.create(path))
Create a new database at the given path.
42 @classmethod 43 def open_existing(cls, path: str) -> UniDatabase: 44 """Open an existing database (must already exist).""" 45 import uni_db 46 47 return cls(uni_db.UniBuilder.open_existing(path))
Open an existing database (must already exist).
49 @classmethod 50 def temporary(cls) -> UniDatabase: 51 """Create an ephemeral in-memory database.""" 52 import uni_db 53 54 return cls(uni_db.UniBuilder.temporary())
Create an ephemeral in-memory database.
56 @classmethod 57 def in_memory(cls) -> UniDatabase: 58 """Create a persistent in-memory database.""" 59 import uni_db 60 61 return cls(uni_db.UniBuilder.in_memory())
Create a persistent in-memory database.
63 def cache_size(self, bytes_: int) -> UniDatabase: 64 """Set the cache size in bytes.""" 65 self._builder = self._builder.cache_size(bytes_) 66 return self
Set the cache size in bytes.
78class AsyncUniDatabase: 79 """ 80 Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation. 81 82 Example: 83 >>> db = await AsyncUniDatabase.open("./path").build() 84 >>> db = await AsyncUniDatabase.temporary().build() 85 """ 86 87 def __init__(self, builder: uni_db.AsyncUniBuilder) -> None: 88 self._builder = builder 89 90 @classmethod 91 def open(cls, path: str) -> AsyncUniDatabase: 92 """Open or create a database at the given path.""" 93 import uni_db 94 95 return cls(uni_db.AsyncUniBuilder.open(path)) 96 97 @classmethod 98 def temporary(cls) -> AsyncUniDatabase: 99 """Create an ephemeral in-memory database.""" 100 import uni_db 101 102 return cls(uni_db.AsyncUniBuilder.temporary()) 103 104 @classmethod 105 def in_memory(cls) -> AsyncUniDatabase: 106 """Create a persistent in-memory database.""" 107 import uni_db 108 109 return cls(uni_db.AsyncUniBuilder.in_memory()) 110 111 def cache_size(self, bytes_: int) -> AsyncUniDatabase: 112 """Set the cache size in bytes.""" 113 self._builder = self._builder.cache_size(bytes_) 114 return self 115 116 def parallelism(self, n: int) -> AsyncUniDatabase: 117 """Set the parallelism level.""" 118 self._builder = self._builder.parallelism(n) 119 return self 120 121 async def build(self) -> uni_db.AsyncUni: 122 """Build and return the async database instance.""" 123 return await self._builder.build()
Thin wrapper around uni-db AsyncUniBuilder for ergonomic async database creation.
Example:
db = await AsyncUniDatabase.open("./path").build() db = await AsyncUniDatabase.temporary().build()
90 @classmethod 91 def open(cls, path: str) -> AsyncUniDatabase: 92 """Open or create a database at the given path.""" 93 import uni_db 94 95 return cls(uni_db.AsyncUniBuilder.open(path))
Open or create a database at the given path.
97 @classmethod 98 def temporary(cls) -> AsyncUniDatabase: 99 """Create an ephemeral in-memory database.""" 100 import uni_db 101 102 return cls(uni_db.AsyncUniBuilder.temporary())
Create an ephemeral in-memory database.
104 @classmethod 105 def in_memory(cls) -> AsyncUniDatabase: 106 """Create a persistent in-memory database.""" 107 import uni_db 108 109 return cls(uni_db.AsyncUniBuilder.in_memory())
Create a persistent in-memory database.
111 def cache_size(self, bytes_: int) -> AsyncUniDatabase: 112 """Set the cache size in bytes.""" 113 self._builder = self._builder.cache_size(bytes_) 114 return self
Set the cache size in bytes.
40def before_create(func: F) -> F: 41 """ 42 Mark a method to be called before the entity is created in the database. 43 44 The method is called after validation but before the INSERT operation. 45 Useful for setting timestamps, generating IDs, or final validation. 46 47 Example: 48 >>> class Person(UniNode): 49 ... name: str 50 ... created_at: datetime | None = None 51 ... 52 ... @before_create 53 ... def set_created_at(self): 54 ... self.created_at = datetime.now() 55 """ 56 return _mark_hook(_BEFORE_CREATE)(func)
Mark a method to be called before the entity is created in the database.
The method is called after validation but before the INSERT operation. Useful for setting timestamps, generating IDs, or final validation.
Example:
class Person(UniNode): ... name: str ... created_at: datetime | None = None ... ... @before_create ... def set_created_at(self): ... self.created_at = datetime.now()
59def after_create(func: F) -> F: 60 """ 61 Mark a method to be called after the entity is created in the database. 62 63 The method is called after the INSERT operation completes successfully. 64 The entity will have its vid/eid assigned at this point. 65 66 Example: 67 >>> class Person(UniNode): 68 ... name: str 69 ... 70 ... @after_create 71 ... def log_creation(self): 72 ... logger.info(f"Created person {self.name} with vid={self.vid}") 73 """ 74 return _mark_hook(_AFTER_CREATE)(func)
Mark a method to be called after the entity is created in the database.
The method is called after the INSERT operation completes successfully. The entity will have its vid/eid assigned at this point.
Example:
class Person(UniNode): ... name: str ... ... @after_create ... def log_creation(self): ... logger.info(f"Created person {self.name} with vid={self.vid}")
77def before_update(func: F) -> F: 78 """ 79 Mark a method to be called before the entity is updated in the database. 80 81 The method is called before the UPDATE operation. 82 Useful for validation or updating timestamps. 83 84 Example: 85 >>> class Person(UniNode): 86 ... name: str 87 ... updated_at: datetime | None = None 88 ... 89 ... @before_update 90 ... def validate_and_timestamp(self): 91 ... if not self.name: 92 ... raise ValueError("Name cannot be empty") 93 ... self.updated_at = datetime.now() 94 """ 95 return _mark_hook(_BEFORE_UPDATE)(func)
Mark a method to be called before the entity is updated in the database.
The method is called before the UPDATE operation. Useful for validation or updating timestamps.
Example:
class Person(UniNode): ... name: str ... updated_at: datetime | None = None ... ... @before_update ... def validate_and_timestamp(self): ... if not self.name: ... raise ValueError("Name cannot be empty") ... self.updated_at = datetime.now()
98def after_update(func: F) -> F: 99 """ 100 Mark a method to be called after the entity is updated in the database. 101 102 The method is called after the UPDATE operation completes successfully. 103 104 Example: 105 >>> class Person(UniNode): 106 ... name: str 107 ... 108 ... @after_update 109 ... def notify_change(self): 110 ... events.emit("person_updated", self.vid) 111 """ 112 return _mark_hook(_AFTER_UPDATE)(func)
Mark a method to be called after the entity is updated in the database.
The method is called after the UPDATE operation completes successfully.
Example:
class Person(UniNode): ... name: str ... ... @after_update ... def notify_change(self): ... events.emit("person_updated", self.vid)
115def before_delete(func: F) -> F: 116 """ 117 Mark a method to be called before the entity is deleted from the database. 118 119 The method is called before the DELETE operation. 120 Useful for cleanup or validation. 121 122 Example: 123 >>> class Person(UniNode): 124 ... name: str 125 ... 126 ... @before_delete 127 ... def cleanup(self): 128 ... # Remove related data 129 ... pass 130 """ 131 return _mark_hook(_BEFORE_DELETE)(func)
Mark a method to be called before the entity is deleted from the database.
The method is called before the DELETE operation. Useful for cleanup or validation.
Example:
class Person(UniNode): ... name: str ... ... @before_delete ... def cleanup(self): ... # Remove related data ... pass
134def after_delete(func: F) -> F: 135 """ 136 Mark a method to be called after the entity is deleted from the database. 137 138 The method is called after the DELETE operation completes successfully. 139 The entity's vid/eid will be cleared at this point. 140 141 Example: 142 >>> class Person(UniNode): 143 ... name: str 144 ... 145 ... @after_delete 146 ... def log_deletion(self): 147 ... logger.info(f"Deleted person {self.name}") 148 """ 149 return _mark_hook(_AFTER_DELETE)(func)
Mark a method to be called after the entity is deleted from the database.
The method is called after the DELETE operation completes successfully. The entity's vid/eid will be cleared at this point.
Example:
class Person(UniNode): ... name: str ... ... @after_delete ... def log_deletion(self): ... logger.info(f"Deleted person {self.name}")
152def before_load(func: F) -> F: 153 """ 154 Mark a method to be called before the entity is loaded from the database. 155 156 This is a class method that receives the raw property dictionary. 157 Can be used to transform data before model instantiation. 158 159 Example: 160 >>> class Person(UniNode): 161 ... name: str 162 ... 163 ... @classmethod 164 ... @before_load 165 ... def transform_data(cls, props: dict) -> dict: 166 ... # Normalize name 167 ... if 'name' in props: 168 ... props['name'] = props['name'].strip() 169 ... return props 170 """ 171 return _mark_hook(_BEFORE_LOAD)(func)
Mark a method to be called before the entity is loaded from the database.
This is a class method that receives the raw property dictionary. Can be used to transform data before model instantiation.
Example:
class Person(UniNode): ... name: str ... ... @classmethod ... @before_load ... def transform_data(cls, props: dict) -> dict: ... # Normalize name ... if 'name' in props: ... props['name'] = props['name'].strip() ... return props
174def after_load(func: F) -> F: 175 """ 176 Mark a method to be called after the entity is loaded from the database. 177 178 The method is called after the entity is instantiated from database data. 179 Useful for computing derived values or initializing non-persisted state. 180 181 Example: 182 >>> class Person(UniNode): 183 ... first_name: str 184 ... last_name: str 185 ... full_name: str | None = None 186 ... 187 ... @after_load 188 ... def compute_full_name(self): 189 ... self.full_name = f"{self.first_name} {self.last_name}" 190 """ 191 return _mark_hook(_AFTER_LOAD)(func)
Mark a method to be called after the entity is loaded from the database.
The method is called after the entity is instantiated from database data. Useful for computing derived values or initializing non-persisted state.
Example:
class Person(UniNode): ... first_name: str ... last_name: str ... full_name: str | None = None ... ... @after_load ... def compute_full_name(self): ... self.full_name = f"{self.first_name} {self.last_name}"
Base exception for all uni-pydantic errors.
19class SchemaError(UniPydanticError): 20 """Error related to schema definition or generation.""" 21 22 def __init__(self, message: str, model: type | None = None) -> None: 23 self.model = model 24 super().__init__(message)
Error related to schema definition or generation.
27class TypeMappingError(SchemaError): 28 """Error mapping Python type to Uni DataType.""" 29 30 def __init__(self, python_type: Any, message: str | None = None) -> None: 31 self.python_type = python_type 32 msg = message or f"Cannot map Python type {python_type!r} to Uni DataType" 33 super().__init__(msg)
Error mapping Python type to Uni DataType.
Validation error for model instances.
Error related to session operations.
44class NotRegisteredError(SessionError): 45 """Model type not registered with session.""" 46 47 def __init__(self, model: type[UniNode] | type[UniEdge]) -> None: 48 self.model = model 49 super().__init__( 50 f"Model {model.__name__!r} is not registered with this session. " 51 f"Call session.register({model.__name__}) first." 52 )
Model type not registered with session.
55class NotPersisted(SessionError): 56 """Operation requires a persisted entity.""" 57 58 def __init__(self, entity: UniNode | UniEdge) -> None: 59 self.entity = entity 60 super().__init__( 61 f"Entity {entity!r} is not persisted. Call session.add() and commit() first." 62 )
Operation requires a persisted entity.
Entity is not tracked by this session.
Error related to transaction operations.
Error executing a query.
Error related to relationship operations.
81class LazyLoadError(RelationshipError): 82 """Error lazy-loading a relationship.""" 83 84 def __init__(self, field_name: str, reason: str) -> None: 85 self.field_name = field_name 86 super().__init__(f"Cannot lazy-load relationship '{field_name}': {reason}")
Error lazy-loading a relationship.
Error during bulk loading operations.
93class CypherInjectionError(QueryError): 94 """Property name validation failure — potential Cypher injection.""" 95 96 def __init__(self, name: str, reason: str | None = None) -> None: 97 self.name = name 98 msg = reason or f"Invalid property name {name!r}: possible Cypher injection" 99 super().__init__(msg)
Property name validation failure — potential Cypher injection.