Source code for discotime.datasets.from_rstats

from pathlib import Path
from typing import Optional
from functools import cached_property

from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
from torch import Generator
import requests
import rdata
import numpy as np
import pandas as pd

import discotime
from discotime.datasets.utils import (
    LitSurvDataModule,
    DataConfig,
    default_fts_transformer,
    SurvDataset,
)
from discotime.datasets.utils import LabelDiscretizer


_ROOT_PATH = Path(discotime.__file__).parent
_DATA_PATH = _ROOT_PATH / "datasets" / "_data"


[docs]class Mgus2(LitSurvDataModule): """ Natural history of 1341 sequential patients with monoclonal gammopathy of undetermined significance (MGUS). [1]: R. Kyle, T. Therneau, V. Rajkumar, J. Offord, D. Larson, M. Plevak, and L. J. Melton III, A long-terms study of prognosis in monoclonal gammopathy of undertermined significance. New Engl J Med, 346:564-569 (2002). """ n_features = 7 n_risks = 2 def __init__( self, data_config: DataConfig, data_dir: Path = _DATA_PATH, seed: int = 13411341, ) -> None: super().__init__() self.seed = seed self.rng = Generator().manual_seed(seed) self.mgus2_csv = data_dir / "mgus2.csv" self.fts_transformer = default_fts_transformer() self.lab_transformer = LabelDiscretizer( n_bins=data_config.n_time_bins, scheme=data_config.discretization_scheme, ) self._config = data_config self._dset_fit: Optional[SurvDataset] = None self._dset_test: Optional[SurvDataset] = None self.save_hyperparameters(ignore=["data_dir"])
[docs] def prepare_data(self) -> None: """Fetch the built-in mgus2 from the R survival package. The mgus2 data [1] is extracted from the rdata file and saved as a csv in the discotime installation directory. If the csv is already available, then the download logic is skipped. [1]: Therneau T (2023). A Package for Survival Analysis in R. """ url = "https://github.com/therneau/survival/raw/649851/data/cancer.rda" if not self.mgus2_csv.is_file(): # create data directory if needed self.mgus2_csv.parent.mkdir(parents=True, exist_ok=True) # download and parse rdata file r = requests.get(url) parsed = rdata.parser.parse_data(r.content) df = rdata.conversion.convert(parsed)["mgus2"] # slight reformatting x = df.loc[:, "age":"mspike"] # type: ignore t = df[["futime", "ptime"]].min(axis=1) e = pd.concat((df["pstat"] * 2, df["death"]), axis=1).max(axis=1) y = pd.concat({"time": t, "event": e}, axis=1) # combine and save as csv pd.concat((x, y), axis=1).to_csv(self.mgus2_csv, index=False)
@property def dset_fit(self) -> SurvDataset: if self._dset_fit is None: raise AttributeError( "`dset_fit` not initialized yet." "Maybe try calling `.setup()` first?" ) return self._dset_fit @property def dset_test(self) -> SurvDataset: if self._dset_test is None: raise AttributeError( "`dset_fit` not initialized yet." "Maybe try calling `.setup()` first?" ) return self._dset_test @cached_property def time_range(self): data = pd.read_csv(self.mgus2_csv, dtype={"event": np.int32}) data = data[data.event != 0] time = data.groupby(["event"])["time"] return (time.min().max(), time.max().min())
[docs] def setup(self, stage: Optional[str] = None) -> None: if self._dset_fit is None: data = pd.read_csv(self.mgus2_csv, dtype={"event": np.int32}) features = data.loc[:, "age":"mspike"] # type: ignore outcomes = data.loc[:, ["time", "event"]] X_fit, X_test, Y_fit, Y_test = train_test_split( features, outcomes, test_size=0.2, random_state=self.seed ) # prepare transformers self.lab_transformer.fit(Y_fit.time.values, Y_fit.event.values) self.fts_transformer.fit(X_fit) # initialize the datasets self._dset_fit = SurvDataset( features=self.fts_transformer.transform(X_fit), event_time=Y_fit.time.values, event_status=Y_fit.event.values, discretizer=self.lab_transformer, ) self._dset_test = SurvDataset( features=self.fts_transformer.transform(X_test), event_time=Y_test.time.values, event_status=Y_test.event.values, discretizer=self.lab_transformer, ) if stage in {"fit", "validate", "debug"}: self.dset_train, self.dset_val = random_split( dataset=self.dset_fit, lengths=[0.8, 0.2], generator=self.rng, )
[docs] def train_dataloader(self): return DataLoader( self.dset_train, batch_size=self.batch_size, shuffle=True, num_workers=1, )
[docs] def val_dataloader(self): return DataLoader( self.dset_val, batch_size=self.batch_size, num_workers=1 )
[docs] def test_dataloader(self): return DataLoader( self.dset_test, batch_size=self.batch_size, num_workers=1 )
[docs] def predict_dataloader(self): return DataLoader( self.dset_test, batch_size=self.batch_size, num_workers=1 )