Welcome to discotime’s documentation!

Discotime is a python package for discrete-time survival analysis with competing risks using neural networks. It relies on PyTorch Lightning to provide an easy-to-use interface, but can still be customized to your heart’s content. The packages contains an implementation of discrete time-to-event models for neural networks (using PyTorch), different evaluation metrics, and a couple of different competing risk datasets.

Getting started

  • Look at the examples/recipes (TODO)

  • Read our paper (TODO)

  • Check the repository (TODO)

Discrete-time survival analysis

Installation

Discotime supports Python 3.9 or newer.

We recommend to install via pip:

$ pip install discotime

You can also install the development version of discotime from the master branch of Git repository:

$ pip install git+https://github.com/peterchristofferholm/discotime.git

Getting started with discotime

import warnings

import numpy as np
import pandas as pd
import torch

from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

from discotime.datasets import Mgus2, DataConfig
from discotime.models import LitSurvModule, ModelConfig

For this example, we will be using the Mgus2 dataset from the survival pacakage in R. The dataset contains natural history of 1341 patients with monoclonal gammopathy of undetermined significance (MGUS) [1]. In the discotime package, this dataset is supplied as a LightningDataModule that automatically handles downloading, preprocessing, splitting into training and testing data, etc.

For initializiation of the data module, arguments are supplied in the form of a DataConfig object as follows.

mgus2dm = Mgus2(
    DataConfig(
        batch_size=128,
        n_time_bins=20,
        discretization_scheme="number",
    )
)
mgus2dm
<discotime.datasets.from_rstats.Mgus2 object at 0x7effb9c66770>
    n_features: 7, n_risks: 2, n_time_bins: 20

If we want to inspect the data, we can load the fit (train + val) and test data by calling setup on the object. During normal use, this is handled automatically by Lightning.Trainer.

mgus2dm.prepare_data()
mgus2dm.setup()
mgus2dm.dset_fit
<discotime.datasets.utils.SurvDataset at 0x7effa2047e50>
mgus2dm.dset_fit[0]
SurvData(features=array([ 0.        ,  1.        ,  1.0580127 ,  0.87056947, -0.7417322 ,
       -0.4214596 , -1.1587092 ], dtype=float32), event_time_disc=4, event_status_disc=1, event_time_cont=29, event_status_cont=1)

The data consist of features (a dataframe) and labels. The features should be fairly self explanatory, and the labels is a tuple of survival times and event indicators. In the Discotime package and across the other examples, an event/status of 0 indicates censoring.

To create the survival model, we use a different configuration object.

model = LitSurvModule(
    ModelConfig(
        learning_rate=1e-3,
        n_hidden_units=15,
        n_sequential_blocks=5,
        use_skip_connections=True,
    )
)

The last thing we need now is to instantiate the Lightning trainer. For this example, we will only be using a single CPU, but with the lightning Trainer it’s easy to scale training to multiple devices and/or GPUs.

early_stopping = EarlyStopping(
    monitor="val_loss", min_delta=0.001, patience=30, mode="min"
)

warnings.filterwarnings("ignore", "GPU available")
trainer = Trainer(
    accelerator="cpu",
    devices=1,
    max_epochs=3000,
    enable_checkpointing=False,
    enable_progress_bar=False,
    logger=False,
    reload_dataloaders_every_n_epochs=1,
    callbacks=[early_stopping],
)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
warnings.filterwarnings("ignore", ".*bottleneck")
trainer.fit(model, mgus2dm)
/home/pechris/work/discotime/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:411: UserWarning: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.
  warning_cache.warn(

  | Name   | Type | Params
--------------------------------
0 | _model | Net  | 0
--------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
__ = trainer.test(model, mgus2dm)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test_IPA_cause1      │    0.19919501087221964    │
│      test_IPA_cause2      │   -0.08615107008082612    │
│         test_loss         │     2.665034770965576     │
└───────────────────────────┴───────────────────────────┘

discotime

discotime package

Subpackages
discotime.datasets package
Submodules
discotime.datasets.from_rstats module
class discotime.datasets.from_rstats.Mgus2(data_config: DataConfig, data_dir: Path = PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/discotime/envs/latest/lib/python3.10/site-packages/discotime/datasets/_data'), seed: int = 13411341)[source]

Bases: LitSurvDataModule

Natural history of 1341 sequential patients with monoclonal gammopathy of undetermined significance (MGUS).

[1]: R. Kyle, T. Therneau, V. Rajkumar, J. Offord, D. Larson, M. Plevak,

and L. J. Melton III, A long-terms study of prognosis in monoclonal gammopathy of undertermined significance. New Engl J Med, 346:564-569 (2002).

property dset_fit: SurvDataset
property dset_test: SurvDataset
n_features = 7
n_risks = 2
predict_dataloader()[source]
prepare_data() None[source]

Fetch the built-in mgus2 from the R survival package.

The mgus2 data [1] is extracted from the rdata file and saved as a csv in the discotime installation directory. If the csv is already available, then the download logic is skipped.

[1]: Therneau T (2023). A Package for Survival Analysis in R.

setup(stage: str | None = None) None[source]
test_dataloader()[source]
property time_range
train_dataloader()[source]
val_dataloader()[source]
discotime.datasets.utils module
class discotime.datasets.utils.DataConfig(*, batch_size: int = 32, n_time_bins: int = 20, discretization_scheme: str = 'number', discretization_grid: list[float] | None = None, max_time: float | None = None)[source]

Bases: object

Configuration class for data modules.

batch_size: int = 32

The batch size defines the number of samples that will be propagated through the network at each training step.

discretization_grid: list[float] | None = None
discretization_scheme: str = 'number'
max_time: float | None = None
n_time_bins: int = 20

Specifies the size of the discretization grid. A default of around 20-30 usually works good.

class discotime.datasets.utils.LabelDiscretizer(scheme: str | None = None, n_bins: int | None = None, *, cut_points: Iterable[int | int64 | float | float64] | None = None, max_time: int | int64 | float | float64 | None = None)[source]

Bases: object

Discretize continous time/event pairs.

The class can either learn a discretization grid from the training data using one of the built-in discretization schemes, or the user can supply an iterable with cut points.

Implementation heavily inspired by pycox.preprocessing.label_tranform [1].

[1]: Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. “Time-to-event prediction with neural networks and Cox regression.” arXiv preprint arXiv:1907.00825 (2019).

property cuts: ndarray[Any, dtype[float64]]
fit(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) None[source]
fit_transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]][source]
property max_time: int | int64 | float | float64
transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]][source]
class discotime.datasets.utils.LitSurvDataModule[source]

Bases: LightningDataModule

property batch_size: int
property config: DataConfig
property cuts: ndarray[Any, dtype[float64]]
property lab_transformer: LabelTransformer
abstract property n_features: int
abstract property n_risks: int
property n_time_bins: int
abstract setup(stage: str | None = None) None[source]
abstract property time_range: tuple[int | numpy.int64 | float | numpy.float64, int | numpy.int64 | float | numpy.float64]
class discotime.datasets.utils.SurvDataset(features: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event_time: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event_status: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], discretizer: LabelTransformer)[source]

Bases: Dataset

Assemble a survival dataset for discrete-time survival analysis.

A discrete time survival dataset \(\mathfrak{D}\) is a set of \(n\) tuples \((t_{i}, \delta_{i}, \mathbf{x}_{i})\) where \((t_i = \min \{T_i, C_i\})\) is the event time, \(\delta_{i} \in \{0, ..., m\}\) is the event indicator (with \((\delta_i = 0)\) defined as censoring), and \(\mathbf{x}_{i} \in \mathbb{R}^d\) is a \(d\)-dimensional vector of time-independent predictors or covariates.

Parameters:
  • features – time-independent features.

  • event_time – follow-up time (continuous).

  • event_status – event indicator (0=censored, 1/2/…=competing risks).

  • discretizer – discretizer that follows the LabelTransformer protocol that convert continuous time/event tuples to their respective discretized versions. Typically this would be LabelDiscretizer unless a custom discretization object is used.

discotime.datasets.utils.default_fts_transformer()[source]
discotime.metrics package
Submodules
discotime.metrics.brier_score module
class discotime.metrics.brier_score.BrierScore(timepoints: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], causes: list[int] | None = None)[source]

Bases: object

Brier score of survival model with right-censored data.

In the case of right-censored data, possibly unknown time-dependent event status can be replaced with jackknife pseudovalues from the marginal Aalen-Johansen estimates [1].

Parameters:
  • timepoints – sequence of t timepoints to evaluate.

  • causes – causes for which the brier score is calculated.

Refs:

[1] Cortese, Giuliana, Thomas A. Gerds, and Per K. Andersen. “Comparing predictions among competing risks models with time‐dependent covariates.” Statistics in medicine 32.18 (2013): 3089-3101.

class Results(null, model)

Bases: tuple

model

Alias for field number 1

null

Alias for field number 0

discotime.metrics.ipa module
class discotime.metrics.ipa.IPA(timepoints: ndarray[Any, dtype[float64]], causes: list[int] | None = None, integrate: bool = True)[source]

Bases: object

Index of prediction accuracy (IPA) for right-censored survival models.

In the context of survival analysis, the IPA is also known as the Brier skill score (BSS).

discotime.models package
Submodules
discotime.models.components module
class discotime.models.components.Block(*, n_hidden_units: int, add_skip_connection: bool = True, activation_function: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.SiLU'>, batch_normalization: bool = True, dropout_rate: float | None = None)[source]

Bases: Module

Neural network building block for the Net class.

Parameters:
  • n_hidden_units – number of units in each hidden layer.

  • add_skip_connection – Defaults to True.

  • activation_function – Defaults to nn.SiLU.

  • batch_normalization – Should batch normalization be performed? Defaults to True.

  • dropout_rate – dropout_rate is being passed along to nn.Dropout(). If None, then dropout is not being used. Defaults to None.

forward(x: Tensor) Tensor[source]
discotime.models.components.negative_log_likelihood(logits: Tensor, time: Tensor, event: Tensor) Tensor[source]

Negative log-likelihood for logistic hazard model with competing risks.

The hazards are expected to be given as logits scale, i.e. they should not have been passed through torch.log_softmax() or similar.

An implementation of equation (8.6) from Tutz and Schmid [1], inspired by the one in pycox following Kvamme et. al. [2]

Parameters:

[1]: Tutz, Gerhard, and Matthias Schmid. Modeling discrete time-to-event data. New York: Springer, 2016.

[2]: Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. “Time-to-event prediction with neural networks and Cox regression.” arXiv preprint arXiv:1907.00825 (2019).

discotime.models.models module
class discotime.models.models.LitSurvModule(config: ModelConfig)[source]

Bases: LightningModule

configure_optimizers() Optimizer[source]
property data_cuts: Tensor
property datamodule: LitSurvDataModule
property eval_grid: Tensor
forward(x: Tensor) Tensor[source]
property model: Module
on_load_checkpoint(checkpoint: Dict[str, Any]) None[source]
on_save_checkpoint(checkpoint: Dict[str, Any]) None[source]
on_test_epoch_end() None[source]
on_validation_epoch_end() None[source]
setup(stage)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is where the PyTorch model gets instantiated.

Parameters:

stage (str) – either "fit", "validate" "test" or "predict"

test_step(batch: tuple[torch.Tensor, ...], batch_idx: int) Tensor[source]
training_step(batch: tuple[torch.Tensor, ...], batch_idx: int) Tensor[source]
validation_step(batch: tuple[torch.Tensor, ...], batch_idx: int) Tensor[source]
class discotime.models.models.ModelConfig(*, learning_rate: float = 0.001, activation_function: str = 'SiLU', n_sequential_blocks: int = 2, n_hidden_units: int = 20, dropout_rate: float | None = None, batch_normalization: bool = True, use_skip_connections: bool = False, evaluation_grid_size: int = 50, evaluation_grid_cuts: int | None = None, n_time_bins: int | None = None, n_risks: int | None = None)[source]

Bases: object

activation_function: str = 'SiLU'

Name of activation function (str). The name needs to match one of the activation functions in the torch.nn module.

batch_normalization: bool = True

Use batch normalization in each block? Default is True.

dropout_rate: float | None = None

Should nn.Dropout() be used in each block, and if so what is the rate of dropout? If None then dropout is not used. Default is None.

evaluation_grid_cuts: int | None = None

Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer n is passed, then n evenly distributed timepoint are chosen from the range of the data.

evaluation_grid_size: int = 50

Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer n is passed, then n evenly distributed timepoint are chosen from the range of the data.

learning_rate: float = 0.001
n_hidden_units: int = 20

Number of neurons in each of the hidden layers.

For ease of of use, the size of each hidden layer have been constrained to to be the exact same size as all the others. Default is 20.

n_risks: int | None = None

How many competing risks are we dealing with?

If None, which is the default, then we try to get the information from the attached datamodule during setup.

n_sequential_blocks: int = 2

Number of neural-network layer blocks.

Each block is modelled after a ResNet skip-block and contains the following key elements:

Sequential(
    (0): LazyBatchNorm1d()
    (1): LazyLinear()
    (2): DropOut()
    (3): SiLU()
    (4): LazyLinear()
    (5): DropOut()
)

Per default the output of the model is self.act_fn(x + self.net(x)) but the skip part can be removed by setting use_skip_connections to False. n_hidden_units controls the size of the linear layers in each block. See components.Block for more details on the actual implementation.

n_time_bins: int | None = None

Number of time bins in discretization grid.

If None, which is the default, then we try to get the information from the attached datamodule during setup.

use_skip_connections: bool = False

Toggle the use of skip-connections in the model blocks.

discotime.utils package
Submodules
discotime.utils.estimators module
class discotime.utils.estimators.AalenJohansen(time: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], n_causes: int | int64 | None = None)[source]

Bases: object

Obtain cumulative incidence curves with the Aalen-Johansen method.

Parameters:
  • time – event times

  • event – event indicator (0/1/../c) with 0=censored

  • n_causes – how many causes should be included? If None (the default) then all observed causes are include.

class discotime.utils.estimators.KaplanMeier(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64])[source]

Bases: object

Simple implementation of the Kaplan-Meier estimator.

Parameters:
  • time – event times.

  • event – event indicator (0/1) where 0 is censoring.

Example

>>> km = KaplanMeier(time=[0, 1.5, 1.3, 3], event=[0, 1, 0, 0])
>>> km(0)
array([1.])
>>> km([0, 1.0, 1.1, 1.5])
array([1. , 1. , 1. , 0.5])
percentile(p: ~collections.abc.Iterable[int | ~numpy.int64 | float | ~numpy.float64], dtype=<class 'numpy.float64'>) ndarray[Any, dtype[float64]][source]

Obtain approximate timepoint t such that P(t) = p.

The stepwise Kaplan-Meier estimator is piecewise linearly interpolated such that unique timepoints can be obtained.

discotime.utils.estimators.interpolate2d(x: Tensor, xp: Tensor, yp: Tensor)[source]

Perform stepwise linear interpolation of a discrete function.

_xp_ and _yp_ are tensors of values used to approximate f: y = f(x). This functions uses interpolation to find the value of new points x.

Parameters:
  • x (torch.Tensor) – an 1D tensor real values.

  • xp (torch.Tensor) – an 1D tensor of real values.

  • yp (torch.Tensor) – an ND tensor of real values. The length of yp along the second axis (dim=1) must have the same length as xp.

discotime.utils.misc module
class discotime.utils.misc.BaseFabricTrainer(module: LitSurvModule, dset_train: SurvDataset, dset_valid: SurvDataset, batch_size: int)[source]

Bases: object

Bare-bones trainer class using Fabric.

This trainer class supports only a very limited set of the full lightning functionality. It can be useful for quick prototyping or for k-fold cross-validation as exemplified in the scripts that can be found in the /experiments folder on github.

class discotime.utils.misc.EarlyStopping(patience: int, min_delta: float, direction: str = 'minimize')[source]

Bases: object

Lightning-free implementation of early stopping.

To be used with the bare-bones simple trainer that have no support for lightning callbacks.

should_stop(metric) bool[source]
class discotime.utils.misc.OptunaCallback(trial: Trial, monitor: str, minvalue: float = -inf, maxvalue: float = inf)[source]

Bases: Callback

Lightning callback for hyperparameter optimization with Optuna.

Parameters:
  • trial – A Trial corresponding to the current evaluation of the objective function.

  • monitor – An evaluation metric for pruning, e.g., val_loss or val_acc.

on_validation_end(trainer: Trainer, module: LightningModule) None[source]
discotime.utils.misc.clamp(value: float, minvalue: float, maxvalue: float)[source]
discotime.utils.misc.get_last_intermediate_value(trial: Trial)[source]
discotime.utils.misc.recursive_defaultdict()[source]
discotime.utils.misc.split_dataset(dataset: Dataset[T], k: int = 5) Generator[tuple[torch.utils.data.dataset.ConcatDataset[T], torch.utils.data.dataset.Subset[T]], None, None][source]

Split dataset into k train/validation splits

discotime.utils.misc.update_mapping(d: Mapping, u: Mapping) Mapping[source]

Update nested mappings recursively.

Parameters:
  • d (Mapping) – dictionary/mapping of things.

  • u (Mapping) – other dictionary/mapping of things to update d with.

Example

>>> foo = {"dog" : {"color" : "black", "age" : 10}}
>>> bar = {"dog" : {"age" : 11}, "cat" : None}
>>> update_mapping(foo, bar)
{'dog': {'color': 'black', 'age': 11}, 'cat': None}
discotime.utils.typing module
class discotime.utils.typing.LabelTransformer(*args, **kwargs)[source]

Bases: Transformer, Protocol

property cuts: ndarray[Any, dtype[floating]]
fit(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) None[source]
fit_transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]][source]
property max_time: int | int64 | float | float64
transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]][source]
class discotime.utils.typing.SurvData(features, event_time_disc, event_status_disc, event_time_cont, event_status_cont)[source]

Bases: NamedTuple

event_status_cont: ndarray[Any, dtype[int64]]

Alias for field number 4

event_status_disc: ndarray[Any, dtype[int64]]

Alias for field number 2

event_time_cont: ndarray[Any, dtype[float64]]

Alias for field number 3

event_time_disc: ndarray[Any, dtype[int64]]

Alias for field number 1

features: ndarray[Any, dtype[float64]]

Alias for field number 0

class discotime.utils.typing.Transformer(*args, **kwargs)[source]

Bases: Protocol

fit(*args, **kwargs) None[source]
fit_transform(*args, **kwargs) Any[source]
transform(*args, **kwargs) Any[source]

License

Reference

Indices and tables