Getting started with discotime

import warnings

import numpy as np
import pandas as pd
import torch

from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

from discotime.datasets import Mgus2, DataConfig
from discotime.models import LitSurvModule, ModelConfig

For this example, we will be using the Mgus2 dataset from the survival pacakage in R. The dataset contains natural history of 1341 patients with monoclonal gammopathy of undetermined significance (MGUS) [1]. In the discotime package, this dataset is supplied as a LightningDataModule that automatically handles downloading, preprocessing, splitting into training and testing data, etc.

For initializiation of the data module, arguments are supplied in the form of a DataConfig object as follows.

mgus2dm = Mgus2(
    DataConfig(
        batch_size=128,
        n_time_bins=20,
        discretization_scheme="number",
    )
)
mgus2dm
<discotime.datasets.from_rstats.Mgus2 object at 0x7effb9c66770>
    n_features: 7, n_risks: 2, n_time_bins: 20

If we want to inspect the data, we can load the fit (train + val) and test data by calling setup on the object. During normal use, this is handled automatically by Lightning.Trainer.

mgus2dm.prepare_data()
mgus2dm.setup()
mgus2dm.dset_fit
<discotime.datasets.utils.SurvDataset at 0x7effa2047e50>
mgus2dm.dset_fit[0]
SurvData(features=array([ 0.        ,  1.        ,  1.0580127 ,  0.87056947, -0.7417322 ,
       -0.4214596 , -1.1587092 ], dtype=float32), event_time_disc=4, event_status_disc=1, event_time_cont=29, event_status_cont=1)

The data consist of features (a dataframe) and labels. The features should be fairly self explanatory, and the labels is a tuple of survival times and event indicators. In the Discotime package and across the other examples, an event/status of 0 indicates censoring.

To create the survival model, we use a different configuration object.

model = LitSurvModule(
    ModelConfig(
        learning_rate=1e-3,
        n_hidden_units=15,
        n_sequential_blocks=5,
        use_skip_connections=True,
    )
)

The last thing we need now is to instantiate the Lightning trainer. For this example, we will only be using a single CPU, but with the lightning Trainer it’s easy to scale training to multiple devices and/or GPUs.

early_stopping = EarlyStopping(
    monitor="val_loss", min_delta=0.001, patience=30, mode="min"
)

warnings.filterwarnings("ignore", "GPU available")
trainer = Trainer(
    accelerator="cpu",
    devices=1,
    max_epochs=3000,
    enable_checkpointing=False,
    enable_progress_bar=False,
    logger=False,
    reload_dataloaders_every_n_epochs=1,
    callbacks=[early_stopping],
)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
warnings.filterwarnings("ignore", ".*bottleneck")
trainer.fit(model, mgus2dm)
/home/pechris/work/discotime/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:411: UserWarning: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.
  warning_cache.warn(

  | Name   | Type | Params
--------------------------------
0 | _model | Net  | 0
--------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
__ = trainer.test(model, mgus2dm)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test_IPA_cause1      │    0.19919501087221964    │
│      test_IPA_cause2      │   -0.08615107008082612    │
│         test_loss         │     2.665034770965576     │
└───────────────────────────┴───────────────────────────┘