Coverage for app \ knowledge_graph \ patient_graph_reader.py: 36%
139 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 13:18 +0530
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 13:18 +0530
1"""
2Patient Graph Reader
4Purpose:
5- Write patient-related FACTS into Neo4j
6- Read patient medical history as clean, structured JSON
7- Used by Hybrid Graph-RAG (NO reasoning, NO research papers)
9Design Principles:
10- Deterministic
11- Auditable
12- LLM-safe (facts only, no inference)
14Neo4j Domain:
15Patient → Disease → Medication
16Patient → Medication
17Disease → LabTest
18Patient → WearableMetric → Reading
19"""
21from typing import Dict, Any, List
22from neo4j import GraphDatabase
23import os
26# ------------------------------------------------------------------
27# Neo4j connection
28# ------------------------------------------------------------------
30def _get_driver():
31 uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
32 user = os.getenv("NEO4J_USER", "neo4j")
33 password = os.getenv("NEO4J_PASSWORD", "password")
34 return GraphDatabase.driver(uri, auth=(user, password))
37# ------------------------------------------------------------------
38# WRITE: Ensure patient exists (facts only)
39# ------------------------------------------------------------------
41def upsert_user_from_question(user_id: str, question: str) -> None:
42 """
43 Minimal deterministic update:
44 - Ensure Patient node exists
45 - Does NOT extract diseases or symptoms
46 """
47 driver = _get_driver()
49 cypher = """
50 MERGE (p:Patient {id: $user_id})
51 ON CREATE SET p.created_at = datetime()
52 ON MATCH SET p.updated_at = datetime()
53 """
55 with driver.session() as session:
56 session.run(cypher, user_id=user_id)
58 driver.close()
61# ------------------------------------------------------------------
62# READ: Fetch patient FACTS profile
63# ------------------------------------------------------------------
65def get_patient_profile(user_id: str) -> Dict[str, Any]:
66 """
67 Fetch complete patient medical profile as FACTS ONLY.
68 No inference, no reasoning.
69 Output is safe to pass into LLM context.
71 Fetches:
72 - Demographics
73 - Conditions + severity + diagnosis date
74 - Medications + purpose + what they treat
75 - Lab results linked to each disease
76 - Wearable metrics + all dated readings
77 """
79 driver = _get_driver()
81 # ── Core patient + conditions + medications + labs ──────────────
82 core_cypher = """
83 MATCH (p:Patient {id: $user_id})
85 OPTIONAL MATCH (p)-[:HAS_DISEASE]->(d:Disease)
86 OPTIONAL MATCH (d)-[:TREATED_BY]->(m:Medication)
87 OPTIONAL MATCH (p)-[:PRESCRIBED]->(pm:Medication)
88 OPTIONAL MATCH (d)-[:HAS_LAB_RESULT]->(l:LabTest)
90 RETURN
91 p,
92 collect(DISTINCT d) AS diseases,
93 collect(DISTINCT m) AS disease_medications,
94 collect(DISTINCT pm) AS patient_medications,
95 collect(DISTINCT l) AS lab_tests
96 """
98 # ── Wearables + all readings ─────────────────────────────────────
99 wearables_cypher = """
100 MATCH (p:Patient {id: $user_id})-[:HAS_METRIC]->(wm:WearableMetric)
101 OPTIONAL MATCH (wm)-[:RECORDED_AS]->(r:Reading)
103 RETURN
104 wm.id AS metric_id,
105 wm.type AS type,
106 wm.name AS name,
107 wm.unit AS unit,
108 wm.normalRange AS normal_range,
109 collect({
110 value: toString(r.value),
111 timestamp: toString(r.timestamp)
112 }) AS readings
113 ORDER BY wm.name
114 """
116 with driver.session() as session:
117 core_record = session.run(core_cypher, user_id=user_id).single()
118 wearable_result = session.run(wearables_cypher, user_id=user_id)
119 wearable_rows = wearable_result.data()
121 driver.close()
123 if not core_record:
124 return {
125 "patient_id": user_id,
126 "name": None,
127 "facts_version": "1.0"
128 }
130 patient_node = core_record["p"]
132 # ── Build wearables block ────────────────────────────────────────
133 wearables_block = _format_wearables(wearable_rows)
135 # ── Build lab results flat list (for prompt_builder) ────────────
136 lab_results_flat = _format_labs_flat(core_record["lab_tests"])
138 profile = {
139 "patient_id": patient_node.get("id"),
140 "name": patient_node.get("name"),
141 "age": patient_node.get("age"),
142 "gender": patient_node.get("gender"),
143 "bloodType": patient_node.get("bloodType"),
144 "facts_version": "1.0",
146 # ── Demographics (used by prompt_builder) ──────────────────
147 "demographics": {
148 "age": patient_node.get("age"),
149 "gender": patient_node.get("gender"),
150 "blood_type": patient_node.get("bloodType"),
151 },
153 # ── Conditions with labs + meds attached ───────────────────
154 "conditions": _format_diseases(
155 core_record["diseases"],
156 core_record["disease_medications"],
157 core_record["lab_tests"],
158 ),
160 # ── Patient-level medications (used by prompt_builder) ──────
161 "medications": _format_medications(core_record["patient_medications"]),
163 # ── Lab results flat list (used by prompt_builder) ──────────
164 "lab_results": lab_results_flat,
166 # ── Wearables block (used by prompt_builder) ────────────────
167 "wearables": wearables_block,
168 }
170 return profile
173# ------------------------------------------------------------------
174# Support: Frontend user management
175# ------------------------------------------------------------------
177def get_all_patients() -> List[Dict[str, Any]]:
178 """
179 Get list of all patients for frontend dropdown.
180 Returns basic facts only.
181 """
182 driver = _get_driver()
184 cypher = """
185 MATCH (p:Patient)
186 RETURN p.id AS id,
187 p.name AS name,
188 p.age AS age,
189 p.gender AS gender,
190 p.bloodType AS bloodType
191 ORDER BY p.id
192 """
194 with driver.session() as session:
195 result = session.run(cypher)
196 patients = [
197 {
198 "id": record["id"],
199 "name": record["name"],
200 "age": record["age"],
201 "gender": record["gender"],
202 "bloodType": record["bloodType"],
203 }
204 for record in result
205 ]
207 driver.close()
208 return patients
211def create_patient(
212 user_id: str,
213 name: str = None,
214 age: int = None,
215 gender: str = None,
216 blood_type: str = None,
217) -> bool:
218 """
219 Create a new patient node (facts only).
220 Returns True if created, False if already exists.
221 """
222 driver = _get_driver()
224 with driver.session() as session:
225 existing = session.run(
226 "MATCH (p:Patient {id: $user_id}) RETURN p",
227 user_id=user_id,
228 ).single()
230 if existing:
231 driver.close()
232 return False
234 session.run(
235 """
236 CREATE (p:Patient {
237 id: $user_id,
238 name: $name,
239 age: $age,
240 gender: $gender,
241 bloodType: $blood_type,
242 created_at: datetime()
243 })
244 """,
245 user_id=user_id,
246 name=name,
247 age=age,
248 gender=gender,
249 blood_type=blood_type,
250 )
252 driver.close()
253 return True
256# ------------------------------------------------------------------
257# Helpers — FACT NORMALIZATION ONLY
258# ------------------------------------------------------------------
260def _format_diseases(
261 disease_nodes: List[Any],
262 medication_nodes: List[Any],
263 lab_nodes: List[Any],
264) -> List[Dict[str, Any]]:
265 """
266 Normalize disease-related FACTS.
267 Each disease gets its own labs and medications attached.
268 """
269 diseases = []
271 for d in disease_nodes:
272 if not d:
273 continue
275 disease = {
276 "name": d.get("name"),
277 "category": d.get("category"),
278 "severity": d.get("severity"),
279 "status": d.get("status"),
280 "icd10": d.get("icd10"),
281 "diagnosed": _safe_date(d.get("diagnosisDate")),
282 "medications": [],
283 "lab_results": [],
284 }
286 # Attach lab results
287 for l in lab_nodes:
288 if not l:
289 continue
290 disease["lab_results"].append({
291 "name": l.get("name"),
292 "result": l.get("result"),
293 "unit": l.get("unit"),
294 "normal_range": l.get("normalRange"),
295 "status": l.get("status"),
296 "date": _safe_date(l.get("testDate")),
297 })
299 # Attach disease-specific medications
300 for m in medication_nodes:
301 if not m:
302 continue
303 disease["medications"].append({
304 "name": m.get("name"),
305 "dosage": m.get("dosage"),
306 "frequency": m.get("frequency"),
307 "purpose": m.get("purpose"),
308 })
310 diseases.append(disease)
312 return diseases
315def _format_medications(nodes: List[Any]) -> List[Dict[str, Any]]:
316 """
317 Normalize patient-level medication FACTS.
318 """
319 meds = []
321 for m in nodes:
322 if not m:
323 continue
324 meds.append({
325 "name": m.get("name"),
326 "dosage": m.get("dosage"),
327 "frequency": m.get("frequency"),
328 "purpose": m.get("purpose"),
329 "atc_code": m.get("atcCode"),
330 })
332 return meds
335def _format_labs_flat(lab_nodes: List[Any]) -> List[Dict[str, Any]]:
336 """
337 Flat list of all lab results for prompt_builder.
338 Separate from disease-attached labs so prompt_builder
339 can show them in a dedicated section.
340 """
341 labs = []
343 for l in lab_nodes:
344 if not l:
345 continue
346 labs.append({
347 "name": l.get("name"),
348 "result": l.get("result"),
349 "unit": l.get("unit"),
350 "normal_range": l.get("normalRange"),
351 "status": l.get("status"),
352 "date": _safe_date(l.get("testDate")),
353 })
355 return labs
358def _format_wearables(wearable_rows: List[Dict]) -> Dict[str, Any]:
359 """
360 Build wearables block from raw Neo4j rows.
362 For each metric:
363 - Extracts all dated readings
364 - Calculates latest, previous, average
365 - Computes clean trend label (never exposes raw internals)
366 - Returns block ready for prompt_builder._format_wearables()
367 """
369 if not wearable_rows:
370 return {"available": False, "metrics": []}
372 metrics = []
374 for row in wearable_rows:
375 raw_readings = row.get("readings", [])
377 # Filter out empty readings (no value)
378 valid_readings = [
379 r for r in raw_readings
380 if r.get("value") and r.get("value") not in ("None", "", "null")
381 ]
383 # Sort readings by timestamp ascending
384 valid_readings.sort(key=lambda r: r.get("timestamp", ""))
386 # Build clean dated readings list
387 dated_readings = [
388 {
389 "date": _clean_timestamp(r.get("timestamp", "")),
390 "value": f"{r.get('value')} {row.get('unit', '')}".strip(),
391 }
392 for r in valid_readings
393 ]
395 # Compute latest, previous, average for numeric metrics
396 latest_value = "not recorded yet"
397 previous_value = "not recorded yet"
398 average_value = "not recorded yet"
399 trend = "monitoring ongoing — more readings needed"
401 numeric_vals = _extract_numeric_values(valid_readings)
403 if dated_readings:
404 latest_value = dated_readings[-1]["value"]
406 if len(dated_readings) >= 2:
407 previous_value = dated_readings[-2]["value"]
409 if len(numeric_vals) >= 2:
410 avg = sum(numeric_vals) / len(numeric_vals)
411 average_value = f"{avg:.1f} {row.get('unit', '')}".strip()
412 trend = _compute_trend(numeric_vals)
414 elif len(numeric_vals) == 1:
415 average_value = f"{numeric_vals[0]} {row.get('unit', '')}".strip()
416 trend = "monitoring ongoing — more readings needed"
418 metrics.append({
419 "metric": row.get("name", "Unknown Metric"),
420 "type": row.get("type"),
421 "unit": row.get("unit"),
422 "normal_range": row.get("normal_range", "N/A"),
423 "latest_value": latest_value,
424 "previous_value": previous_value,
425 "average_value": average_value,
426 "trend": trend,
427 "readings": dated_readings,
428 })
430 return {
431 "available": len(metrics) > 0,
432 "metrics": metrics,
433 }
436def _extract_numeric_values(readings: List[Dict]) -> List[float]:
437 """
438 Extract numeric values from readings.
439 Skips non-numeric values like "NSR", "138/88" etc.
440 For blood pressure takes systolic only.
441 """
442 values = []
443 for r in readings:
444 raw = str(r.get("value", "")).strip()
445 # Handle blood pressure format "138/88"
446 if "/" in raw:
447 try:
448 systolic = float(raw.split("/")[0])
449 values.append(systolic)
450 except ValueError:
451 pass
452 else:
453 try:
454 values.append(float(raw))
455 except ValueError:
456 pass # Non-numeric like "NSR" — skip cleanly
457 return values
460def _compute_trend(values: List[float]) -> str:
461 """
462 Compute a clean human-readable trend label.
463 Never exposes raw internal values.
464 """
465 if len(values) < 2:
466 return "monitoring ongoing — more readings needed"
468 first = values[0]
469 last = values[-1]
470 diff = last - first
472 if abs(diff) < 2:
473 return "stable"
474 elif diff > 0:
475 pct = (diff / first) * 100 if first != 0 else 0
476 return f"increasing ({pct:.1f}% rise over recorded period)"
477 else:
478 pct = (abs(diff) / first) * 100 if first != 0 else 0
479 return f"decreasing ({pct:.1f}% drop over recorded period)"
482def _clean_timestamp(ts: str) -> str:
483 """
484 Convert Neo4j timestamp string to readable date.
485 e.g. "2026-02-08T08:00:00Z" → "2026-02-08"
486 """
487 if not ts:
488 return "unknown date"
489 return ts[:10] # Take YYYY-MM-DD portion only
492def _safe_date(value) -> str:
493 return str(value) if value else None