Retrieval-Augmented Generation (RAG) with uni-pydantic¶
Combining Vector Search with Knowledge Graph traversal for better context using Pydantic models.
In [1]:
Copied!
import os
import shutil
import tempfile
import uni_db
from uni_pydantic import UniNode, UniEdge, UniSession, Field, Relationship, Vector
import os
import shutil
import tempfile
import uni_db
from uni_pydantic import UniNode, UniEdge, UniSession, Field, Relationship, Vector
1. Define Models¶
Chunks of text with embeddings, linked to named Entities for knowledge graph traversal.
In [2]:
Copied!
class Chunk(UniNode):
"""A chunk of text with semantic embedding."""
__label__ = "Chunk"
text: str
embedding: Vector[4] = Field(metric="cosine") # 4-dim vector for demo
# Relationships
entities: list["Entity"] = Relationship("MENTIONS", direction="outgoing")
class Entity(UniNode):
"""A named entity extracted from text."""
__label__ = "Entity"
name: str
entity_type: str = Field(default="unknown") # function, class, variable, etc.
# Relationships
mentioned_in: list[Chunk] = Relationship("MENTIONS", direction="incoming")
class Mentions(UniEdge):
"""Edge representing a chunk mentioning an entity."""
__edge_type__ = "MENTIONS"
__from__ = Chunk
__to__ = Entity
class Chunk(UniNode):
"""A chunk of text with semantic embedding."""
__label__ = "Chunk"
text: str
embedding: Vector[4] = Field(metric="cosine") # 4-dim vector for demo
# Relationships
entities: list["Entity"] = Relationship("MENTIONS", direction="outgoing")
class Entity(UniNode):
"""A named entity extracted from text."""
__label__ = "Entity"
name: str
entity_type: str = Field(default="unknown") # function, class, variable, etc.
# Relationships
mentioned_in: list[Chunk] = Relationship("MENTIONS", direction="incoming")
class Mentions(UniEdge):
"""Edge representing a chunk mentioning an entity."""
__edge_type__ = "MENTIONS"
__from__ = Chunk
__to__ = Entity
2. Setup Database and Session¶
In [3]:
Copied!
db_path = os.path.join(tempfile.gettempdir(), "rag_pydantic_db")
if os.path.exists(db_path):
shutil.rmtree(db_path)
db = uni_db.Database(db_path)
# Create session and register models
session = UniSession(db)
session.register(Chunk, Entity, Mentions)
session.sync_schema()
print(f"Opened database at {db_path}")
db_path = os.path.join(tempfile.gettempdir(), "rag_pydantic_db")
if os.path.exists(db_path):
shutil.rmtree(db_path)
db = uni_db.Database(db_path)
# Create session and register models
session = UniSession(db)
session.register(Chunk, Entity, Mentions)
session.sync_schema()
print(f"Opened database at {db_path}")
Opened database at /tmp/rag_pydantic_db
3. Create Data¶
Ingest text chunks with embeddings and link them to entities.
In [4]:
Copied!
# Create chunks with embeddings
chunk1 = Chunk(
text="Function verify() checks cryptographic signatures.",
embedding=[1.0, 0.0, 0.0, 0.0]
)
chunk2 = Chunk(
text="The verify function validates input before processing.",
embedding=[0.9, 0.1, 0.0, 0.0] # Similar to chunk1
)
chunk3 = Chunk(
text="Database connections are pooled for efficiency.",
embedding=[0.0, 0.0, 1.0, 0.0] # Different topic
)
# Create entities
verify_entity = Entity(name="verify", entity_type="function")
database_entity = Entity(name="database", entity_type="concept")
# Add all to session
session.add_all([chunk1, chunk2, chunk3, verify_entity, database_entity])
session.commit()
print(f"Created 3 chunks and 2 entities")
# Create chunks with embeddings
chunk1 = Chunk(
text="Function verify() checks cryptographic signatures.",
embedding=[1.0, 0.0, 0.0, 0.0]
)
chunk2 = Chunk(
text="The verify function validates input before processing.",
embedding=[0.9, 0.1, 0.0, 0.0] # Similar to chunk1
)
chunk3 = Chunk(
text="Database connections are pooled for efficiency.",
embedding=[0.0, 0.0, 1.0, 0.0] # Different topic
)
# Create entities
verify_entity = Entity(name="verify", entity_type="function")
database_entity = Entity(name="database", entity_type="concept")
# Add all to session
session.add_all([chunk1, chunk2, chunk3, verify_entity, database_entity])
session.commit()
print(f"Created 3 chunks and 2 entities")
Created 3 chunks and 2 entities
In [5]:
Copied!
# Link chunks to entities
session.create_edge(chunk1, "MENTIONS", verify_entity)
session.create_edge(chunk2, "MENTIONS", verify_entity)
session.create_edge(chunk3, "MENTIONS", database_entity)
session.commit()
print("Created entity mention relationships")
# Link chunks to entities
session.create_edge(chunk1, "MENTIONS", verify_entity)
session.create_edge(chunk2, "MENTIONS", verify_entity)
session.create_edge(chunk3, "MENTIONS", database_entity)
session.commit()
print("Created entity mention relationships")
Created entity mention relationships
4. Vector Search¶
Find semantically similar chunks using vector similarity.
In [6]:
Copied!
# Query vector (similar to chunk1)
query_vec = [0.95, 0.05, 0.0, 0.0]
# Find similar chunks
query = """
MATCH (c:Chunk)
WHERE vector_similarity(c.embedding, $query_vec) > 0.8
RETURN c.text as text
"""
results = session.cypher(query, {"query_vec": query_vec})
print("Chunks similar to query:")
for r in results:
print(f" - {r['text']}")
# Query vector (similar to chunk1)
query_vec = [0.95, 0.05, 0.0, 0.0]
# Find similar chunks
query = """
MATCH (c:Chunk)
WHERE vector_similarity(c.embedding, $query_vec) > 0.8
RETURN c.text as text
"""
results = session.cypher(query, {"query_vec": query_vec})
print("Chunks similar to query:")
for r in results:
print(f" - {r['text']}")
Chunks similar to query: - Function verify() checks cryptographic signatures. - The verify function validates input before processing.
DEBUG 2: DataFusion execution failed (falling back to execute_subplan): Schema error: No field named "c.text". Valid fields are "c._vid", c, "c._score".
5. Hybrid Retrieval¶
Find chunks related to a specific chunk via shared entities (knowledge graph traversal).
In [7]:
Copied!
# Find chunks that share entities with chunk1
query = """
MATCH (c:Chunk)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
WHERE c._vid = $cid AND related._vid <> c._vid
RETURN related.text as text, e.name as shared_entity
"""
results = session.cypher(query, {"cid": chunk1.vid})
print("Chunks related to chunk1 via shared entities:")
for r in results:
print(f" - '{r['text']}' (via entity: {r['shared_entity']})")
# Find chunks that share entities with chunk1
query = """
MATCH (c:Chunk)-[:MENTIONS]->(e:Entity)<-[:MENTIONS]-(related:Chunk)
WHERE c._vid = $cid AND related._vid <> c._vid
RETURN related.text as text, e.name as shared_entity
"""
results = session.cypher(query, {"cid": chunk1.vid})
print("Chunks related to chunk1 via shared entities:")
for r in results:
print(f" - '{r['text']}' (via entity: {r['shared_entity']})")
Chunks related to chunk1 via shared entities: - 'The verify function validates input before processing.' (via entity: verify)
DEBUG 2: DataFusion execution failed (falling back to execute_subplan): Internal error: Only intervals with the same data type are intersectable, lhs:UInt64, rhs:Int64. This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues
6. Query Builder Demo¶
Using the type-safe query builder to find entities.
In [8]:
Copied!
# Find all function entities
function_entities = (
session.query(Entity)
.filter(Entity.entity_type == "function")
.all()
)
print("Function entities:")
for entity in function_entities:
print(f" - {entity.name} (type: {entity.entity_type})")
# Find all function entities
function_entities = (
session.query(Entity)
.filter(Entity.entity_type == "function")
.all()
)
print("Function entities:")
for entity in function_entities:
print(f" - {entity.name} (type: {entity.entity_type})")
DEBUG 2: DataFusion execution failed (falling back to execute_subplan): Error during planning: UDF 'properties' is not registered. Register it via SessionContext.
Function entities: - verify (type: function)
In [9]:
Copied!
# Count total chunks
total_chunks = session.query(Chunk).count()
print(f"Total chunks in knowledge base: {total_chunks}")
# Count total chunks
total_chunks = session.query(Chunk).count()
print(f"Total chunks in knowledge base: {total_chunks}")
Total chunks in knowledge base: 3