Coverage for app \ rag \ graph_rag_pipeline.py: 88%
72 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 13:18 +0530
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 13:18 +0530
1"""
2Hybrid Graph-RAG Pipeline
4Flow:
51. Update patient graph from question (Neo4j)
62. Fetch patient medical profile (Neo4j) — includes labs + wearables
73. Check drug interactions (Neo4j / rules)
84. Search medical research papers (Qdrant)
95. Merge all context
106. Build clinical prompt
117. Generate LLM response
128. Extract structured medical claims
13"""
15from app.knowledge_graph.patient_graph_reader import (
16 upsert_user_from_question,
17 get_patient_profile,
18)
19from app.knowledge_graph.drug_interactions import (
20 check_drug_interactions,
21)
22from app.vector_store.paper_search import (
23 search_papers,
24)
25from app.rag.prompt_builder import (
26 build_medical_prompt,
27)
28from app.rag.claim_extractor import (
29 extract_claims,
30)
31from app.llm.ollama_client import (
32 call_ollama,
33)
34from app.utils.logger import get_logger
36logger = get_logger(__name__)
39def run_hybrid_rag_pipeline(
40 user_id,
41 question,
42) -> dict:
43 """
44 End-to-end Hybrid Graph-RAG execution.
46 Args:
47 user_id: Patient ID from Neo4j (e.g. "user_1")
48 question: The patient's health question
50 Returns:
51 dict with keys: response, claims, context
52 """
54 logger.info("Starting Hybrid Graph-RAG", extra={"user_id": user_id})
56 # ----------------------------------------------------------------
57 # 1. Ensure patient node exists in Neo4j
58 # ----------------------------------------------------------------
59 logger.info("Step 1 — Upserting patient node")
60 upsert_user_from_question(user_id, question)
62 # ----------------------------------------------------------------
63 # 2. Fetch FULL patient profile from Neo4j
64 # Now includes: demographics, conditions, medications,
65 # lab_results (flat), wearables + readings
66 # ----------------------------------------------------------------
67 logger.info("Step 2 — Fetching patient profile (full graph)")
68 patient_profile = get_patient_profile(user_id)
70 # Wearables now come directly from patient_graph_reader
71 # No separate wearables_graph call needed
72 wearables = patient_profile.pop("wearables", {"available": False, "metrics": []})
74 logger.info(
75 "Patient profile loaded",
76 extra={
77 "conditions": len(patient_profile.get("conditions", [])),
78 "medications": len(patient_profile.get("medications", [])),
79 "lab_results": len(patient_profile.get("lab_results", [])),
80 "wearable_metrics": len(wearables.get("metrics", [])),
81 }
82 )
84 # ----------------------------------------------------------------
85 # 3. Drug interaction safety check
86 # ----------------------------------------------------------------
87 logger.info("Step 3 — Checking drug interactions")
88 drug_interactions = check_drug_interactions(
89 medications=patient_profile.get("medications", [])
90 )
92 # ----------------------------------------------------------------
93 # 4. Search relevant research papers (Qdrant vector search)
94 # ----------------------------------------------------------------
95 logger.info("Step 4 — Searching research papers")
96 papers = search_papers(
97 query=question,
98 top_k=3, # Keep to 3 — prompt_builder uses max 3
99 )
101 logger.info(f"Found {len(papers)} relevant papers")
103 # ----------------------------------------------------------------
104 # 5. Merge full context
105 # ----------------------------------------------------------------
106 context = {
107 "patient": patient_profile,
108 "wearables": wearables,
109 "papers": papers,
110 "drug_interactions": drug_interactions,
111 }
113 logger.info("Step 5 — Context merged", extra={
114 "has_wearables": wearables.get("available", False),
115 "has_papers": len(papers) > 0,
116 "has_drug_facts": bool(drug_interactions),
117 })
120 # ----------------------------------------------------------------
121 # 6. Build clinical-grade prompt
122 # ----------------------------------------------------------------
123 logger.info("Step 6 — Building clinical prompt")
124 prompt = build_medical_prompt(
125 question=question,
126 context=context,
127 )
129 # ----------------------------------------------------------------
130 # 7. LLM generation
131 # ----------------------------------------------------------------
132 logger.info("Step 7 — Calling LLM")
133 response = call_ollama(prompt)
135 # ----------------------------------------------------------------
136 # 8. Extract structured medical claims
137 # ----------------------------------------------------------------
138 logger.info("Step 8 — Extracting medical claims")
139 claims = extract_claims(response)
141 # ----------------------------------------------------------------
142 # Output to console
143 # ----------------------------------------------------------------
144 _print_results(
145 user_id=user_id,
146 question=question,
147 patient_profile=patient_profile,
148 wearables=wearables,
149 response=response,
150 claims=claims,
151 )
153 logger.info("Hybrid Graph-RAG completed successfully")
155 return {
156 "response": response,
157 "claims": claims,
158 "context": context,
159 }
162# ------------------------------------------------------------------
163# Console output helper
164# ------------------------------------------------------------------
166def _print_results(
167 user_id: str,
168 question: str,
169 patient_profile: dict,
170 wearables: dict,
171 response: str,
172 claims: list,
173) -> None:
174 """
175 Clean structured console output for debugging.
176 """
178 print("\n" + "=" * 60)
179 print("HYBRID GRAPH-RAG PIPELINE — RESULTS")
180 print("=" * 60)
182 # Patient summary
183 print(f"\n👤 Patient : {patient_profile.get('name', 'Unknown')} ({user_id})")
184 print(f" Age : {patient_profile.get('age', 'N/A')}")
185 print(f" Question : {question}")
187 # Conditions
188 conditions = patient_profile.get("conditions", [])
189 if conditions:
190 print(f"\n🏥 Conditions ({len(conditions)}):")
191 for c in conditions:
192 print(f" - {c.get('name')} [{c.get('severity')}]")
194 # Medications
195 meds = patient_profile.get("medications", [])
196 if meds:
197 print(f"\n💊 Medications ({len(meds)}):")
198 for m in meds:
199 print(f" - {m.get('name')} {m.get('dosage')} — {m.get('frequency')}")
201 # Lab results
202 labs = patient_profile.get("lab_results", [])
203 if labs:
204 print(f"\n🧪 Lab Results ({len(labs)}):")
205 for l in labs:
206 print(f" - {l.get('name')}: {l.get('result')} {l.get('unit')} [{l.get('status')}]")
208 # Wearables
209 metrics = wearables.get("metrics", [])
210 if metrics:
211 print(f"\n⌚ Wearable Metrics ({len(metrics)}):")
212 for m in metrics:
213 print(f" - {m.get('metric')}: Latest {m.get('latest_value')} | Trend: {m.get('trend')}")
214 for r in m.get("readings", []):
215 print(f" [{r.get('date')}] → {r.get('value')}")
217 # LLM Response
218 print("\n" + "=" * 60)
219 print("💬 Medical Assistant Response")
220 print("=" * 60)
221 print(response)
223 # Claims
224 print("\n" + "=" * 60)
225 print("✅ Extracted Medical Claims")
226 print("=" * 60)
227 for c in claims:
228 print(f" {c.get('type', 'general')} — {c.get('claim', c)}")
231# ------------------------------------------------------------------
232# Entry point
233# ------------------------------------------------------------------
235if __name__ == "__main__":
236 # Test with different patients by changing user_id
237 run_hybrid_rag_pipeline(
239 )