Skip to content

Retrieval-Augmented Generation (RAG) with uni-pydantic

Combining vector search with knowledge graph traversal for hybrid retrieval over Python web framework documentation.

import os
import shutil
import tempfile

import uni_db
from uni_pydantic import UniNode, UniEdge, UniSession, Field, Relationship, Vector

1. Define Models

Text chunks with embeddings, linked to named entities via MENTIONS edges.

class Chunk(UniNode):
    """A chunk of text with semantic embedding."""

    __label__ = "Chunk"

    chunk_id: str
    text: str
    embedding: Vector[4] = Field(metric="l2")  # [auth, routing, database, testing]

    # Relationships
    entities: list["Entity"] = Relationship("MENTIONS", direction="outgoing")


class Entity(UniNode):
    """A named entity extracted from text."""

    __label__ = "Entity"

    name: str
    entity_type: str = Field(default="unknown")

    # Relationships
    mentioned_in: list[Chunk] = Relationship("MENTIONS", direction="incoming")


class Mentions(UniEdge):
    """Edge representing a chunk mentioning an entity."""

    __edge_type__ = "MENTIONS"
    __from__ = Chunk
    __to__ = Entity

2. Setup Database and Session

db_path = os.path.join(tempfile.gettempdir(), "rag_pydantic_db")
if os.path.exists(db_path):
    shutil.rmtree(db_path)
db = uni_db.Uni.open(db_path)

# Create session and register models
session = UniSession(db)
session.register(Chunk, Entity, Mentions)
session.sync_schema()

print(f"Opened database at {db_path}")
Opened database at /tmp/rag_pydantic_db

3. Create Data

8 documentation chunks across 4 topics, with 6 named entities.

# 4D embeddings: [auth, routing, database, testing]
c1 = Chunk(
    chunk_id="c1",
    text="JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour.",
    embedding=[1.0, 0.0, 0.0, 0.0],
)
c2 = Chunk(
    chunk_id="c2",
    text="Token refresh via /auth/refresh. Send expired token, receive new one.",
    embedding=[0.95, 0.05, 0.0, 0.0],
)
c3 = Chunk(
    chunk_id="c3",
    text="Password hashing uses bcrypt with cost factor 12.",
    embedding=[0.85, 0.0, 0.0, 0.15],
)
c4 = Chunk(
    chunk_id="c4",
    text="Routes defined with @app.route decorator. Supports GET, POST, PUT, DELETE.",
    embedding=[0.0, 1.0, 0.0, 0.0],
)
c5 = Chunk(
    chunk_id="c5",
    text="Middleware intercepts requests before handlers. Register with app.use().",
    embedding=[0.05, 0.9, 0.05, 0.0],
)
c6 = Chunk(
    chunk_id="c6",
    text="ConnectionPool manages DB connections. Max pool size defaults to 10.",
    embedding=[0.0, 0.0, 1.0, 0.0],
)
c7 = Chunk(
    chunk_id="c7",
    text="ORM models inherit from BaseModel. Columns map to database fields.",
    embedding=[0.0, 0.1, 0.9, 0.0],
)
c8 = Chunk(
    chunk_id="c8",
    text="TestClient simulates HTTP requests without starting a server.",
    embedding=[0.0, 0.2, 0.0, 0.8],
)

# 6 entities
jwt = Entity(name="JWT", entity_type="technology")
auth_entity = Entity(name="authentication", entity_type="concept")
routing_ent = Entity(name="routing", entity_type="concept")
db_entity = Entity(name="database", entity_type="concept")
bcrypt_ent = Entity(name="bcrypt", entity_type="technology")
pool_entity = Entity(name="ConnectionPool", entity_type="class")

session.add_all(
    [
        c1,
        c2,
        c3,
        c4,
        c5,
        c6,
        c7,
        c8,
        jwt,
        auth_entity,
        routing_ent,
        db_entity,
        bcrypt_ent,
        pool_entity,
    ]
)
session.commit()

print("Data ingested")
Data ingested
# MENTIONS edges
session.create_edge(c1, "MENTIONS", jwt)
session.create_edge(c1, "MENTIONS", auth_entity)
session.create_edge(c2, "MENTIONS", jwt)
session.create_edge(c2, "MENTIONS", auth_entity)
session.create_edge(c3, "MENTIONS", bcrypt_ent)
session.create_edge(c3, "MENTIONS", auth_entity)
session.create_edge(c4, "MENTIONS", routing_ent)
session.create_edge(c5, "MENTIONS", routing_ent)
session.create_edge(c6, "MENTIONS", db_entity)
session.create_edge(c6, "MENTIONS", pool_entity)
session.create_edge(c7, "MENTIONS", db_entity)
session.commit()
print("Entity mention edges created")
Entity mention edges created

Find the 3 chunks most similar to an authentication query.

auth_query = [1.0, 0.0, 0.0, 0.0]

query_vec = """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node, distance
    RETURN node.chunk_id AS chunk_id, node.text AS text, distance
    ORDER BY distance
"""
results = session.cypher(query_vec, {"vec": auth_query})
print("Top 3 chunks for auth query:")
for r in results:
    print(f"  [{r['distance']:.4f}] {r['chunk_id']}: {r['text'][:60]}...")

chunk_ids = [r["chunk_id"] for r in results]
assert set(chunk_ids) == {"c1", "c2", "c3"}, (
    f"Expected auth chunks c1/c2/c3, got {chunk_ids}"
)
Top 3 chunks for auth query:
  [0.0000] c1: JWT tokens issued by /auth/login endpoint. Tokens expire aft...
  [0.0050] c2: Token refresh via /auth/refresh. Send expired token, receive...
  [0.0450] c3: Password hashing uses bcrypt with cost factor 12....

5. Graph Expansion

Same vector seeds — now also show which entities each chunk mentions.

query_expand = """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node, distance
    MATCH (node)-[:MENTIONS]->(e:Entity)
    RETURN node.chunk_id AS chunk_id, e.name AS entity, distance
    ORDER BY distance, entity
"""
results = session.cypher(query_expand, {"vec": auth_query})
print("Entities mentioned by top auth chunks:")
for r in results:
    print(f"  {r['chunk_id']} -> {r['entity']}")
Entities mentioned by top auth chunks:
  c1 -> JWT
  c1 -> authentication
  c2 -> JWT
  c2 -> authentication
  c3 -> authentication
  c3 -> bcrypt

6. Entity Bridging

Find all chunks related to the auth seeds via shared entity mentions — the core graph RAG technique.

query_bridge = """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node AS anchor, distance
    MATCH (anchor)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
    WHERE related._vid <> anchor._vid
    RETURN anchor.chunk_id AS anchor_id, e.name AS bridge_entity,
           related.chunk_id AS related_id
    ORDER BY anchor_id, bridge_entity
"""
results = session.cypher(query_bridge, {"vec": auth_query})
print("Entity bridges between auth chunks:")
for r in results:
    print(f"  {r['anchor_id']} <-> {r['related_id']} (via {r['bridge_entity']})")
Entity bridges between auth chunks:
  c1 <-> c2 (via JWT)
  c1 <-> c2 (via authentication)
  c1 <-> c3 (via authentication)
  c2 <-> c1 (via JWT)
  c2 <-> c3 (via authentication)
  c2 <-> c1 (via authentication)
  c3 <-> c1 (via authentication)
  c3 <-> c2 (via authentication)

7. Context Assembly

Full hybrid pipeline: vector seeds + graph bridging → collect unique chunks for the LLM context window.

query_ctx = """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node AS seed, distance
    MATCH (seed)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
    RETURN seed.chunk_id AS seed_id, seed.text AS seed_text,
           related.chunk_id AS related_id, related.text AS related_text,
           e.name AS shared_entity
    ORDER BY seed_id, shared_entity
"""
results = session.cypher(query_ctx, {"vec": auth_query})

# Collect all unique chunk texts for LLM context
context_chunks = {}
for r in results:
    context_chunks[r["seed_id"]] = r["seed_text"]
    context_chunks[r["related_id"]] = r["related_text"]

print(f"Assembled {len(context_chunks)} unique chunks for LLM context:")
for cid, text in sorted(context_chunks.items()):
    print(f"  [{cid}] {text[:70]}...")
Assembled 3 unique chunks for LLM context:
  [c1] JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour....
  [c2] Token refresh via /auth/refresh. Send expired token, receive new one....
  [c3] Password hashing uses bcrypt with cost factor 12....

8. Query Builder Demo

Using the type-safe query builder to browse entities.

# Find all technology entities using query builder
tech_entities = session.query(Entity).filter(Entity.entity_type == "technology").all()

print("Technology entities:")
for entity in tech_entities:
    print(f"  - {entity.name}")

total_chunks = session.query(Chunk).count()
print(f"Total chunks in knowledge base: {total_chunks}")
Technology entities:
  - JWT
  - bcrypt
Total chunks in knowledge base: 8