Welcome to discotime’s documentation!
Discotime is a python package for discrete-time survival analysis with competing risks using neural networks. It relies on PyTorch Lightning to provide an easy-to-use interface, but can still be customized to your heart’s content. The packages contains an implementation of discrete time-to-event models for neural networks (using PyTorch), different evaluation metrics, and a couple of different competing risk datasets.
Getting started
Look at the examples/recipes (TODO)
Read our paper (TODO)
Check the repository (TODO)
Discrete-time survival analysis
Installation
Discotime supports Python 3.9 or newer.
We recommend to install via pip:
$ pip install discotime
You can also install the development version of discotime from the master branch of Git repository:
$ pip install git+https://github.com/peterchristofferholm/discotime.git
Getting started with discotime
import warnings
import numpy as np
import pandas as pd
import torch
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from discotime.datasets import Mgus2, DataConfig
from discotime.models import LitSurvModule, ModelConfig
For this example, we will be using the Mgus2
dataset from the
survival
pacakage in R
. The dataset contains natural history of
1341 patients with monoclonal gammopathy of undetermined significance
(MGUS) [1]. In the discotime
package, this dataset is supplied as a
LightningDataModule
that automatically handles downloading,
preprocessing, splitting into training and testing data, etc.
For initializiation of the data module, arguments are supplied in the
form of a DataConfig
object as follows.
mgus2dm = Mgus2(
DataConfig(
batch_size=128,
n_time_bins=20,
discretization_scheme="number",
)
)
mgus2dm
<discotime.datasets.from_rstats.Mgus2 object at 0x7effb9c66770>
n_features: 7, n_risks: 2, n_time_bins: 20
If we want to inspect the data, we can load the fit (train + val) and
test data by calling setup on the object. During normal use, this is
handled automatically by Lightning.Trainer
.
mgus2dm.prepare_data()
mgus2dm.setup()
mgus2dm.dset_fit
<discotime.datasets.utils.SurvDataset at 0x7effa2047e50>
mgus2dm.dset_fit[0]
SurvData(features=array([ 0. , 1. , 1.0580127 , 0.87056947, -0.7417322 ,
-0.4214596 , -1.1587092 ], dtype=float32), event_time_disc=4, event_status_disc=1, event_time_cont=29, event_status_cont=1)
The data consist of features (a dataframe) and labels. The features
should be fairly self explanatory, and the labels is a tuple of survival
times and event indicators. In the Discotime
package and across the
other examples, an event/status of 0 indicates censoring.
To create the survival model, we use a different configuration object.
model = LitSurvModule(
ModelConfig(
learning_rate=1e-3,
n_hidden_units=15,
n_sequential_blocks=5,
use_skip_connections=True,
)
)
The last thing we need now is to instantiate the Lightning trainer. For this example, we will only be using a single CPU, but with the lightning Trainer it’s easy to scale training to multiple devices and/or GPUs.
early_stopping = EarlyStopping(
monitor="val_loss", min_delta=0.001, patience=30, mode="min"
)
warnings.filterwarnings("ignore", "GPU available")
trainer = Trainer(
accelerator="cpu",
devices=1,
max_epochs=3000,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
reload_dataloaders_every_n_epochs=1,
callbacks=[early_stopping],
)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
warnings.filterwarnings("ignore", ".*bottleneck")
trainer.fit(model, mgus2dm)
/home/pechris/work/discotime/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:411: UserWarning: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.
warning_cache.warn(
| Name | Type | Params
--------------------------------
0 | _model | Net | 0
--------------------------------
0 Trainable params
0 Non-trainable params
0 Total params
0.000 Total estimated model params size (MB)
__ = trainer.test(model, mgus2dm)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_IPA_cause1 │ 0.19919501087221964 │
│ test_IPA_cause2 │ -0.08615107008082612 │
│ test_loss │ 2.665034770965576 │
└───────────────────────────┴───────────────────────────┘
discotime
discotime package
Subpackages
discotime.datasets package
Submodules
discotime.datasets.from_rstats module
- class discotime.datasets.from_rstats.Mgus2(data_config: DataConfig, data_dir: Path = PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/discotime/envs/latest/lib/python3.10/site-packages/discotime/datasets/_data'), seed: int = 13411341)[source]
Bases:
LitSurvDataModule
Natural history of 1341 sequential patients with monoclonal gammopathy of undetermined significance (MGUS).
- [1]: R. Kyle, T. Therneau, V. Rajkumar, J. Offord, D. Larson, M. Plevak,
and L. J. Melton III, A long-terms study of prognosis in monoclonal gammopathy of undertermined significance. New Engl J Med, 346:564-569 (2002).
- property dset_fit: SurvDataset
- property dset_test: SurvDataset
- n_features = 7
- n_risks = 2
- prepare_data() None [source]
Fetch the built-in mgus2 from the R survival package.
The mgus2 data [1] is extracted from the rdata file and saved as a csv in the discotime installation directory. If the csv is already available, then the download logic is skipped.
[1]: Therneau T (2023). A Package for Survival Analysis in R.
- property time_range
discotime.datasets.utils module
- class discotime.datasets.utils.DataConfig(*, batch_size: int = 32, n_time_bins: int = 20, discretization_scheme: str = 'number', discretization_grid: list[float] | None = None, max_time: float | None = None)[source]
Bases:
object
Configuration class for data modules.
- class discotime.datasets.utils.LabelDiscretizer(scheme: str | None = None, n_bins: int | None = None, *, cut_points: Iterable[int | int64 | float | float64] | None = None, max_time: int | int64 | float | float64 | None = None)[source]
Bases:
object
Discretize continous time/event pairs.
The class can either learn a discretization grid from the training data using one of the built-in discretization schemes, or the user can supply an iterable with cut points.
Implementation heavily inspired by pycox.preprocessing.label_tranform [1].
[1]: Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. “Time-to-event prediction with neural networks and Cox regression.” arXiv preprint arXiv:1907.00825 (2019).
- fit_transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]] [source]
- transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]] [source]
- class discotime.datasets.utils.LitSurvDataModule[source]
Bases:
LightningDataModule
- property config: DataConfig
- property lab_transformer: LabelTransformer
- class discotime.datasets.utils.SurvDataset(features: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event_time: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event_status: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], discretizer: LabelTransformer)[source]
Bases:
Dataset
Assemble a survival dataset for discrete-time survival analysis.
A discrete time survival dataset \(\mathfrak{D}\) is a set of \(n\) tuples \((t_{i}, \delta_{i}, \mathbf{x}_{i})\) where \((t_i = \min \{T_i, C_i\})\) is the event time, \(\delta_{i} \in \{0, ..., m\}\) is the event indicator (with \((\delta_i = 0)\) defined as censoring), and \(\mathbf{x}_{i} \in \mathbb{R}^d\) is a \(d\)-dimensional vector of time-independent predictors or covariates.
- Parameters:
features – time-independent features.
event_time – follow-up time (continuous).
event_status – event indicator (0=censored, 1/2/…=competing risks).
discretizer – discretizer that follows the
LabelTransformer
protocol that convert continuous time/event tuples to their respective discretized versions. Typically this would beLabelDiscretizer
unless a custom discretization object is used.
discotime.metrics package
Submodules
discotime.metrics.brier_score module
- class discotime.metrics.brier_score.BrierScore(timepoints: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], causes: list[int] | None = None)[source]
Bases:
object
Brier score of survival model with right-censored data.
In the case of right-censored data, possibly unknown time-dependent event status can be replaced with jackknife pseudovalues from the marginal Aalen-Johansen estimates [1].
- Parameters:
timepoints – sequence of t timepoints to evaluate.
causes – causes for which the brier score is calculated.
- Refs:
[1] Cortese, Giuliana, Thomas A. Gerds, and Per K. Andersen. “Comparing predictions among competing risks models with time‐dependent covariates.” Statistics in medicine 32.18 (2013): 3089-3101.
discotime.metrics.ipa module
- class discotime.metrics.ipa.IPA(timepoints: ndarray[Any, dtype[float64]], causes: list[int] | None = None, integrate: bool = True)[source]
Bases:
object
Index of prediction accuracy (IPA) for right-censored survival models.
In the context of survival analysis, the IPA is also known as the Brier skill score (BSS).
discotime.models package
Submodules
discotime.models.components module
- class discotime.models.components.Block(*, n_hidden_units: int, add_skip_connection: bool = True, activation_function: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.SiLU'>, batch_normalization: bool = True, dropout_rate: float | None = None)[source]
Bases:
Module
Neural network building block for the Net class.
- Parameters:
n_hidden_units – number of units in each hidden layer.
add_skip_connection – Defaults to True.
activation_function – Defaults to nn.SiLU.
batch_normalization – Should batch normalization be performed? Defaults to True.
dropout_rate – dropout_rate is being passed along to nn.Dropout(). If None, then dropout is not being used. Defaults to None.
- discotime.models.components.negative_log_likelihood(logits: Tensor, time: Tensor, event: Tensor) Tensor [source]
Negative log-likelihood for logistic hazard model with competing risks.
The hazards are expected to be given as logits scale, i.e. they should not have been passed through
torch.log_softmax()
or similar.An implementation of equation (8.6) from Tutz and Schmid [1], inspired by the one in
pycox
following Kvamme et. al. [2]- Parameters:
logits (
torch.Tensor
) – input logitstime (
torch.Tensor
) – discretized event timesevent (
torch.Tensor
) – events (0=censored, 1/2/…=events)
[1]: Tutz, Gerhard, and Matthias Schmid. Modeling discrete time-to-event data. New York: Springer, 2016.
[2]: Kvamme, Håvard, Ørnulf Borgan, and Ida Scheel. “Time-to-event prediction with neural networks and Cox regression.” arXiv preprint arXiv:1907.00825 (2019).
discotime.models.models module
- class discotime.models.models.LitSurvModule(config: ModelConfig)[source]
Bases:
LightningModule
- property datamodule: LitSurvDataModule
- class discotime.models.models.ModelConfig(*, learning_rate: float = 0.001, activation_function: str = 'SiLU', n_sequential_blocks: int = 2, n_hidden_units: int = 20, dropout_rate: float | None = None, batch_normalization: bool = True, use_skip_connections: bool = False, evaluation_grid_size: int = 50, evaluation_grid_cuts: int | None = None, n_time_bins: int | None = None, n_risks: int | None = None)[source]
Bases:
object
- activation_function: str = 'SiLU'
Name of activation function (str). The name needs to match one of the activation functions in the
torch.nn
module.
- dropout_rate: float | None = None
Should
nn.Dropout()
be used in each block, and if so what is the rate of dropout? IfNone
then dropout is not used. Default is None.
- evaluation_grid_cuts: int | None = None
Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer n is passed, then n evenly distributed timepoint are chosen from the range of the data.
- evaluation_grid_size: int = 50
Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer n is passed, then n evenly distributed timepoint are chosen from the range of the data.
Number of neurons in each of the hidden layers.
For ease of of use, the size of each hidden layer have been constrained to to be the exact same size as all the others. Default is 20.
- n_risks: int | None = None
How many competing risks are we dealing with?
If
None
, which is the default, then we try to get the information from the attached datamodule during setup.
- n_sequential_blocks: int = 2
Number of neural-network layer blocks.
Each block is modelled after a ResNet skip-block and contains the following key elements:
Sequential( (0): LazyBatchNorm1d() (1): LazyLinear() (2): DropOut() (3): SiLU() (4): LazyLinear() (5): DropOut() )
Per default the output of the model is
self.act_fn(x + self.net(x))
but the skip part can be removed by settinguse_skip_connections
toFalse
.n_hidden_units
controls the size of the linear layers in each block. Seecomponents.Block
for more details on the actual implementation.
discotime.utils package
Submodules
discotime.utils.estimators module
- class discotime.utils.estimators.AalenJohansen(time: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], event: _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], n_causes: int | int64 | None = None)[source]
Bases:
object
Obtain cumulative incidence curves with the Aalen-Johansen method.
- Parameters:
time – event times
event – event indicator (0/1/../c) with 0=censored
n_causes – how many causes should be included? If None (the default) then all observed causes are include.
- class discotime.utils.estimators.KaplanMeier(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64])[source]
Bases:
object
Simple implementation of the Kaplan-Meier estimator.
- Parameters:
time – event times.
event – event indicator (0/1) where 0 is censoring.
Example
>>> km = KaplanMeier(time=[0, 1.5, 1.3, 3], event=[0, 1, 0, 0]) >>> km(0) array([1.]) >>> km([0, 1.0, 1.1, 1.5]) array([1. , 1. , 1. , 0.5])
- percentile(p: ~collections.abc.Iterable[int | ~numpy.int64 | float | ~numpy.float64], dtype=<class 'numpy.float64'>) ndarray[Any, dtype[float64]] [source]
Obtain approximate timepoint t such that P(t) = p.
The stepwise Kaplan-Meier estimator is piecewise linearly interpolated such that unique timepoints can be obtained.
- discotime.utils.estimators.interpolate2d(x: Tensor, xp: Tensor, yp: Tensor)[source]
Perform stepwise linear interpolation of a discrete function.
_xp_ and _yp_ are tensors of values used to approximate f: y = f(x). This functions uses interpolation to find the value of new points x.
- Parameters:
x (
torch.Tensor
) – an 1D tensor real values.xp (
torch.Tensor
) – an 1D tensor of real values.yp (
torch.Tensor
) – an ND tensor of real values. The length of yp along the second axis (dim=1) must have the same length as xp.
discotime.utils.misc module
- class discotime.utils.misc.BaseFabricTrainer(module: LitSurvModule, dset_train: SurvDataset, dset_valid: SurvDataset, batch_size: int)[source]
Bases:
object
Bare-bones trainer class using Fabric.
This trainer class supports only a very limited set of the full lightning functionality. It can be useful for quick prototyping or for k-fold cross-validation as exemplified in the scripts that can be found in the
/experiments
folder on github.
- class discotime.utils.misc.EarlyStopping(patience: int, min_delta: float, direction: str = 'minimize')[source]
Bases:
object
Lightning-free implementation of early stopping.
To be used with the bare-bones simple trainer that have no support for lightning callbacks.
- class discotime.utils.misc.OptunaCallback(trial: Trial, monitor: str, minvalue: float = -inf, maxvalue: float = inf)[source]
Bases:
Callback
Lightning callback for hyperparameter optimization with Optuna.
- Parameters:
trial – A
Trial
corresponding to the current evaluation of the objective function.monitor – An evaluation metric for pruning, e.g.,
val_loss
orval_acc
.
- on_validation_end(trainer: Trainer, module: LightningModule) None [source]
- discotime.utils.misc.split_dataset(dataset: Dataset[T], k: int = 5) Generator[tuple[torch.utils.data.dataset.ConcatDataset[T], torch.utils.data.dataset.Subset[T]], None, None] [source]
Split dataset into k train/validation splits
- discotime.utils.misc.update_mapping(d: Mapping, u: Mapping) Mapping [source]
Update nested mappings recursively.
- Parameters:
d (Mapping) – dictionary/mapping of things.
u (Mapping) – other dictionary/mapping of things to update d with.
Example
>>> foo = {"dog" : {"color" : "black", "age" : 10}} >>> bar = {"dog" : {"age" : 11}, "cat" : None} >>> update_mapping(foo, bar) {'dog': {'color': 'black', 'age': 11}, 'cat': None}
discotime.utils.typing module
- class discotime.utils.typing.LabelTransformer(*args, **kwargs)[source]
Bases:
Transformer
,Protocol
- fit_transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]] [source]
- transform(time: Iterable[int | int64 | float | float64], event: Iterable[int | int64]) tuple[numpy.ndarray[Any, numpy.dtype[numpy.integer]], numpy.ndarray[Any, numpy.dtype[numpy.integer]]] [source]
- class discotime.utils.typing.SurvData(features, event_time_disc, event_status_disc, event_time_cont, event_status_cont)[source]
Bases:
NamedTuple