Coverage for app \ rag \ graph_rag_pipeline.py: 88%

72 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-24 13:18 +0530

1""" 

2Hybrid Graph-RAG Pipeline 

3 

4Flow: 

51. Update patient graph from question (Neo4j) 

62. Fetch patient medical profile (Neo4j) — includes labs + wearables 

73. Check drug interactions (Neo4j / rules) 

84. Search medical research papers (Qdrant) 

95. Merge all context 

106. Build clinical prompt 

117. Generate LLM response 

128. Extract structured medical claims 

13""" 

14 

15from app.knowledge_graph.patient_graph_reader import ( 

16 upsert_user_from_question, 

17 get_patient_profile, 

18) 

19from app.knowledge_graph.drug_interactions import ( 

20 check_drug_interactions, 

21) 

22from app.vector_store.paper_search import ( 

23 search_papers, 

24) 

25from app.rag.prompt_builder import ( 

26 build_medical_prompt, 

27) 

28from app.rag.claim_extractor import ( 

29 extract_claims, 

30) 

31from app.llm.ollama_client import ( 

32 call_ollama, 

33) 

34from app.utils.logger import get_logger 

35 

36logger = get_logger(__name__) 

37 

38 

39def run_hybrid_rag_pipeline( 

40 user_id, 

41 question, 

42) -> dict: 

43 """ 

44 End-to-end Hybrid Graph-RAG execution. 

45 

46 Args: 

47 user_id: Patient ID from Neo4j (e.g. "user_1") 

48 question: The patient's health question 

49 

50 Returns: 

51 dict with keys: response, claims, context 

52 """ 

53 

54 logger.info("Starting Hybrid Graph-RAG", extra={"user_id": user_id}) 

55 

56 # ---------------------------------------------------------------- 

57 # 1. Ensure patient node exists in Neo4j 

58 # ---------------------------------------------------------------- 

59 logger.info("Step 1 — Upserting patient node") 

60 upsert_user_from_question(user_id, question) 

61 

62 # ---------------------------------------------------------------- 

63 # 2. Fetch FULL patient profile from Neo4j 

64 # Now includes: demographics, conditions, medications, 

65 # lab_results (flat), wearables + readings 

66 # ---------------------------------------------------------------- 

67 logger.info("Step 2 — Fetching patient profile (full graph)") 

68 patient_profile = get_patient_profile(user_id) 

69 

70 # Wearables now come directly from patient_graph_reader 

71 # No separate wearables_graph call needed 

72 wearables = patient_profile.pop("wearables", {"available": False, "metrics": []}) 

73 

74 logger.info( 

75 "Patient profile loaded", 

76 extra={ 

77 "conditions": len(patient_profile.get("conditions", [])), 

78 "medications": len(patient_profile.get("medications", [])), 

79 "lab_results": len(patient_profile.get("lab_results", [])), 

80 "wearable_metrics": len(wearables.get("metrics", [])), 

81 } 

82 ) 

83 

84 # ---------------------------------------------------------------- 

85 # 3. Drug interaction safety check 

86 # ---------------------------------------------------------------- 

87 logger.info("Step 3 — Checking drug interactions") 

88 drug_interactions = check_drug_interactions( 

89 medications=patient_profile.get("medications", []) 

90 ) 

91 

92 # ---------------------------------------------------------------- 

93 # 4. Search relevant research papers (Qdrant vector search) 

94 # ---------------------------------------------------------------- 

95 logger.info("Step 4 — Searching research papers") 

96 papers = search_papers( 

97 query=question, 

98 top_k=3, # Keep to 3 — prompt_builder uses max 3 

99 ) 

100 

101 logger.info(f"Found {len(papers)} relevant papers") 

102 

103 # ---------------------------------------------------------------- 

104 # 5. Merge full context 

105 # ---------------------------------------------------------------- 

106 context = { 

107 "patient": patient_profile, 

108 "wearables": wearables, 

109 "papers": papers, 

110 "drug_interactions": drug_interactions, 

111 } 

112 

113 logger.info("Step 5 — Context merged", extra={ 

114 "has_wearables": wearables.get("available", False), 

115 "has_papers": len(papers) > 0, 

116 "has_drug_facts": bool(drug_interactions), 

117 }) 

118 

119 

120 # ---------------------------------------------------------------- 

121 # 6. Build clinical-grade prompt 

122 # ---------------------------------------------------------------- 

123 logger.info("Step 6 — Building clinical prompt") 

124 prompt = build_medical_prompt( 

125 question=question, 

126 context=context, 

127 ) 

128 

129 # ---------------------------------------------------------------- 

130 # 7. LLM generation 

131 # ---------------------------------------------------------------- 

132 logger.info("Step 7 — Calling LLM") 

133 response = call_ollama(prompt) 

134 

135 # ---------------------------------------------------------------- 

136 # 8. Extract structured medical claims 

137 # ---------------------------------------------------------------- 

138 logger.info("Step 8 — Extracting medical claims") 

139 claims = extract_claims(response) 

140 

141 # ---------------------------------------------------------------- 

142 # Output to console 

143 # ---------------------------------------------------------------- 

144 _print_results( 

145 user_id=user_id, 

146 question=question, 

147 patient_profile=patient_profile, 

148 wearables=wearables, 

149 response=response, 

150 claims=claims, 

151 ) 

152 

153 logger.info("Hybrid Graph-RAG completed successfully") 

154 

155 return { 

156 "response": response, 

157 "claims": claims, 

158 "context": context, 

159 } 

160 

161 

162# ------------------------------------------------------------------ 

163# Console output helper 

164# ------------------------------------------------------------------ 

165 

166def _print_results( 

167 user_id: str, 

168 question: str, 

169 patient_profile: dict, 

170 wearables: dict, 

171 response: str, 

172 claims: list, 

173) -> None: 

174 """ 

175 Clean structured console output for debugging. 

176 """ 

177 

178 print("\n" + "=" * 60) 

179 print("HYBRID GRAPH-RAG PIPELINE — RESULTS") 

180 print("=" * 60) 

181 

182 # Patient summary 

183 print(f"\n👤 Patient : {patient_profile.get('name', 'Unknown')} ({user_id})") 

184 print(f" Age : {patient_profile.get('age', 'N/A')}") 

185 print(f" Question : {question}") 

186 

187 # Conditions 

188 conditions = patient_profile.get("conditions", []) 

189 if conditions: 

190 print(f"\n🏥 Conditions ({len(conditions)}):") 

191 for c in conditions: 

192 print(f" - {c.get('name')} [{c.get('severity')}]") 

193 

194 # Medications 

195 meds = patient_profile.get("medications", []) 

196 if meds: 

197 print(f"\n💊 Medications ({len(meds)}):") 

198 for m in meds: 

199 print(f" - {m.get('name')} {m.get('dosage')}{m.get('frequency')}") 

200 

201 # Lab results 

202 labs = patient_profile.get("lab_results", []) 

203 if labs: 

204 print(f"\n🧪 Lab Results ({len(labs)}):") 

205 for l in labs: 

206 print(f" - {l.get('name')}: {l.get('result')} {l.get('unit')} [{l.get('status')}]") 

207 

208 # Wearables 

209 metrics = wearables.get("metrics", []) 

210 if metrics: 

211 print(f"\n⌚ Wearable Metrics ({len(metrics)}):") 

212 for m in metrics: 

213 print(f" - {m.get('metric')}: Latest {m.get('latest_value')} | Trend: {m.get('trend')}") 

214 for r in m.get("readings", []): 

215 print(f" [{r.get('date')}] → {r.get('value')}") 

216 

217 # LLM Response 

218 print("\n" + "=" * 60) 

219 print("💬 Medical Assistant Response") 

220 print("=" * 60) 

221 print(response) 

222 

223 # Claims 

224 print("\n" + "=" * 60) 

225 print("✅ Extracted Medical Claims") 

226 print("=" * 60) 

227 for c in claims: 

228 print(f" {c.get('type', 'general')}{c.get('claim', c)}") 

229 

230 

231# ------------------------------------------------------------------ 

232# Entry point 

233# ------------------------------------------------------------------ 

234 

235if __name__ == "__main__": 

236 # Test with different patients by changing user_id 

237 run_hybrid_rag_pipeline( 

238 

239 )