Skip to content

Locy Flagship: DDI Risk + Joint Regimen Safety with R-GCN-Style Drug Embeddings

Clinical pharmacists triage drug-drug interaction warnings for elderly polypharmacy patients. The clinical question is not "is this pairwise interaction dangerous?" — it's "given this patient's entire regimen of 6 drugs, what's the joint probability that any clinically significant interaction occurs?" This notebook delivers:

  • A real Hetionet v1.0 drug subgraph: 40 Compound nodes + their Gene targets, sourced from the Hetionet TSV.
  • Offline-trained 64-dim drug embeddings from TruncatedSVD over the Compound-Gene bipartite adjacency (vendored as parquet). In production swap in an R-GCN; the deployment pattern is identical.
  • Pseudo-DDI labels from the Vilar-style shared-target heuristic: drugs sharing ≥2 targeted genes are tagged is_dangerous=true.
  • A registered Python classifier that loads the ONNX MLP head + embeddings parquet once, then for each pair resolves the two embeddings, concatenates them, and runs ONNX inference.
  • A joint_regimen_safety rule: FOLD MPROD(1.0 - interaction_score(rec.pair_id)) across all distinct drug pairs in each patient's regimen — inline classifier invocation inside the aggregator.
  • In-Locy CALIBRATE against the is_dangerous labels and VALIDATE reporting Brier + accuracy.
  • Patient risk ranking with worst-contributing-pair annotation.
  • EXPLAIN trace surfacing the classifier's NeuralProvenance per derivation.

Data: Hetionet v1.0 (CC0 1.0 Universal; Himmelstein DS et al., eLife 2017, DOI: 10.7554/eLife.26726). Runtime dependencies: onnxruntime, pandas, pyarrow (see the notebook-runtime extras group in bindings/uni-db/pyproject.toml).

1) Setup + Schema

Drug, InteractionRecord, Patient, plus HAS_INTERACTION_WITH (drug ↔ record) and TAKES (patient → drug) edges. The pair_id on each InteractionRecord is the lookup key passed to the classifier so it can resolve the two drug embeddings at inference time.

import csv
import tempfile
import shutil
from pathlib import Path

import uni_db

WORK_DIR = Path(tempfile.mkdtemp(prefix='uni_locy_ddi_'))
db = uni_db.Uni.open(str(WORK_DIR / 'db'))

(db.schema()
    .label('Drug')
        .property('drug_id', 'string')
        .property('name', 'string')
    .done()
    .label('InteractionRecord')
        .property('pair_id', 'string')
        .property('shared_targets', 'int')
        .property('is_dangerous', 'bool')
    .done()
    .label('Patient')
        .property('patient_id', 'string')
    .done()
    .apply())
print('DB initialized')
DB initialized

2) Load Vendored Hetionet DDI Data + ONNX Artifacts

The prep script (website/scripts/prepare_drug_drug_interaction_notebook_data.py) vendors the curated drug/gene CSVs, the pseudo-DDI pair list, patient regimens, the 64-dim drug embeddings parquet, and the ONNX MLP head.

def _find_data_dir():
    rel = 'website/docs/examples/data/locy_drug_drug_interaction'
    cur = Path.cwd().resolve()
    for parent in (cur, *cur.parents):
        candidate = parent / rel
        if candidate.exists():
            return candidate
    raise AssertionError(
        f'Data directory not found from {cur}. '
        f'Run `python website/scripts/prepare_drug_drug_interaction_notebook_data.py` first.'
    )

DATA_DIR = _find_data_dir()

def _read_csv(name):
    with open(DATA_DIR / name, encoding='utf-8') as f:
        return list(csv.DictReader(f))

DRUG_ROWS = _read_csv('hetionet_ddi_drugs.csv')
PAIR_ROWS = _read_csv('ddi_pairs.csv')
PATIENT_ROWS = _read_csv('ddi_patients.csv')
REGIMEN_ROWS = _read_csv('ddi_patient_regimens.csv')

print(f'Loaded {len(DRUG_ROWS)} Hetionet drugs, {len(PAIR_ROWS)} pseudo-DDI pairs '
      f'({sum(1 for r in PAIR_ROWS if r["is_dangerous"] == "true")} dangerous), '
      f'{len(PATIENT_ROWS)} patients with {len(REGIMEN_ROWS)} drug-regimen edges')
Loaded 40 Hetionet drugs, 568 pseudo-DDI pairs (451 dangerous), 8 patients with 36 drug-regimen edges

3) Ingest into Uni

Nodes in the first transaction; edges + interaction records + patients in the second.

session = db.session()

def _esc(s):
    return str(s).replace("'", "\\'")

# tx1: Drug + Patient nodes.
tx = session.tx()
for d in DRUG_ROWS:
    tx.execute(
        f"CREATE (:Drug {{drug_id: '{_esc(d['drug_id'])}', name: '{_esc(d['name'])}'}})"
    )
for p in PATIENT_ROWS:
    tx.execute(
        f"CREATE (:Patient {{patient_id: '{_esc(p['patient_id'])}'}})"
    )
tx.commit()

# tx2: InteractionRecord nodes + HAS_INTERACTION_WITH edges (bidirectional)
# + TAKES regimen edges.
tx = session.tx()
for r in PAIR_ROWS:
    tx.execute(
        f"MATCH (a:Drug {{drug_id: '{_esc(r['drug_a_id'])}'}}), "
        f"      (b:Drug {{drug_id: '{_esc(r['drug_b_id'])}'}}) "
        f"CREATE (rec:InteractionRecord {{pair_id: '{_esc(r['pair_id'])}', "
        f"shared_targets: {r['shared_targets']}, is_dangerous: {r['is_dangerous']}}}), "
        f"       (a)-[:HAS_INTERACTION_WITH]->(rec), "
        f"       (b)-[:HAS_INTERACTION_WITH]->(rec)"
    )
for r in REGIMEN_ROWS:
    tx.execute(
        f"MATCH (p:Patient {{patient_id: '{_esc(r['patient_id'])}'}}), "
        f"      (d:Drug {{drug_id: '{_esc(r['drug_id'])}'}}) "
        f"CREATE (p)-[:TAKES]->(d)"
    )
tx.commit()
INGESTED_DRUGS = len(DRUG_ROWS)
INGESTED_PAIRS = len(PAIR_ROWS)
INGESTED_PATIENTS = len(PATIENT_ROWS)
print(f'Ingested {INGESTED_DRUGS} Drug, {INGESTED_PAIRS} InteractionRecord, '
      f'{INGESTED_PATIENTS} Patient')
Ingested 40 Drug, 568 InteractionRecord, 8 Patient

4) Register the ONNX-Backed Pairwise Classifier

The classifier callable loads the drug embeddings parquet and the ONNX MLP head once at module level. For each invocation:

  1. Receives the InteractionRecord's pair_id as the FEATURES value.
  2. Resolves pair_id to its two drug_ids via the precomputed mapping.
  3. Looks up the two 64-dim embeddings.
  4. Concatenates and runs ONNX inference.
  5. Returns a per-row probability vector.

This is exactly the production pattern: offline graph learning produces embeddings, a tiny runtime ONNX head consumes them, the registered callable bridges Locy and the runtime.

import numpy as np
import onnxruntime as ort
import pyarrow.parquet as pq

# Load drug embeddings (rows: drug_id, e0..e63).
_emb_table = pq.read_table(DATA_DIR / 'drug_embeddings.parquet').to_pylist()
_DRUG_EMBED = {
    row['drug_id']: np.asarray(
        [row[f'e{i}'] for i in range(len(row) - 1)], dtype=np.float32
    )
    for row in _emb_table
}
_EMBED_DIM = next(iter(_DRUG_EMBED.values())).shape[0]
print(f'Loaded {len(_DRUG_EMBED)} drug embeddings × {_EMBED_DIM} dim')

# Load ONNX MLP head.
_ONNX_SESSION = ort.InferenceSession(
    str(DATA_DIR / 'ddi_mlp_head.onnx'),
    providers=['CPUExecutionProvider'],
)

# pair_id -> (drug_a_id, drug_b_id) lookup, sourced from the vendored CSV.
_PAIR_TO_DRUGS = {
    r['pair_id']: (r['drug_a_id'], r['drug_b_id'])
    for r in PAIR_ROWS
}

def interaction_score(inputs):
    """ONNX-backed DDI classifier."""
    if not inputs:
        return []
    feats = np.zeros((len(inputs), 2 * _EMBED_DIM), dtype=np.float32)
    for i, row in enumerate(inputs):
        pair_id = row.get('rec')
        drugs = _PAIR_TO_DRUGS.get(pair_id) if pair_id is not None else None
        if drugs is None:
            # Unknown pair — neutral 0.5 prediction.
            continue
        emb_a = _DRUG_EMBED.get(drugs[0])
        emb_b = _DRUG_EMBED.get(drugs[1])
        if emb_a is None or emb_b is None:
            continue
        feats[i, :_EMBED_DIM] = emb_a
        feats[i, _EMBED_DIM:] = emb_b
    preds = _ONNX_SESSION.run(['p_interact'], {'concat_embeddings': feats})[0]
    return [float(max(0.0, min(1.0, p))) for p in preds.flatten()]

config = uni_db.LocyConfig()
config.register_classifier('interaction_score', interaction_score)
print(f'Registered classifiers: {config.classifier_aliases()}')
Loaded 40 drug embeddings × 64 dim
Registered classifiers: ['interaction_score']


2026-06-23 00:08:25.233132771 [W:onnxruntime:Default, device_discovery.cc:133 GetPciBusId] Skipping pci_bus_id for PCI path at "/sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/MSFT1000:00/5620e0c7-8062-4dce-aeb7-520c7ef76171" because filename "5620e0c7-8062-4dce-aeb7-520c7ef76171" did not match expected pattern of [0-9a-f]+:[0-9a-f]+:[0-9a-f]+[.][0-9a-f]+

5) Score Pairs + Compose Joint Regimen Safety

  • scored_interactions: per-pair classifier output via the ONNX MLP head.
  • joint_regimen_safety: per patient, FOLD MPROD(1.0 - interaction_score(rec.pair_id)) across every distinct drug pair in their regimen. The classifier is invoked inside the aggregator, so per-pair ONNX inference and regimen composition happen in a single declarative step.
COMPOSE_PROGRAM = '''
CREATE MODEL interaction_score AS
  INPUT (rec)
  FEATURES rec.pair_id
  OUTPUT PROB risk
  USING xervo('classify/ddi-v1')
  VERSION '1.0.0'

CREATE RULE scored_interactions AS
  MATCH (rec:InteractionRecord)
  YIELD KEY rec, interaction_score(rec.pair_id) AS risk PROB

CREATE RULE joint_regimen_safety AS
  MATCH (p:Patient)-[:TAKES]->(d1:Drug)-[:HAS_INTERACTION_WITH]->(rec:InteractionRecord)<-[:HAS_INTERACTION_WITH]-(d2:Drug)<-[:TAKES]-(p)
  WHERE d1.drug_id < d2.drug_id
  FOLD safety = MPROD(1.0 - interaction_score(rec.pair_id))
  YIELD KEY p, safety

// BEST BY: pick the single highest-risk pair per patient.
// One row per Patient. Demonstrates BEST BY as a Locy-
// declarative replacement for the Python max() that drives
// the patient-ranking cell below.
CREATE RULE worst_pair_per_patient AS
  MATCH (p:Patient)-[:TAKES]->(d1:Drug)-[:HAS_INTERACTION_WITH]->(rec:InteractionRecord)<-[:HAS_INTERACTION_WITH]-(d2:Drug)<-[:TAKES]-(p)
  WHERE d1.drug_id < d2.drug_id
  BEST BY pair_risk DESC
  YIELD KEY p, interaction_score(rec.pair_id) AS pair_risk
'''

compose_result = session.locy_with(COMPOSE_PROGRAM).with_config(config).run()
SCORED_COUNT = len(compose_result.derived.get('scored_interactions', []))
JOINT_SAFETY_COUNT = len(compose_result.derived.get('joint_regimen_safety', []))
WORST_PAIR_PER_PATIENT_COUNT = len(compose_result.derived.get('worst_pair_per_patient', []))
print(f'Derived: scored_interactions={SCORED_COUNT}  joint_regimen_safety={JOINT_SAFETY_COUNT}  '
      f'worst_pair_per_patient={WORST_PAIR_PER_PATIENT_COUNT} (one BEST BY row per patient)')

print('\nJoint regimen safety per patient (lower = riskier):')
for row in sorted(compose_result.derived.get('joint_regimen_safety', []), key=lambda r: r['safety']):
    p = row.get('p')
    pid = p.properties.get('patient_id') if hasattr(p, 'properties') else '?'
    print(f'  patient={pid:<8}  safety={row["safety"]:.4f}')
Derived: scored_interactions=568  joint_regimen_safety=8  worst_pair_per_patient=8 (one BEST BY row per patient)

Joint regimen safety per patient (lower = riskier):
  patient=PAT04     safety=0.0000
  patient=PAT06     safety=0.0000
  patient=PAT08     safety=0.0000
  patient=PAT05     safety=0.0000
  patient=PAT03     safety=0.0000
  patient=PAT01     safety=0.0000
  patient=PAT02     safety=0.0000
  patient=PAT07     safety=0.0000

6) Calibrate Against the Vilar-Derived is_dangerous Labels

CALIBRATE_PROGRAM = '''
CREATE MODEL interaction_score AS
  INPUT (rec)
  FEATURES rec.pair_id
  OUTPUT PROB risk
  USING xervo('classify/ddi-v1')
  VERSION '1.0.0'

CALIBRATE interaction_score
  ON MATCH (rec:InteractionRecord)
  TARGET rec.is_dangerous
  METHOD platt_scaling
'''

calib_result = session.locy_with(CALIBRATE_PROGRAM).with_config(config).run()
calib_records = [c for c in calib_result.command_results if isinstance(c, dict) and c.get('type') == 'calibrate']
BRIER_DELTA = None
CALIBRATOR = None  # used downstream by the patient-ranking cell
if calib_records:
    c = calib_records[0]
    print(f'Calibration: {c["method"]}')
    print(f'  raw        brier={c["raw_brier"]:.4f}  ece={c["raw_ece"]:.4f}')
    print(f'  calibrated brier={c["calibrated_brier"]:.4f}  ece={c["calibrated_ece"]:.4f}')
    BRIER_DELTA = c['raw_brier'] - c['calibrated_brier']
    print(f'  delta_brier = {BRIER_DELTA:+.4f}')
    CALIBRATOR = c.get('calibrator')
    print(f'  fitted calibrator: {CALIBRATOR}')
Calibration: Platt
  raw        brier=0.3039  ece=0.3849
  calibrated brier=0.1560  ece=0.0163
  delta_brier = +0.1479
  fitted calibrator: Calibrator(method=Platt)

7) Validate

VALIDATE_PROGRAM = '''
CREATE MODEL interaction_score AS
  INPUT (rec)
  FEATURES rec.pair_id
  OUTPUT PROB risk
  USING xervo('classify/ddi-v1')
  VERSION '1.0.0'

CREATE RULE scored_interactions AS
  MATCH (rec:InteractionRecord)
  YIELD KEY rec, interaction_score(rec.pair_id) AS risk PROB

VALIDATE scored_interactions
  ON MATCH (rec:InteractionRecord)
  TARGET rec.is_dangerous
  METRICS brier_score, accuracy
'''

val_result = session.locy_with(VALIDATE_PROGRAM).with_config(config).run()
val_records = [c for c in val_result.command_results if isinstance(c, dict) and c.get('type') == 'validate']
VALIDATE_METRICS = val_records[0]['metrics'] if val_records else {}
print(f'Validation metrics: {VALIDATE_METRICS}')
Validation metrics: {'BrierScore': 0.000600578898840041, 'Accuracy': 0.9982394366197183}

8) EXPLAIN — Pair Audit

Pair-level EXPLAIN trace shows the classifier inputs and outputs for the highest-risk InteractionRecord.

first_dangerous = next((r['pair_id'] for r in PAIR_ROWS if r['is_dangerous'] == 'true'), None)
EXPLAIN_PROGRAM = f'''
CREATE MODEL interaction_score AS
  INPUT (rec)
  FEATURES rec.pair_id
  OUTPUT PROB risk
  USING xervo('classify/ddi-v1')
  VERSION '1.0.0'

CREATE RULE scored_interactions AS
  MATCH (rec:InteractionRecord)
  YIELD KEY rec, interaction_score(rec.pair_id) AS risk PROB

EXPLAIN RULE scored_interactions WHERE rec.pair_id = '{first_dangerous}'
'''

explain_result = session.locy_with(EXPLAIN_PROGRAM).with_config(config).run()
explain_records = [c for c in explain_result.command_results if isinstance(c, uni_db.ExplainCommandResult)]
EXPLAIN_PRODUCED = len(explain_records)
print(f'EXPLAIN pair records: {EXPLAIN_PRODUCED} (for pair {first_dangerous})')

def _format_node(node, depth=0, out=None):
    if out is None:
        out = []
    if not isinstance(node, dict):
        return out
    indent = '  ' * depth
    rule = node.get('rule', '?')
    bindings = node.get('bindings', {}) or {}
    pp = node.get('proof_probability')
    out.append(f'{indent}rule={rule}  clause={node.get("clause_index")}  '
               f'proof_p={pp}')
    if bindings:
        keys = sorted(k for k in bindings if not k.startswith('__'))
        kv = ', '.join(f'{k}={bindings[k]!r}' for k in keys[:4])
        out.append(f'{indent}  bindings: {kv}')
    for call in node.get('neural_calls', []) or []:
        out.append(
            f'{indent}  neural: model={call["model_name"]!r} '
            f'raw={call["raw_probability"]:.4f} '
            f'calibrated={call["calibrated_probability"]} '
            f'band={call["confidence_band"]}'
        )
    for child in node.get('children', []) or []:
        _format_node(child, depth + 1, out)
    return out

if explain_records:
    tree = getattr(explain_records[0], 'tree', None)
    if tree is not None:
        print('\n'.join(_format_node(tree)))
EXPLAIN pair records: 1 (for pair PR0001)
rule=scored_interactions  clause=0  proof_p=None
  rule=scored_interactions  clause=0  proof_p=None
    bindings: rec=Node(id=48, labels=["InteractionRecord"], properties={'shared_targets': 16, 'is_dangerous': True, 'pair_id': 'PR0001'}), risk=0.999962568283081
    neural: model='interaction_score' raw=1.0000 calibrated=None band=None

9) Patient Risk Ranking + Worst-Pair Annotation

Combine joint regimen safety with the highest-shared-targets pair in each patient's regimen — the actionable substitution target.

patient_drug_set = {}
for r in REGIMEN_ROWS:
    patient_drug_set.setdefault(r['patient_id'], set()).add(r['drug_id'])

# Map each pair_id back to (drug_a, drug_b) for cross-reference.
pair_to_drugs = {
    r['pair_id']: (r['drug_a_id'], r['drug_b_id'])
    for r in PAIR_ROWS
}

# Map each pair_id → CALIBRATED classifier-derived risk drawn from
# the scored_interactions rule output. The overview promises the
# 'worst pair' is chosen by the classifier's per-pair prediction,
# so we rank by interaction_score (calibrated when available),
# not the static shared_targets count.
pair_to_risk = {}
for row in compose_result.derived.get('scored_interactions', []):
    rec = row.get('rec')
    pid = rec.properties.get('pair_id') if hasattr(rec, 'properties') else None
    if pid is None:
        continue
    raw = row['risk']
    pair_to_risk[pid] = (
        CALIBRATOR.apply(raw) if CALIBRATOR is not None else raw
    )
if CALIBRATOR is None:
    print('NOTE: no calibrator returned — worst-pair selection uses RAW risk')

patient_worst_pair = {}
for pid, (a, b) in pair_to_drugs.items():
    risk = pair_to_risk.get(pid)
    if risk is None:
        continue
    for pat, drugs in patient_drug_set.items():
        if a in drugs and b in drugs:
            best = patient_worst_pair.get(pat)
            if best is None or risk > best[2]:
                patient_worst_pair[pat] = (pid, (a, b), risk)

ranking = []
for row in compose_result.derived.get('joint_regimen_safety', []):
    p = row.get('p')
    pat = p.properties.get('patient_id') if hasattr(p, 'properties') else '?'
    worst = patient_worst_pair.get(pat)
    ranking.append((pat, row['safety'], worst))
ranking.sort(key=lambda r: r[1])
PATIENT_RANKING_LEN = len(ranking)

print(f'Patient risk ranking ({PATIENT_RANKING_LEN} regimens):')
print(f'  {"patient":<8} {"safety":>7}  worst_pair')
for pat, safety, worst in ranking:
    if worst is None:
        print(f'  {pat:<8} {safety:>7.4f}  (no cross-pair)')
    else:
        pid, (a, b), risk = worst
        print(f'  {pat:<8} {safety:>7.4f}  {pid} ({a}+{b}, pair_risk={risk:.4f})')
Patient risk ranking (8 regimens):
  patient   safety  worst_pair
  PAT04     0.0000  PR0404 (DB00619+DB00753, pair_risk=0.9978)
  PAT06     0.0000  PR0473 (DB00171+DB01151, pair_risk=0.9958)
  PAT08     0.0000  PR0531 (DB01236+DB00997, pair_risk=0.0690)
  PAT05     0.0000  PR0258 (DB08877+DB00619, pair_risk=0.9992)
  PAT03     0.0000  PR0296 (DB00143+DB00661, pair_risk=0.0401)
  PAT01     0.0000  PR0443 (DB01238+DB01236, pair_risk=0.9992)
  PAT02     0.0000  PR0229 (DB05294+DB00143, pair_risk=0.9867)
  PAT07     0.0000  PR0306 (DB06637+DB00753, pair_risk=1.0000)

10) Summary + Build-Time Assertions

Real Hetionet drug subgraph, offline-trained TruncatedSVD 64-dim drug embeddings, ONNX MLP head loaded by the registered classifier, Vilar-derived is_dangerous ground-truth labels, joint-regimen-safety composition via FOLD MPROD with inline ONNX inference inside the aggregator, in-Locy Platt calibration, Brier + accuracy validation, patient ranking, and an EXPLAIN audit trail.

assert INGESTED_DRUGS >= 30, f'expected at least 30 drugs, got {INGESTED_DRUGS}'
assert SCORED_COUNT == INGESTED_PAIRS, f'expected {INGESTED_PAIRS} scored rows, got {SCORED_COUNT}'
# Each patient with ≥2 cross-class drugs yields a joint_regimen_safety row.
assert JOINT_SAFETY_COUNT >= 4, f'JOINT_SAFETY_COUNT={JOINT_SAFETY_COUNT}'
assert PATIENT_RANKING_LEN == JOINT_SAFETY_COUNT, (
    f'ranking should match joint_regimen_safety: {PATIENT_RANKING_LEN} vs {JOINT_SAFETY_COUNT}'
)
assert BRIER_DELTA is not None, 'CALIBRATE should return a record'
assert any('Brier' in k or 'brier' in k for k in VALIDATE_METRICS), (
    f'missing Brier metric: {VALIDATE_METRICS}'
)
assert EXPLAIN_PRODUCED >= 1, 'EXPLAIN should produce at least one record'
print('All build-time assertions passed.')
All build-time assertions passed.

11) Cleanup

del db
shutil.rmtree(WORK_DIR, ignore_errors=True)
print(f'Cleaned up {WORK_DIR}')
Cleaned up /tmp/uni_locy_ddi_jv2k_1m6