Source code for discotime.utils.estimators

from typing import Optional
from collections.abc import Iterable

import numpy as np
import numpy.typing as npt
import torch
from einops import repeat

Int = int | np.int_
Num = Int | float | np.float_


def _tabulate(a, v, side="left"):
    """Count interval occurences.

    As an example, if side="left" and a=[1, 2, 3], then intervals are
    (-inf, 1), [1, 2), [2, 3) and [3, inf)

    Example:
        >>> a, v = np.array([1, 2, 3]), np.array([0.5, 1, 2.5, 4])
        >>> _tabulate(a, v)
        array([2, 0, 1, 1])
    """
    return np.bincount(np.searchsorted(a, v, side=side))


[docs]class KaplanMeier: """Simple implementation of the Kaplan-Meier estimator. Args: time: event times. event: event indicator (0/1) where 0 is censoring. Example: >>> km = KaplanMeier(time=[0, 1.5, 1.3, 3], event=[0, 1, 0, 0]) >>> km(0) array([1.]) >>> km([0, 1.0, 1.1, 1.5]) array([1. , 1. , 1. , 0.5]) """ def __init__(self, time: Iterable[Num], event: Iterable[Int]) -> None: _time, _event = map(np.asarray, (time, event)) order = np.argsort(_time) _time, _event = _time[order], _event[order] # observed event times tj = np.unique(np.pad(_time[_event == 1], (1, 0))) def _pad(a): return np.pad(a, (0, tj.size - a.size)) mj = _pad(_tabulate(tj, _time[_event == 1], side="right")[1:]) qj = _pad(_tabulate(tj, _time[_event == 0], side="right")[1:]) nj = np.roll(_time.size - np.cumsum(mj + qj), 1) nj[0] = _time.size self._sj = np.cumprod((nj - mj) / nj) self._nj, self._tj, self._mj = nj, tj, mj def __call__(self, time: Num | Iterable[Num]) -> npt.NDArray[np.float_]: """Obtain Kaplan-Meier estimates for each timepoint. Args: time (num | seq[num]): `t`'s for which `km(t)` will be returned. """ _time = np.atleast_1d(np.asanyarray(time)) return self._sj[np.searchsorted(self._tj, _time, side="right") - 1]
[docs] def percentile( self, p: Iterable[Num], dtype=np.float64 ) -> npt.NDArray[np.float_]: """Obtain approximate timepoint t such that P(t) = p. The stepwise Kaplan-Meier estimator is piecewise linearly interpolated such that unique timepoints can be obtained. """ p = np.atleast_1d(np.asanyarray(p, dtype=dtype)) if not np.all((0 <= p) & (p <= 1)): raise ValueError( "p is a probability and should be between 0 and 1" ) return np.interp(1 - p, 1 - self._sj, self._tj, left=0)
[docs]class AalenJohansen: """Obtain cumulative incidence curves with the Aalen-Johansen method. Args: time: event times event: event indicator (0/1/../c) with 0=censored n_causes: how many causes should be included? If None (the default) then all observed causes are include. """ def __init__( self, time: npt.ArrayLike, event: npt.ArrayLike, n_causes: Optional[Int] = None, ) -> None: time, event = map(np.asarray, (time, event)) if time.ndim != 1: raise ValueError(f"`time` is a {time.ndim}D array.") if event.ndim != 1: raise ValueError(f"`event` is a {event.ndim}D array.") order = np.argsort(time) time, event = time[order], event[order] # find unique event times and add t=0 if it isn't present tj = np.unique(np.pad(time, (1, 0))) # number at risk at the start of each interval nj = time.size - np.cumsum(_tabulate(tj, time, side="right")) nj = nj[:-1] # drop last def _pad(a): """Add zeros to end of array to ensure it has same size as `tj`""" assert (tj.size - a.size) >= 0 return np.pad(a, (0, tj.size - a.size), constant_values=0) # number lost in each interval to any cause mj = _pad(_tabulate(tj, time[event != 0])) # lagged survival which starts with P(0) = 1 sj = np.cumprod((nj - mj) / nj)[:-1] sj = np.pad(sj, (1, 0), constant_values=1) # cause-specific incidences n_risks = n_causes if n_causes else np.max(event) ci = np.zeros((tj.size, n_risks)) for e in range(1, n_risks + 1): mcj = _pad(_tabulate(tj, time[event == e])) ci[:, e - 1] = np.cumsum(sj * (mcj / nj)) self._tj = tj self._sj = sj self._ci = ci def __call__(self, timepoints: npt.ArrayLike) -> npt.NDArray[np.float64]: """Return cause-specific cumulative incidence at given timepoints.""" tau = np.asanyarray(timepoints).reshape(-1) idx = np.searchsorted(self._tj, tau, side="right") idx = np.clip(idx - 1, a_min=0, a_max=(self._tj.size - 1)) return self._ci[idx]
[docs]def interpolate2d(x: torch.Tensor, xp: torch.Tensor, yp: torch.Tensor): """Perform stepwise linear interpolation of a discrete function. _xp_ and _yp_ are tensors of values used to approximate f: y = f(x). This functions uses interpolation to find the value of new points x. Args: x (:obj:`torch.Tensor`): an 1D tensor real values. xp (:obj:`torch.Tensor`): an 1D tensor of real values. yp (:obj:`torch.Tensor`): an ND tensor of real values. The length of yp along the second axis (dim=1) must have the same length as xp. """ x, xp = x.to(yp), xp.to(yp) # move to same device m = torch.diff(yp, dim=1) / repeat(torch.diff(xp, dim=0), "t -> 1 t 1") b = yp[:, :-1, :] - torch.einsum("btr,t->btr", m, xp[:-1]) idx = torch.clamp_max(torch.searchsorted(xp, x), m.shape[1] - 1) return m[:, idx, :] * repeat(x, "t -> 1 t 1") + b[:, idx, :]