Retrieval-Augmented Generation (RAG) with uni-pydantic¶
Combining vector search with knowledge graph traversal for hybrid retrieval over Python web framework documentation.
In [1]:
Copied!
import os
import shutil
import tempfile
import uni_db
from uni_pydantic import UniNode, UniEdge, UniSession, Field, Relationship, Vector
import os
import shutil
import tempfile
import uni_db
from uni_pydantic import UniNode, UniEdge, UniSession, Field, Relationship, Vector
1. Define Models¶
Text chunks with embeddings, linked to named entities via MENTIONS edges.
In [2]:
Copied!
class Chunk(UniNode):
"""A chunk of text with semantic embedding."""
__label__ = "Chunk"
chunk_id: str
text: str
embedding: Vector[4] = Field(metric="l2") # [auth, routing, database, testing]
# Relationships
entities: list["Entity"] = Relationship("MENTIONS", direction="outgoing")
class Entity(UniNode):
"""A named entity extracted from text."""
__label__ = "Entity"
name: str
entity_type: str = Field(default="unknown")
# Relationships
mentioned_in: list[Chunk] = Relationship("MENTIONS", direction="incoming")
class Mentions(UniEdge):
"""Edge representing a chunk mentioning an entity."""
__edge_type__ = "MENTIONS"
__from__ = Chunk
__to__ = Entity
class Chunk(UniNode):
"""A chunk of text with semantic embedding."""
__label__ = "Chunk"
chunk_id: str
text: str
embedding: Vector[4] = Field(metric="l2") # [auth, routing, database, testing]
# Relationships
entities: list["Entity"] = Relationship("MENTIONS", direction="outgoing")
class Entity(UniNode):
"""A named entity extracted from text."""
__label__ = "Entity"
name: str
entity_type: str = Field(default="unknown")
# Relationships
mentioned_in: list[Chunk] = Relationship("MENTIONS", direction="incoming")
class Mentions(UniEdge):
"""Edge representing a chunk mentioning an entity."""
__edge_type__ = "MENTIONS"
__from__ = Chunk
__to__ = Entity
2. Setup Database and Session¶
In [3]:
Copied!
db_path = os.path.join(tempfile.gettempdir(), "rag_pydantic_db")
if os.path.exists(db_path):
shutil.rmtree(db_path)
db = uni_db.Database(db_path)
# Create session and register models
session = UniSession(db)
session.register(Chunk, Entity, Mentions)
session.sync_schema()
print(f"Opened database at {db_path}")
db_path = os.path.join(tempfile.gettempdir(), "rag_pydantic_db")
if os.path.exists(db_path):
shutil.rmtree(db_path)
db = uni_db.Database(db_path)
# Create session and register models
session = UniSession(db)
session.register(Chunk, Entity, Mentions)
session.sync_schema()
print(f"Opened database at {db_path}")
Opened database at /tmp/rag_pydantic_db
3. Create Data¶
8 documentation chunks across 4 topics, with 6 named entities.
In [4]:
Copied!
# 4D embeddings: [auth, routing, database, testing]
c1 = Chunk(chunk_id="c1", text="JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour.",
embedding=[1.0, 0.0, 0.0, 0.0 ])
c2 = Chunk(chunk_id="c2", text="Token refresh via /auth/refresh. Send expired token, receive new one.",
embedding=[0.95, 0.05, 0.0, 0.0 ])
c3 = Chunk(chunk_id="c3", text="Password hashing uses bcrypt with cost factor 12.",
embedding=[0.85, 0.0, 0.0, 0.15])
c4 = Chunk(chunk_id="c4", text="Routes defined with @app.route decorator. Supports GET, POST, PUT, DELETE.",
embedding=[0.0, 1.0, 0.0, 0.0 ])
c5 = Chunk(chunk_id="c5", text="Middleware intercepts requests before handlers. Register with app.use().",
embedding=[0.05, 0.9, 0.05, 0.0 ])
c6 = Chunk(chunk_id="c6", text="ConnectionPool manages DB connections. Max pool size defaults to 10.",
embedding=[0.0, 0.0, 1.0, 0.0 ])
c7 = Chunk(chunk_id="c7", text="ORM models inherit from BaseModel. Columns map to database fields.",
embedding=[0.0, 0.1, 0.9, 0.0 ])
c8 = Chunk(chunk_id="c8", text="TestClient simulates HTTP requests without starting a server.",
embedding=[0.0, 0.2, 0.0, 0.8 ])
# 6 entities
jwt = Entity(name="JWT", entity_type="technology")
auth_entity = Entity(name="authentication", entity_type="concept")
routing_ent = Entity(name="routing", entity_type="concept")
db_entity = Entity(name="database", entity_type="concept")
bcrypt_ent = Entity(name="bcrypt", entity_type="technology")
pool_entity = Entity(name="ConnectionPool", entity_type="class")
session.add_all([c1, c2, c3, c4, c5, c6, c7, c8,
jwt, auth_entity, routing_ent, db_entity, bcrypt_ent, pool_entity])
session.commit()
# Create vector index AFTER commit
db.create_vector_index("Chunk", "embedding", "l2")
print("Data ingested and vector index created")
# 4D embeddings: [auth, routing, database, testing]
c1 = Chunk(chunk_id="c1", text="JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour.",
embedding=[1.0, 0.0, 0.0, 0.0 ])
c2 = Chunk(chunk_id="c2", text="Token refresh via /auth/refresh. Send expired token, receive new one.",
embedding=[0.95, 0.05, 0.0, 0.0 ])
c3 = Chunk(chunk_id="c3", text="Password hashing uses bcrypt with cost factor 12.",
embedding=[0.85, 0.0, 0.0, 0.15])
c4 = Chunk(chunk_id="c4", text="Routes defined with @app.route decorator. Supports GET, POST, PUT, DELETE.",
embedding=[0.0, 1.0, 0.0, 0.0 ])
c5 = Chunk(chunk_id="c5", text="Middleware intercepts requests before handlers. Register with app.use().",
embedding=[0.05, 0.9, 0.05, 0.0 ])
c6 = Chunk(chunk_id="c6", text="ConnectionPool manages DB connections. Max pool size defaults to 10.",
embedding=[0.0, 0.0, 1.0, 0.0 ])
c7 = Chunk(chunk_id="c7", text="ORM models inherit from BaseModel. Columns map to database fields.",
embedding=[0.0, 0.1, 0.9, 0.0 ])
c8 = Chunk(chunk_id="c8", text="TestClient simulates HTTP requests without starting a server.",
embedding=[0.0, 0.2, 0.0, 0.8 ])
# 6 entities
jwt = Entity(name="JWT", entity_type="technology")
auth_entity = Entity(name="authentication", entity_type="concept")
routing_ent = Entity(name="routing", entity_type="concept")
db_entity = Entity(name="database", entity_type="concept")
bcrypt_ent = Entity(name="bcrypt", entity_type="technology")
pool_entity = Entity(name="ConnectionPool", entity_type="class")
session.add_all([c1, c2, c3, c4, c5, c6, c7, c8,
jwt, auth_entity, routing_ent, db_entity, bcrypt_ent, pool_entity])
session.commit()
# Create vector index AFTER commit
db.create_vector_index("Chunk", "embedding", "l2")
print("Data ingested and vector index created")
Data ingested and vector index created
In [5]:
Copied!
# MENTIONS edges
session.create_edge(c1, "MENTIONS", jwt)
session.create_edge(c1, "MENTIONS", auth_entity)
session.create_edge(c2, "MENTIONS", jwt)
session.create_edge(c2, "MENTIONS", auth_entity)
session.create_edge(c3, "MENTIONS", bcrypt_ent)
session.create_edge(c3, "MENTIONS", auth_entity)
session.create_edge(c4, "MENTIONS", routing_ent)
session.create_edge(c5, "MENTIONS", routing_ent)
session.create_edge(c6, "MENTIONS", db_entity)
session.create_edge(c6, "MENTIONS", pool_entity)
session.create_edge(c7, "MENTIONS", db_entity)
session.commit()
print("Entity mention edges created")
# MENTIONS edges
session.create_edge(c1, "MENTIONS", jwt)
session.create_edge(c1, "MENTIONS", auth_entity)
session.create_edge(c2, "MENTIONS", jwt)
session.create_edge(c2, "MENTIONS", auth_entity)
session.create_edge(c3, "MENTIONS", bcrypt_ent)
session.create_edge(c3, "MENTIONS", auth_entity)
session.create_edge(c4, "MENTIONS", routing_ent)
session.create_edge(c5, "MENTIONS", routing_ent)
session.create_edge(c6, "MENTIONS", db_entity)
session.create_edge(c6, "MENTIONS", pool_entity)
session.create_edge(c7, "MENTIONS", db_entity)
session.commit()
print("Entity mention edges created")
Entity mention edges created
4. Pure Vector Search¶
Find the 3 chunks most similar to an authentication query.
In [6]:
Copied!
auth_query = [1.0, 0.0, 0.0, 0.0]
query_vec = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node, distance
RETURN node.chunk_id AS chunk_id, node.text AS text, distance
ORDER BY distance
"""
results = session.cypher(query_vec, {"vec": auth_query})
print('Top 3 chunks for auth query:')
for r in results:
print(f" [{r['distance']:.4f}] {r['chunk_id']}: {r['text'][:60]}...")
chunk_ids = [r['chunk_id'] for r in results]
assert set(chunk_ids) == {'c1', 'c2', 'c3'}, f'Expected auth chunks c1/c2/c3, got {chunk_ids}'
auth_query = [1.0, 0.0, 0.0, 0.0]
query_vec = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node, distance
RETURN node.chunk_id AS chunk_id, node.text AS text, distance
ORDER BY distance
"""
results = session.cypher(query_vec, {"vec": auth_query})
print('Top 3 chunks for auth query:')
for r in results:
print(f" [{r['distance']:.4f}] {r['chunk_id']}: {r['text'][:60]}...")
chunk_ids = [r['chunk_id'] for r in results]
assert set(chunk_ids) == {'c1', 'c2', 'c3'}, f'Expected auth chunks c1/c2/c3, got {chunk_ids}'
Top 3 chunks for auth query: [0.0000] c1: JWT tokens issued by /auth/login endpoint. Tokens expire aft... [0.0050] c2: Token refresh via /auth/refresh. Send expired token, receive... [0.0450] c3: Password hashing uses bcrypt with cost factor 12....
5. Graph Expansion¶
Same vector seeds — now also show which entities each chunk mentions.
In [7]:
Copied!
query_expand = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node, distance
MATCH (node)-[:MENTIONS]->(e:Entity)
RETURN node.chunk_id AS chunk_id, e.name AS entity, distance
ORDER BY distance, entity
"""
results = session.cypher(query_expand, {"vec": auth_query})
print('Entities mentioned by top auth chunks:')
for r in results:
print(f" {r['chunk_id']} -> {r['entity']}")
query_expand = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node, distance
MATCH (node)-[:MENTIONS]->(e:Entity)
RETURN node.chunk_id AS chunk_id, e.name AS entity, distance
ORDER BY distance, entity
"""
results = session.cypher(query_expand, {"vec": auth_query})
print('Entities mentioned by top auth chunks:')
for r in results:
print(f" {r['chunk_id']} -> {r['entity']}")
Entities mentioned by top auth chunks: c1 -> JWT c1 -> authentication c2 -> JWT c2 -> authentication c3 -> authentication c3 -> bcrypt
6. Entity Bridging¶
Find all chunks related to the auth seeds via shared entity mentions — the core graph RAG technique.
In [8]:
Copied!
query_bridge = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node AS anchor, distance
MATCH (anchor)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
WHERE related._vid <> anchor._vid
RETURN anchor.chunk_id AS anchor_id, e.name AS bridge_entity,
related.chunk_id AS related_id
ORDER BY anchor_id, bridge_entity
"""
results = session.cypher(query_bridge, {"vec": auth_query})
print('Entity bridges between auth chunks:')
for r in results:
print(f" {r['anchor_id']} <-> {r['related_id']} (via {r['bridge_entity']})")
query_bridge = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node AS anchor, distance
MATCH (anchor)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
WHERE related._vid <> anchor._vid
RETURN anchor.chunk_id AS anchor_id, e.name AS bridge_entity,
related.chunk_id AS related_id
ORDER BY anchor_id, bridge_entity
"""
results = session.cypher(query_bridge, {"vec": auth_query})
print('Entity bridges between auth chunks:')
for r in results:
print(f" {r['anchor_id']} <-> {r['related_id']} (via {r['bridge_entity']})")
Entity bridges between auth chunks: c1 <-> c2 (via JWT) c1 <-> c2 (via authentication) c1 <-> c3 (via authentication) c2 <-> c1 (via JWT) c2 <-> c3 (via authentication) c2 <-> c1 (via authentication) c3 <-> c1 (via authentication) c3 <-> c2 (via authentication)
7. Context Assembly¶
Full hybrid pipeline: vector seeds + graph bridging → collect unique chunks for the LLM context window.
In [9]:
Copied!
query_ctx = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node AS seed, distance
MATCH (seed)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
RETURN seed.chunk_id AS seed_id, seed.text AS seed_text,
related.chunk_id AS related_id, related.text AS related_text,
e.name AS shared_entity
ORDER BY seed_id, shared_entity
"""
results = session.cypher(query_ctx, {"vec": auth_query})
# Collect all unique chunk texts for LLM context
context_chunks = {}
for r in results:
context_chunks[r['seed_id']] = r['seed_text']
context_chunks[r['related_id']] = r['related_text']
print(f'Assembled {len(context_chunks)} unique chunks for LLM context:')
for cid, text in sorted(context_chunks.items()):
print(f' [{cid}] {text[:70]}...')
query_ctx = """
CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
YIELD node AS seed, distance
MATCH (seed)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
RETURN seed.chunk_id AS seed_id, seed.text AS seed_text,
related.chunk_id AS related_id, related.text AS related_text,
e.name AS shared_entity
ORDER BY seed_id, shared_entity
"""
results = session.cypher(query_ctx, {"vec": auth_query})
# Collect all unique chunk texts for LLM context
context_chunks = {}
for r in results:
context_chunks[r['seed_id']] = r['seed_text']
context_chunks[r['related_id']] = r['related_text']
print(f'Assembled {len(context_chunks)} unique chunks for LLM context:')
for cid, text in sorted(context_chunks.items()):
print(f' [{cid}] {text[:70]}...')
Assembled 3 unique chunks for LLM context: [c1] JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour.... [c2] Token refresh via /auth/refresh. Send expired token, receive new one.... [c3] Password hashing uses bcrypt with cost factor 12....
8. Query Builder Demo¶
Using the type-safe query builder to browse entities.
In [10]:
Copied!
# Find all technology entities using query builder
tech_entities = (
session.query(Entity)
.filter(Entity.entity_type == "technology")
.all()
)
print("Technology entities:")
for entity in tech_entities:
print(f" - {entity.name}")
total_chunks = session.query(Chunk).count()
print(f"Total chunks in knowledge base: {total_chunks}")
# Find all technology entities using query builder
tech_entities = (
session.query(Entity)
.filter(Entity.entity_type == "technology")
.all()
)
print("Technology entities:")
for entity in tech_entities:
print(f" - {entity.name}")
total_chunks = session.query(Chunk).count()
print(f"Total chunks in knowledge base: {total_chunks}")
Technology entities: - JWT - bcrypt Total chunks in knowledge base: 8