Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test-resources/core/mimic3demo/admissions.csv.gz
Binary file not shown.
Binary file not shown.
Binary file added test-resources/core/mimic3demo/patients.csv.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
98 changes: 54 additions & 44 deletions tests/core/test_mimic3.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,94 @@
import unittest
import tempfile
import shutil
import subprocess
import os
from pathlib import Path

from pyhealth.datasets import MIMIC3Dataset


class TestMIMIC3Demo(unittest.TestCase):
"""Test MIMIC3 dataset with demo data downloaded from PhysioNet."""
"""Test MIMIC3 dataset with demo data from local test resources."""

def setUp(self):
"""Download and set up demo dataset for each test."""
self.temp_dir = tempfile.mkdtemp()
self._download_demo_dataset()
"""Set up demo dataset path for each test."""
self._setup_dataset_path()
self._load_dataset()

def tearDown(self):
"""Clean up downloaded dataset after each test."""
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def _download_demo_dataset(self):
"""Download MIMIC-III demo dataset using wget."""
download_url = "https://physionet.org/files/mimiciii-demo/1.4/"

# Use wget to download the demo dataset recursively
cmd = [
"wget",
"-r",
"-N",
"-c",
"-np",
"--directory-prefix",
self.temp_dir,
download_url,
]

try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise unittest.SkipTest(f"Failed to download MIMIC-III demo dataset: {e}")
except FileNotFoundError:
raise unittest.SkipTest("wget not available - skipping download test")

# Find the downloaded dataset path
physionet_dir = (
Path(self.temp_dir) / "physionet.org" / "files" / "mimiciii-demo" / "1.4"
)
if physionet_dir.exists():
self.demo_dataset_path = str(physionet_dir)
else:
raise unittest.SkipTest("Downloaded dataset not found in expected location")
def _setup_dataset_path(self):
"""Get path to local MIMIC-III demo dataset in test resources."""
# Get the path to the test-resources/core/mimic3demo directory
test_dir = Path(__file__).parent.parent
self.demo_dataset_path = str(test_dir / "test-resources" / "core" / "mimic3demo")

print(f"\n{'='*60}")
print(f"Setting up MIMIC-III demo dataset")
print(f"Dataset path: {self.demo_dataset_path}")

# Verify the dataset exists
if not os.path.exists(self.demo_dataset_path):
raise unittest.SkipTest(
f"MIMIC-III demo dataset not found at {self.demo_dataset_path}"
)

# List files in the dataset directory
files = os.listdir(self.demo_dataset_path)
print(f"Found {len(files)} files in dataset directory:")
for f in sorted(files):
file_path = os.path.join(self.demo_dataset_path, f)
size = os.path.getsize(file_path) / 1024 # KB
print(f" - {f} ({size:.1f} KB)")
print(f"{'='*60}\n")

def _load_dataset(self):
"""Load the dataset for testing."""
tables = ["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"]
tables = ["diagnoses_icd", "procedures_icd", "prescriptions"]
print(f"Loading MIMIC3Dataset with tables: {tables}")
self.dataset = MIMIC3Dataset(root=self.demo_dataset_path, tables=tables)
print(f"✓ Dataset loaded successfully")
print(f" Total patients: {len(self.dataset.patients)}")
print()

def test_stats(self):
"""Test .stats() method execution."""
print(f"\n{'='*60}")
print("TEST: test_stats()")
print(f"{'='*60}")
try:
print("Calling dataset.stats()...")
self.dataset.stats()
print("✓ dataset.stats() executed successfully")
except Exception as e:
print(f"✗ dataset.stats() failed with error: {e}")
self.fail(f"dataset.stats() failed: {e}")

def test_get_events(self):
"""Test get_patient and get_events methods with patient 10006."""
print(f"\n{'='*60}")
print("TEST: test_get_events()")
print(f"{'='*60}")

# Test get_patient method
print("Getting patient 10006...")
patient = self.dataset.get_patient("10006")
self.assertIsNotNone(patient, msg="Patient 10006 should exist in demo dataset")
print(f"✓ Patient 10006 found: {patient}")

# Test get_events method
print("Getting events for patient 10006...")
events = patient.get_events()
self.assertIsNotNone(events, msg="get_events() should not return None")
self.assertIsInstance(events, list, msg="get_events() should return a list")
self.assertGreater(
len(events), 0, msg="get_events() should not return an empty list"
)
print(f"✓ Retrieved {len(events)} events")
print(f" Event types: {set(e[0] for e in events)}")

# Show sample events
print(f"\nSample events (first 3):")
for i, event in enumerate(events[:3]):
print(f" {i+1}. Type: {event[0]}, Time: {event[1]}, Data: {event[2]}")

print(f"✓ test_get_events() passed successfully")


if __name__ == "__main__":
Expand Down
125 changes: 74 additions & 51 deletions tests/core/test_mimic3_mortality_prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import unittest
import tempfile
import shutil
import subprocess
import os
from pathlib import Path

Expand All @@ -13,89 +10,96 @@


class TestMIMIC3MortalityPrediction(unittest.TestCase):
"""Test MIMIC-3 mortality prediction tasks with demo data downloaded from PhysioNet."""
"""Test MIMIC-3 mortality prediction tasks with demo data from local test resources."""

def setUp(self):
"""Download and set up demo dataset for each test."""
self.temp_dir = tempfile.mkdtemp()
self._download_demo_dataset()
"""Set up demo dataset path for each test."""
self._setup_dataset_path()
self._load_dataset()

def tearDown(self):
"""Clean up downloaded dataset after each test."""
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def _download_demo_dataset(self):
"""Download MIMIC-III demo dataset using wget."""
download_url = "https://physionet.org/files/mimiciii-demo/1.4/"

# Use wget to download the demo dataset recursively
cmd = [
"wget",
"-r",
"-N",
"-c",
"-np",
"--directory-prefix",
self.temp_dir,
download_url,
]

try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise unittest.SkipTest(f"Failed to download MIMIC-III demo dataset: {e}")
except FileNotFoundError:
raise unittest.SkipTest("wget not available - skipping download test")

# Find the downloaded dataset path
physionet_dir = (
Path(self.temp_dir) / "physionet.org" / "files" / "mimiciii-demo" / "1.4"
)
if physionet_dir.exists():
self.demo_dataset_path = str(physionet_dir)
else:
raise unittest.SkipTest("Downloaded dataset not found in expected location")
def _setup_dataset_path(self):
"""Get path to local MIMIC-III demo dataset in test resources."""
# Get the path to the test-resources/core/mimic3demo directory
test_dir = Path(__file__).parent.parent
self.demo_dataset_path = str(test_dir / "test-resources" / "core" / "mimic3demo")

print(f"\n{'='*60}")
print(f"Setting up MIMIC-III demo dataset for mortality prediction")
print(f"Dataset path: {self.demo_dataset_path}")

# Verify the dataset exists
if not os.path.exists(self.demo_dataset_path):
raise unittest.SkipTest(
f"MIMIC-III demo dataset not found at {self.demo_dataset_path}"
)

# List files in the dataset directory
files = os.listdir(self.demo_dataset_path)
print(f"Found {len(files)} files in dataset directory:")
for f in sorted(files):
file_path = os.path.join(self.demo_dataset_path, f)
size = os.path.getsize(file_path) / 1024 # KB
print(f" - {f} ({size:.1f} KB)")
print(f"{'='*60}\n")

def _load_dataset(self):
"""Load the dataset for testing."""
tables = ["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"]
tables = ["diagnoses_icd", "procedures_icd", "prescriptions"]
print(f"Loading MIMIC3Dataset with tables: {tables}")
self.dataset = MIMIC3Dataset(root=self.demo_dataset_path, tables=tables)
print(f"✓ Dataset loaded successfully")
print(f" Total patients: {len(self.dataset.patients)}")
print()

def test_dataset_stats(self):
"""Test that the dataset loads correctly and stats() works."""
print(f"\n{'='*60}")
print("TEST: test_dataset_stats()")
print(f"{'='*60}")
try:
print("Calling dataset.stats()...")
self.dataset.stats()
print("✓ dataset.stats() executed successfully")
except Exception as e:
print(f"✗ dataset.stats() failed with error: {e}")
self.fail(f"dataset.stats() failed: {e}")

def test_mortality_prediction_mimic3_set_task(self):
"""Test MortalityPredictionMIMIC3 task with set_task() method."""
print(f"\n{'='*60}")
print("TEST: test_mortality_prediction_mimic3_set_task()")
print(f"{'='*60}")

print("Initializing MortalityPredictionMIMIC3 task...")
task = MortalityPredictionMIMIC3()

# Test that task is properly initialized
print(f"✓ Task initialized: {task.task_name}")
self.assertEqual(task.task_name, "MortalityPredictionMIMIC3")
self.assertIn("conditions", task.input_schema)
self.assertIn("procedures", task.input_schema)
self.assertIn("drugs", task.input_schema)
self.assertIn("mortality", task.output_schema)
print(f" Input schema: {list(task.input_schema.keys())}")
print(f" Output schema: {list(task.output_schema.keys())}")

# Test using set_task method
try:
print("\nCalling dataset.set_task()...")
sample_dataset = self.dataset.set_task(task)
self.assertIsNotNone(sample_dataset, "set_task should return a dataset")
self.assertTrue(
hasattr(sample_dataset, "samples"), "Sample dataset should have samples"
)
print(f"✓ set_task() completed")

# Verify we got some samples
self.assertGreater(
len(sample_dataset.samples), 0, "Should generate at least one sample"
)
num_samples = len(sample_dataset.samples)
self.assertGreater(num_samples, 0, "Should generate at least one sample")
print(f"✓ Generated {num_samples} mortality prediction samples")

# Test sample structure
if len(sample_dataset.samples) > 0:
if num_samples > 0:
sample = sample_dataset.samples[0]
required_keys = [
"hadm_id",
Expand All @@ -105,20 +109,39 @@ def test_mortality_prediction_mimic3_set_task(self):
"drugs",
"mortality",
]

print(f"\nFirst sample structure:")
print(f" Sample keys: {list(sample.keys())}")

for key in required_keys:
self.assertIn(key, sample, f"Sample should contain key: {key}")
if key in ["conditions", "procedures", "drugs"]:
print(f" - {key}: {len(sample[key])} items")
else:
print(f" - {key}: {sample[key]}")

# Verify mortality label is binary (0 or 1)
self.assertIn(
sample["mortality"], [0, 1], "Mortality label should be 0 or 1"
)

print(f"Generated {len(sample_dataset.samples)} mortality samples")
print(f"Sample keys: {list(sample.keys())}")

# Count mortality distribution
mortality_counts = {0: 0, 1: 0}
for s in sample_dataset.samples:
mortality_counts[s["mortality"]] += 1
print(f"\nMortality label distribution:")
print(f" Survived (0): {mortality_counts[0]} ({mortality_counts[0]/num_samples*100:.1f}%)")
print(f" Died (1): {mortality_counts[1]} ({mortality_counts[1]/num_samples*100:.1f}%)")

print(f"\n✓ test_mortality_prediction_mimic3_set_task() passed successfully")

except Exception as e:
print(f"✗ Failed with error: {e}")
import traceback
traceback.print_exc()
self.fail(f"Failed to use set_task with MortalityPredictionMIMIC3: {e}")

@unittest.skip("Skipping multimodal test - noteevents not included in test resources")
def test_multimodal_mortality_prediction_mimic3_set_task(self):
"""Test MultimodalMortalityPredictionMIMIC3 task with set_task() method."""
task = MultimodalMortalityPredictionMIMIC3()
Expand Down