Source code for discotime.models.models

from typing import Optional, Any, Iterable, Callable, Dict
from dataclasses import dataclass
import warnings

import torch
from torch.nn import functional as F
from torch import nn
from lightning import pytorch as pl
import numpy as np

from discotime.models.components import Net, negative_log_likelihood
from discotime.metrics import IPA
from discotime.datasets import LitSurvDataModule
from discotime.utils import interpolate2d


[docs]@dataclass(kw_only=True) class ModelConfig: learning_rate: float = 1e-3 activation_function: str = "SiLU" """Name of activation function (str). The name needs to match one of the activation functions in the ``torch.nn`` module. """ n_sequential_blocks: int = 2 """Number of neural-network layer blocks. Each block is modelled after a ResNet skip-block and contains the following key elements: .. code-block :: Sequential( (0): LazyBatchNorm1d() (1): LazyLinear() (2): DropOut() (3): SiLU() (4): LazyLinear() (5): DropOut() ) Per default the output of the model is ``self.act_fn(x + self.net(x))`` but the skip part can be removed by setting :attr:`use_skip_connections` to ``False``. :attr:`n_hidden_units` controls the size of the linear layers in each block. See :class:`.components.Block` for more details on the actual implementation. """ n_hidden_units: int = 20 """Number of neurons in each of the hidden layers. For ease of of use, the size of each hidden layer have been constrained to to be the exact same size as all the others. Default is 20. """ dropout_rate: Optional[float] = None """Should ``nn.Dropout()`` be used in each block, and if so what is the rate of dropout? If ``None`` then dropout is not used. Default is None. """ batch_normalization: bool = True """Use batch normalization in each block? Default is True.""" use_skip_connections: bool = False """Toggle the use of skip-connections in the model blocks.""" evaluation_grid_size: int = 50 """Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer `n` is passed, then `n` evenly distributed timepoint are chosen from the range of the data. """ evaluation_grid_cuts: Optional[int] = None """Either the length of (int) or the specific grid (seq of floats) at which model metrics are calculated. If an integer `n` is passed, then `n` evenly distributed timepoint are chosen from the range of the data. """ n_time_bins: Optional[int] = None """Number of time bins in discretization grid. If ``None``, which is the default, then we try to get the information from the attached datamodule during setup. """ n_risks: Optional[int] = None """How many competing risks are we dealing with? If ``None``, which is the default, then we try to get the information from the attached datamodule during setup. """
[docs]class LitSurvModule(pl.LightningModule): def __init__(self, config: ModelConfig) -> None: super().__init__() self.config = config self.learning_rate = config.learning_rate self._loss = negative_log_likelihood self._metrics: dict[str, Callable] = {} self._eval_grid: Optional[torch.Tensor] = None self._data_cuts: Optional[torch.Tensor] = None self._model: Optional[nn.Module] = None self._datamodule: Optional[LitSurvDataModule] = None self._test_step_outputs: list[tuple[torch.Tensor, ...]] = [] self._validation_step_outputs: list[tuple[torch.Tensor, ...]] = [] self.save_hyperparameters() @property def datamodule(self) -> LitSurvDataModule: if self._datamodule is not None: return self._datamodule return self.trainer.datamodule # type: ignore @datamodule.setter def datamodule(self, value: LitSurvDataModule) -> None: self._datamodule = value @property def model(self) -> nn.Module: if self._model is None: raise AssertionError( "`self.model` have not been instantiated." "Try calling `self.setup()` first." ) return self._model @property def eval_grid(self) -> torch.Tensor: if self._eval_grid is None: raise AssertionError( ( "`self.eval_grid` has not been instantiated." "Try calling `self.setup()` first." ) ) return self._eval_grid @property def data_cuts(self) -> torch.Tensor: if self._data_cuts is None: raise AssertionError( ( "`self.data_cuts` has not been instantiated." "Try calling `self.setup()` first." ) ) return self._data_cuts
[docs] def setup(self, stage): """Called at the beginning of fit (train + validate), validate, test, or predict. This is where the PyTorch model gets instantiated. Args: stage (`str`): either ``"fit"``, ``"validate"`` ``"test"`` or ``"predict"`` """ if not isinstance((conf := self.config), ModelConfig): raise TypeError("`config` has wrong type!") def _get_data_attribute(attr: str) -> int: """Get attribute from config or attached datamodule.""" if (value := getattr(conf, attr)) is None: try: self.datamodule.setup() # fit transformers return getattr(self.datamodule, attr) except: raise AssertionError(f"`{attr}` could not be found.") return value self.n_time_bins = _get_data_attribute("n_time_bins") self.n_risks = _get_data_attribute("n_risks") if not self._model: with warnings.catch_warnings(): warnings.simplefilter("ignore") self._model = Net( n_out_features=self.n_time_bins * (self.n_risks + 1), n_hidden_units=conf.n_hidden_units, add_skip_connection=conf.use_skip_connections, activation_function=getattr(nn, conf.activation_function), batch_normalization=conf.batch_normalization, dropout_rate=conf.dropout_rate, n_blocks=conf.n_sequential_blocks, ) if stage in {"fit", "test", "validate"}: self.datamodule.setup() # get attributes from the datamodule self._data_cuts = torch.as_tensor(self.datamodule.cuts) if conf.evaluation_grid_cuts is not None: self._eval_grid = torch.as_tensor(conf.evaluation_grid_cuts) else: start, end = self.datamodule.time_range self._eval_grid = torch.linspace( start=start, end=end, steps=conf.evaluation_grid_size # type: ignore ) # instantiate metrics self._metrics["IPA"] = IPA( np.asarray(self.eval_grid, dtype=np.float_), integrate=True )
[docs] def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["model"] = self._model
[docs] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self._model = checkpoint["model"]
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x).view(-1, self.n_time_bins, self.n_risks + 1)
[docs] def training_step( self, batch: tuple[torch.Tensor, ...], batch_idx: int ) -> torch.Tensor: x, dt, de, _, _ = batch train_loss = self._loss(self(x), dt, de) self.log("train_loss", train_loss) return train_loss
[docs] def validation_step( self, batch: tuple[torch.Tensor, ...], batch_idx: int ) -> torch.Tensor: x, dt, de, ct, ce = batch val_loss = self._loss(self(x), dt, de) self.log("val_loss", val_loss) # aggregate estimates + labels for batch metrics self._validation_step_outputs.append( (self._predict_estimates(x, self.eval_grid), ct, ce) ) return val_loss
[docs] def test_step( self, batch: tuple[torch.Tensor, ...], batch_idx: int ) -> torch.Tensor: x, dt, de, ct, ce = batch test_loss = self._loss(self(x), dt, de) self.log("test_loss", test_loss) # aggregate estimates + labels for batch metrics self._test_step_outputs.append( (self._predict_estimates(x, self.eval_grid), ct, ce) ) return test_loss
[docs] def on_test_epoch_end(self) -> None: estimates, ct, ce = map( lambda x: torch.cat(x).cpu(), zip(*self._test_step_outputs) ) self._test_step_outputs.clear() # empty list for metric, metric_fn in self._metrics.items(): values = metric_fn(estimates, ct, ce) for cause, value in enumerate(values, start=1): self.log(f"test_{metric}_cause{cause}", value)
[docs] def on_validation_epoch_end(self) -> None: estimates, ct, ce = map( lambda x: torch.cat(x).cpu(), zip(*self._validation_step_outputs) ) self._validation_step_outputs.clear() # empty list for metric, metric_fn in self._metrics.items(): values = metric_fn(estimates, ct, ce) for cause, value in enumerate(values, start=1): self.log(f"val_{metric}_cause{cause}", value)
def _predict_estimates( self, x: torch.Tensor, timepoints: torch.Tensor ) -> torch.Tensor: c_haz = F.softmax(self(x), dim=-1) s = torch.cumprod(c_haz[..., [0]], dim=1) s_lag = torch.roll(s, shifts=1, dims=1) s_lag[:, 0, :] = 1 # convert conditional hazards to cause-specific cumulative incidence proba = torch.cumsum(s_lag * c_haz[..., 1:], dim=1) proba = F.pad(proba, pad=(0, 0, 1, 0)) return interpolate2d(timepoints, self.data_cuts, proba)
[docs] def configure_optimizers(self) -> torch.optim.Optimizer: return torch.optim.Adam(self.model.parameters(), self.learning_rate)