from math import inf
from collections import deque, defaultdict
from collections.abc import Mapping
from typing import Generator, TypeVar, Optional
from numbers import Number
import lightning.pytorch as pl
from lightning.pytorch import cli
from optuna.trial import Trial
from optuna.exceptions import TrialPruned
from lightning.pytorch.callbacks import Callback
import lightning as L
import torch
from torch.utils.data import (
DataLoader,
Dataset,
random_split,
ConcatDataset,
Subset,
)
from discotime.models import LitSurvModule
from discotime.datasets.utils import SurvDataset
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
[docs]def update_mapping(d: Mapping, u: Mapping) -> Mapping:
"""Update nested mappings recursively.
Args:
d (Mapping): dictionary/mapping of things.
u (Mapping): other dictionary/mapping of things to update `d` with.
Example:
>>> foo = {"dog" : {"color" : "black", "age" : 10}}
>>> bar = {"dog" : {"age" : 11}, "cat" : None}
>>> update_mapping(foo, bar)
{'dog': {'color': 'black', 'age': 11}, 'cat': None}
"""
for k, v in u.items():
if isinstance(v, Mapping):
d[k] = update_mapping(d.get(k, {}), v) # type: ignore
else:
d[k] = v # type: ignore
return d
[docs]def recursive_defaultdict():
return defaultdict(recursive_defaultdict)
[docs]class BaseFabricTrainer:
"""Bare-bones trainer class using Fabric.
This trainer class supports only a very limited set of the full lightning
functionality. It can be useful for quick prototyping or for k-fold
cross-validation as exemplified in the scripts that can be found in the
``/experiments`` folder on github.
"""
def __init__(
self,
module: LitSurvModule,
dset_train: SurvDataset,
dset_valid: SurvDataset,
batch_size: int,
) -> None:
self.fabric = L.Fabric(accelerator="cpu", devices=1)
self.train_dl = DataLoader(
dset_train, shuffle=True, batch_size=batch_size, drop_last=True
)
self.valid_dl = DataLoader(
dset_valid, shuffle=True, batch_size=batch_size
)
self.module = module
def __iter__(self) -> Generator[torch.Tensor, None, None]:
self.fabric.launch()
train_dl = self.fabric.setup_dataloaders(self.train_dl)
valid_dl = self.fabric.setup_dataloaders(self.valid_dl)
module, optimizer = self.fabric.setup(
self.module,
self.module.configure_optimizers(),
)
assert isinstance(train_dl, DataLoader)
assert isinstance(valid_dl, DataLoader)
while True:
self.module.train()
with torch.set_grad_enabled(True):
for batch_idx, batch in enumerate(train_dl):
loss = module.training_step(batch, batch_idx)
optimizer.zero_grad()
self.fabric.backward(loss)
optimizer.step()
with torch.set_grad_enabled(False):
module.eval()
losses = (
module.validation_step(batch, idx) * len(batch[1])
for idx, batch in enumerate(valid_dl)
)
yield sum(losses) / len(valid_dl.sampler) # type: ignore
[docs]def split_dataset(
dataset: Dataset[T], k: int = 5
) -> Generator[tuple[ConcatDataset[T], Subset[T]], None, None]:
"""Split dataset into k train/validation splits"""
splits = deque(random_split(dataset, lengths=[1 / k] * k))
for _ in range(k):
dset_valid: Subset[T] = splits.popleft()
dset_train: ConcatDataset[T] = ConcatDataset(splits)
yield (dset_train, dset_valid)
splits.append(dset_valid)
[docs]class EarlyStopping:
"""Lightning-free implementation of early stopping.
To be used with the bare-bones simple trainer that have no support for
lightning callbacks.
"""
def __init__(
self, patience: int, min_delta: float, direction: str = "minimize"
) -> None:
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if direction not in (allowed := {"minimize", "maximize"}):
raise ValueError(f"direction can be {allowed}, got {direction}")
self.direction = direction
self.best = inf if direction == "minimize" else -inf
[docs] def should_stop(self, metric) -> bool:
match self.direction:
case "minimize":
improvement = metric < (self.best - self.min_delta)
case "maximize":
improvement = metric > (self.best + self.min_delta)
if improvement:
self.wait = 0
self.best = metric
else:
self.wait += 1
return self.wait >= self.patience
[docs]def clamp(value: float, minvalue: float, maxvalue: float):
return max(minvalue, min(value, maxvalue))
[docs]class OptunaCallback(Callback):
"""Lightning callback for hyperparameter optimization with Optuna.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current
evaluation of the objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or
``val_acc``.
"""
def __init__(
self,
trial: Trial,
monitor: str,
minvalue: float = -inf,
maxvalue: float = inf,
) -> None:
super().__init__()
self.minvalue = minvalue
self.maxvalue = maxvalue
self.trial = trial
self.monitor = monitor
[docs] def on_validation_end(
self, trainer: pl.Trainer, module: pl.LightningModule
) -> None:
if trainer.sanity_checking:
return
metric = trainer.callback_metrics.get(self.monitor)
if metric is None:
raise ValueError(
f"The metric '{self.monitor}' is not in the evaluation logs."
f"Available metrics are: {trainer.callback_metrics}"
)
epoch = module.current_epoch
self.trial.report(
clamp(metric.item(), self.minvalue, self.maxvalue),
step=epoch,
)
if not self.trial.should_prune():
return
raise TrialPruned(f"Trial was pruned at epoch {epoch}.")