Coverage for app \ knowledge_graph \ patient_graph_reader.py: 36%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-24 13:18 +0530

1""" 

2Patient Graph Reader 

3 

4Purpose: 

5- Write patient-related FACTS into Neo4j 

6- Read patient medical history as clean, structured JSON 

7- Used by Hybrid Graph-RAG (NO reasoning, NO research papers) 

8 

9Design Principles: 

10- Deterministic 

11- Auditable 

12- LLM-safe (facts only, no inference) 

13 

14Neo4j Domain: 

15Patient → Disease → Medication 

16Patient → Medication 

17Disease → LabTest 

18Patient → WearableMetric → Reading 

19""" 

20 

21from typing import Dict, Any, List 

22from neo4j import GraphDatabase 

23import os 

24 

25 

26# ------------------------------------------------------------------ 

27# Neo4j connection 

28# ------------------------------------------------------------------ 

29 

30def _get_driver(): 

31 uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") 

32 user = os.getenv("NEO4J_USER", "neo4j") 

33 password = os.getenv("NEO4J_PASSWORD", "password") 

34 return GraphDatabase.driver(uri, auth=(user, password)) 

35 

36 

37# ------------------------------------------------------------------ 

38# WRITE: Ensure patient exists (facts only) 

39# ------------------------------------------------------------------ 

40 

41def upsert_user_from_question(user_id: str, question: str) -> None: 

42 """ 

43 Minimal deterministic update: 

44 - Ensure Patient node exists 

45 - Does NOT extract diseases or symptoms 

46 """ 

47 driver = _get_driver() 

48 

49 cypher = """ 

50 MERGE (p:Patient {id: $user_id}) 

51 ON CREATE SET p.created_at = datetime() 

52 ON MATCH SET p.updated_at = datetime() 

53 """ 

54 

55 with driver.session() as session: 

56 session.run(cypher, user_id=user_id) 

57 

58 driver.close() 

59 

60 

61# ------------------------------------------------------------------ 

62# READ: Fetch patient FACTS profile 

63# ------------------------------------------------------------------ 

64 

65def get_patient_profile(user_id: str) -> Dict[str, Any]: 

66 """ 

67 Fetch complete patient medical profile as FACTS ONLY. 

68 No inference, no reasoning. 

69 Output is safe to pass into LLM context. 

70 

71 Fetches: 

72 - Demographics 

73 - Conditions + severity + diagnosis date 

74 - Medications + purpose + what they treat 

75 - Lab results linked to each disease 

76 - Wearable metrics + all dated readings 

77 """ 

78 

79 driver = _get_driver() 

80 

81 # ── Core patient + conditions + medications + labs ────────────── 

82 core_cypher = """ 

83 MATCH (p:Patient {id: $user_id}) 

84 

85 OPTIONAL MATCH (p)-[:HAS_DISEASE]->(d:Disease) 

86 OPTIONAL MATCH (d)-[:TREATED_BY]->(m:Medication) 

87 OPTIONAL MATCH (p)-[:PRESCRIBED]->(pm:Medication) 

88 OPTIONAL MATCH (d)-[:HAS_LAB_RESULT]->(l:LabTest) 

89 

90 RETURN 

91 p, 

92 collect(DISTINCT d) AS diseases, 

93 collect(DISTINCT m) AS disease_medications, 

94 collect(DISTINCT pm) AS patient_medications, 

95 collect(DISTINCT l) AS lab_tests 

96 """ 

97 

98 # ── Wearables + all readings ───────────────────────────────────── 

99 wearables_cypher = """ 

100 MATCH (p:Patient {id: $user_id})-[:HAS_METRIC]->(wm:WearableMetric) 

101 OPTIONAL MATCH (wm)-[:RECORDED_AS]->(r:Reading) 

102 

103 RETURN 

104 wm.id AS metric_id, 

105 wm.type AS type, 

106 wm.name AS name, 

107 wm.unit AS unit, 

108 wm.normalRange AS normal_range, 

109 collect({ 

110 value: toString(r.value), 

111 timestamp: toString(r.timestamp) 

112 }) AS readings 

113 ORDER BY wm.name 

114 """ 

115 

116 with driver.session() as session: 

117 core_record = session.run(core_cypher, user_id=user_id).single() 

118 wearable_result = session.run(wearables_cypher, user_id=user_id) 

119 wearable_rows = wearable_result.data() 

120 

121 driver.close() 

122 

123 if not core_record: 

124 return { 

125 "patient_id": user_id, 

126 "name": None, 

127 "facts_version": "1.0" 

128 } 

129 

130 patient_node = core_record["p"] 

131 

132 # ── Build wearables block ──────────────────────────────────────── 

133 wearables_block = _format_wearables(wearable_rows) 

134 

135 # ── Build lab results flat list (for prompt_builder) ──────────── 

136 lab_results_flat = _format_labs_flat(core_record["lab_tests"]) 

137 

138 profile = { 

139 "patient_id": patient_node.get("id"), 

140 "name": patient_node.get("name"), 

141 "age": patient_node.get("age"), 

142 "gender": patient_node.get("gender"), 

143 "bloodType": patient_node.get("bloodType"), 

144 "facts_version": "1.0", 

145 

146 # ── Demographics (used by prompt_builder) ────────────────── 

147 "demographics": { 

148 "age": patient_node.get("age"), 

149 "gender": patient_node.get("gender"), 

150 "blood_type": patient_node.get("bloodType"), 

151 }, 

152 

153 # ── Conditions with labs + meds attached ─────────────────── 

154 "conditions": _format_diseases( 

155 core_record["diseases"], 

156 core_record["disease_medications"], 

157 core_record["lab_tests"], 

158 ), 

159 

160 # ── Patient-level medications (used by prompt_builder) ────── 

161 "medications": _format_medications(core_record["patient_medications"]), 

162 

163 # ── Lab results flat list (used by prompt_builder) ────────── 

164 "lab_results": lab_results_flat, 

165 

166 # ── Wearables block (used by prompt_builder) ──────────────── 

167 "wearables": wearables_block, 

168 } 

169 

170 return profile 

171 

172 

173# ------------------------------------------------------------------ 

174# Support: Frontend user management 

175# ------------------------------------------------------------------ 

176 

177def get_all_patients() -> List[Dict[str, Any]]: 

178 """ 

179 Get list of all patients for frontend dropdown. 

180 Returns basic facts only. 

181 """ 

182 driver = _get_driver() 

183 

184 cypher = """ 

185 MATCH (p:Patient) 

186 RETURN p.id AS id, 

187 p.name AS name, 

188 p.age AS age, 

189 p.gender AS gender, 

190 p.bloodType AS bloodType 

191 ORDER BY p.id 

192 """ 

193 

194 with driver.session() as session: 

195 result = session.run(cypher) 

196 patients = [ 

197 { 

198 "id": record["id"], 

199 "name": record["name"], 

200 "age": record["age"], 

201 "gender": record["gender"], 

202 "bloodType": record["bloodType"], 

203 } 

204 for record in result 

205 ] 

206 

207 driver.close() 

208 return patients 

209 

210 

211def create_patient( 

212 user_id: str, 

213 name: str = None, 

214 age: int = None, 

215 gender: str = None, 

216 blood_type: str = None, 

217) -> bool: 

218 """ 

219 Create a new patient node (facts only). 

220 Returns True if created, False if already exists. 

221 """ 

222 driver = _get_driver() 

223 

224 with driver.session() as session: 

225 existing = session.run( 

226 "MATCH (p:Patient {id: $user_id}) RETURN p", 

227 user_id=user_id, 

228 ).single() 

229 

230 if existing: 

231 driver.close() 

232 return False 

233 

234 session.run( 

235 """ 

236 CREATE (p:Patient { 

237 id: $user_id, 

238 name: $name, 

239 age: $age, 

240 gender: $gender, 

241 bloodType: $blood_type, 

242 created_at: datetime() 

243 }) 

244 """, 

245 user_id=user_id, 

246 name=name, 

247 age=age, 

248 gender=gender, 

249 blood_type=blood_type, 

250 ) 

251 

252 driver.close() 

253 return True 

254 

255 

256# ------------------------------------------------------------------ 

257# Helpers — FACT NORMALIZATION ONLY 

258# ------------------------------------------------------------------ 

259 

260def _format_diseases( 

261 disease_nodes: List[Any], 

262 medication_nodes: List[Any], 

263 lab_nodes: List[Any], 

264) -> List[Dict[str, Any]]: 

265 """ 

266 Normalize disease-related FACTS. 

267 Each disease gets its own labs and medications attached. 

268 """ 

269 diseases = [] 

270 

271 for d in disease_nodes: 

272 if not d: 

273 continue 

274 

275 disease = { 

276 "name": d.get("name"), 

277 "category": d.get("category"), 

278 "severity": d.get("severity"), 

279 "status": d.get("status"), 

280 "icd10": d.get("icd10"), 

281 "diagnosed": _safe_date(d.get("diagnosisDate")), 

282 "medications": [], 

283 "lab_results": [], 

284 } 

285 

286 # Attach lab results 

287 for l in lab_nodes: 

288 if not l: 

289 continue 

290 disease["lab_results"].append({ 

291 "name": l.get("name"), 

292 "result": l.get("result"), 

293 "unit": l.get("unit"), 

294 "normal_range": l.get("normalRange"), 

295 "status": l.get("status"), 

296 "date": _safe_date(l.get("testDate")), 

297 }) 

298 

299 # Attach disease-specific medications 

300 for m in medication_nodes: 

301 if not m: 

302 continue 

303 disease["medications"].append({ 

304 "name": m.get("name"), 

305 "dosage": m.get("dosage"), 

306 "frequency": m.get("frequency"), 

307 "purpose": m.get("purpose"), 

308 }) 

309 

310 diseases.append(disease) 

311 

312 return diseases 

313 

314 

315def _format_medications(nodes: List[Any]) -> List[Dict[str, Any]]: 

316 """ 

317 Normalize patient-level medication FACTS. 

318 """ 

319 meds = [] 

320 

321 for m in nodes: 

322 if not m: 

323 continue 

324 meds.append({ 

325 "name": m.get("name"), 

326 "dosage": m.get("dosage"), 

327 "frequency": m.get("frequency"), 

328 "purpose": m.get("purpose"), 

329 "atc_code": m.get("atcCode"), 

330 }) 

331 

332 return meds 

333 

334 

335def _format_labs_flat(lab_nodes: List[Any]) -> List[Dict[str, Any]]: 

336 """ 

337 Flat list of all lab results for prompt_builder. 

338 Separate from disease-attached labs so prompt_builder 

339 can show them in a dedicated section. 

340 """ 

341 labs = [] 

342 

343 for l in lab_nodes: 

344 if not l: 

345 continue 

346 labs.append({ 

347 "name": l.get("name"), 

348 "result": l.get("result"), 

349 "unit": l.get("unit"), 

350 "normal_range": l.get("normalRange"), 

351 "status": l.get("status"), 

352 "date": _safe_date(l.get("testDate")), 

353 }) 

354 

355 return labs 

356 

357 

358def _format_wearables(wearable_rows: List[Dict]) -> Dict[str, Any]: 

359 """ 

360 Build wearables block from raw Neo4j rows. 

361 

362 For each metric: 

363 - Extracts all dated readings 

364 - Calculates latest, previous, average 

365 - Computes clean trend label (never exposes raw internals) 

366 - Returns block ready for prompt_builder._format_wearables() 

367 """ 

368 

369 if not wearable_rows: 

370 return {"available": False, "metrics": []} 

371 

372 metrics = [] 

373 

374 for row in wearable_rows: 

375 raw_readings = row.get("readings", []) 

376 

377 # Filter out empty readings (no value) 

378 valid_readings = [ 

379 r for r in raw_readings 

380 if r.get("value") and r.get("value") not in ("None", "", "null") 

381 ] 

382 

383 # Sort readings by timestamp ascending 

384 valid_readings.sort(key=lambda r: r.get("timestamp", "")) 

385 

386 # Build clean dated readings list 

387 dated_readings = [ 

388 { 

389 "date": _clean_timestamp(r.get("timestamp", "")), 

390 "value": f"{r.get('value')} {row.get('unit', '')}".strip(), 

391 } 

392 for r in valid_readings 

393 ] 

394 

395 # Compute latest, previous, average for numeric metrics 

396 latest_value = "not recorded yet" 

397 previous_value = "not recorded yet" 

398 average_value = "not recorded yet" 

399 trend = "monitoring ongoing — more readings needed" 

400 

401 numeric_vals = _extract_numeric_values(valid_readings) 

402 

403 if dated_readings: 

404 latest_value = dated_readings[-1]["value"] 

405 

406 if len(dated_readings) >= 2: 

407 previous_value = dated_readings[-2]["value"] 

408 

409 if len(numeric_vals) >= 2: 

410 avg = sum(numeric_vals) / len(numeric_vals) 

411 average_value = f"{avg:.1f} {row.get('unit', '')}".strip() 

412 trend = _compute_trend(numeric_vals) 

413 

414 elif len(numeric_vals) == 1: 

415 average_value = f"{numeric_vals[0]} {row.get('unit', '')}".strip() 

416 trend = "monitoring ongoing — more readings needed" 

417 

418 metrics.append({ 

419 "metric": row.get("name", "Unknown Metric"), 

420 "type": row.get("type"), 

421 "unit": row.get("unit"), 

422 "normal_range": row.get("normal_range", "N/A"), 

423 "latest_value": latest_value, 

424 "previous_value": previous_value, 

425 "average_value": average_value, 

426 "trend": trend, 

427 "readings": dated_readings, 

428 }) 

429 

430 return { 

431 "available": len(metrics) > 0, 

432 "metrics": metrics, 

433 } 

434 

435 

436def _extract_numeric_values(readings: List[Dict]) -> List[float]: 

437 """ 

438 Extract numeric values from readings. 

439 Skips non-numeric values like "NSR", "138/88" etc. 

440 For blood pressure takes systolic only. 

441 """ 

442 values = [] 

443 for r in readings: 

444 raw = str(r.get("value", "")).strip() 

445 # Handle blood pressure format "138/88" 

446 if "/" in raw: 

447 try: 

448 systolic = float(raw.split("/")[0]) 

449 values.append(systolic) 

450 except ValueError: 

451 pass 

452 else: 

453 try: 

454 values.append(float(raw)) 

455 except ValueError: 

456 pass # Non-numeric like "NSR" — skip cleanly 

457 return values 

458 

459 

460def _compute_trend(values: List[float]) -> str: 

461 """ 

462 Compute a clean human-readable trend label. 

463 Never exposes raw internal values. 

464 """ 

465 if len(values) < 2: 

466 return "monitoring ongoing — more readings needed" 

467 

468 first = values[0] 

469 last = values[-1] 

470 diff = last - first 

471 

472 if abs(diff) < 2: 

473 return "stable" 

474 elif diff > 0: 

475 pct = (diff / first) * 100 if first != 0 else 0 

476 return f"increasing ({pct:.1f}% rise over recorded period)" 

477 else: 

478 pct = (abs(diff) / first) * 100 if first != 0 else 0 

479 return f"decreasing ({pct:.1f}% drop over recorded period)" 

480 

481 

482def _clean_timestamp(ts: str) -> str: 

483 """ 

484 Convert Neo4j timestamp string to readable date. 

485 e.g. "2026-02-08T08:00:00Z" → "2026-02-08" 

486 """ 

487 if not ts: 

488 return "unknown date" 

489 return ts[:10] # Take YYYY-MM-DD portion only 

490 

491 

492def _safe_date(value) -> str: 

493 return str(value) if value else None