Skip to content

Retrieval-Augmented Generation (RAG)

Combining vector search with knowledge graph traversal for hybrid retrieval over Python web framework documentation.

import os
import shutil
import tempfile

import uni_db
db_path = os.path.join(tempfile.gettempdir(), "rag_db")
if os.path.exists(db_path):
    shutil.rmtree(db_path)
db = uni_db.Uni.open(db_path)
session = db.session()
print(f"Opened database at {db_path}")
Opened database at /tmp/rag_db

1. Schema

Text chunks with embeddings, linked to named entities via MENTIONS edges.

(
    db.schema()
    .label("Chunk")
    .property("chunk_id", "string")
    .property("text", "string")
    .vector("embedding", 4)
    .done()
    .label("Entity")
    .property("name", "string")
    .property("type", "string")
    .done()
    .edge_type("MENTIONS", ["Chunk"], ["Entity"])
    .done()
    .apply()
)

print("Schema created")
Schema created

2. Ingest Data

8 documentation chunks across 4 topics, with 6 entities.

# 4D embeddings: [auth, routing, database, testing]
tx = session.tx()
with tx.bulk_writer().build() as bw:
    chunk_vids = bw.insert_vertices(
        "Chunk",
        [
            {
                "chunk_id": "c1",
                "text": "JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour.",
                "embedding": [1.0, 0.0, 0.0, 0.0],
            },
            {
                "chunk_id": "c2",
                "text": "Token refresh via /auth/refresh. Send expired token, receive new one.",
                "embedding": [0.95, 0.05, 0.0, 0.0],
            },
            {
                "chunk_id": "c3",
                "text": "Password hashing uses bcrypt with cost factor 12.",
                "embedding": [0.85, 0.0, 0.0, 0.15],
            },
            {
                "chunk_id": "c4",
                "text": "Routes defined with @app.route decorator. Supports GET, POST, PUT, DELETE.",
                "embedding": [0.0, 1.0, 0.0, 0.0],
            },
            {
                "chunk_id": "c5",
                "text": "Middleware intercepts requests before handlers. Register with app.use().",
                "embedding": [0.05, 0.9, 0.05, 0.0],
            },
            {
                "chunk_id": "c6",
                "text": "ConnectionPool manages DB connections. Max pool size defaults to 10.",
                "embedding": [0.0, 0.0, 1.0, 0.0],
            },
            {
                "chunk_id": "c7",
                "text": "ORM models inherit from BaseModel. Columns map to database fields.",
                "embedding": [0.0, 0.1, 0.9, 0.0],
            },
            {
                "chunk_id": "c8",
                "text": "TestClient simulates HTTP requests without starting a server.",
                "embedding": [0.0, 0.2, 0.0, 0.8],
            },
        ],
    )
    c1, c2, c3, c4, c5, c6, c7, c8 = chunk_vids

    # Entities
    entity_vids = bw.insert_vertices(
        "Entity",
        [
            {"name": "JWT", "type": "technology"},
            {"name": "authentication", "type": "concept"},
            {"name": "routing", "type": "concept"},
            {"name": "database", "type": "concept"},
            {"name": "bcrypt", "type": "technology"},
            {"name": "ConnectionPool", "type": "class"},
        ],
    )
    jwt, auth_entity, routing_entity, db_entity, bcrypt_entity, pool_entity = (
        entity_vids
    )

    # MENTIONS edges
    bw.insert_edges(
        "MENTIONS",
        [
            (c1, jwt, {}),
            (c1, auth_entity, {}),
            (c2, jwt, {}),
            (c2, auth_entity, {}),
            (c3, bcrypt_entity, {}),
            (c3, auth_entity, {}),
            (c4, routing_entity, {}),
            (c5, routing_entity, {}),
            (c6, db_entity, {}),
            (c6, pool_entity, {}),
            (c7, db_entity, {}),
        ],
    )

    bw.commit()
tx.commit()

print("Data ingested")
Data ingested

Find the 3 chunks most similar to an authentication query.

auth_query = [1.0, 0.0, 0.0, 0.0]

results = session.query(
    """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node, distance
    RETURN node.chunk_id AS chunk_id, node.text AS text, distance
    ORDER BY distance
""",
    {"vec": auth_query},
)

print("Top 3 chunks for auth query:")
for r in results:
    print(f"  [{r['distance']:.4f}] {r['chunk_id']}: {r['text'][:60]}...")

chunk_ids = [r["chunk_id"] for r in results]
assert set(chunk_ids) == {"c1", "c2", "c3"}, (
    f"Expected auth chunks c1/c2/c3, got {chunk_ids}"
)
Top 3 chunks for auth query:
  [0.0000] c1: JWT tokens issued by /auth/login endpoint. Tokens expire aft...
  [0.0050] c2: Token refresh via /auth/refresh. Send expired token, receive...
  [0.0450] c3: Password hashing uses bcrypt with cost factor 12....

4. Graph Expansion

Same vector seeds — now also show which entities each chunk mentions.

results = session.query(
    """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node, distance
    MATCH (node)-[:MENTIONS]->(e:Entity)
    RETURN node.chunk_id AS chunk_id, e.name AS entity, distance
    ORDER BY distance, entity
""",
    {"vec": auth_query},
)

print("Entities mentioned by top auth chunks:")
for r in results:
    print(f"  {r['chunk_id']} -> {r['entity']}")
Entities mentioned by top auth chunks:
  c1 -> JWT
  c1 -> authentication
  c2 -> JWT
  c2 -> authentication
  c3 -> authentication
  c3 -> bcrypt

5. Entity Bridging

Find all chunks related to the auth seeds via shared entity mentions. This is the graph RAG technique: expand context through shared concepts.

results = session.query(
    """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node AS anchor, distance
    MATCH (anchor)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
    WHERE related._vid <> anchor._vid
    RETURN anchor.chunk_id AS anchor_id, e.name AS bridge_entity,
           related.chunk_id AS related_id
    ORDER BY anchor_id, bridge_entity
""",
    {"vec": auth_query},
)

print("Entity bridges between auth chunks:")
for r in results:
    print(f"  {r['anchor_id']} <-> {r['related_id']} (via {r['bridge_entity']})")
Entity bridges between auth chunks:
  c1 <-> c2 (via JWT)
  c1 <-> c2 (via authentication)
  c1 <-> c3 (via authentication)
  c2 <-> c1 (via JWT)
  c2 <-> c1 (via authentication)
  c2 <-> c3 (via authentication)
  c3 <-> c2 (via authentication)
  c3 <-> c1 (via authentication)

6. Context Assembly

Full hybrid pipeline: vector seeds + graph bridging -> collect unique chunks for the LLM context window.

results = session.query(
    """
    CALL uni.vector.query('Chunk', 'embedding', $vec, 3)
    YIELD node AS seed, distance
    MATCH (seed)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
    RETURN seed.chunk_id AS seed_id, seed.text AS seed_text,
           related.chunk_id AS related_id, related.text AS related_text,
           e.name AS shared_entity
    ORDER BY seed_id, shared_entity
""",
    {"vec": auth_query},
)

# Collect all unique chunk texts for LLM context
context_chunks = {}
for r in results:
    context_chunks[r["seed_id"]] = r["seed_text"]
    context_chunks[r["related_id"]] = r["related_text"]

print(f"Assembled {len(context_chunks)} unique chunks for LLM context:")
for cid, text in sorted(context_chunks.items()):
    print(f"  [{cid}] {text[:70]}...")
Assembled 3 unique chunks for LLM context:
  [c1] JWT tokens issued by /auth/login endpoint. Tokens expire after 1 hour....
  [c2] Token refresh via /auth/refresh. Send expired token, receive new one....
  [c3] Password hashing uses bcrypt with cost factor 12....