Source code for discotime.metrics.brier_score

from typing import Optional, TypeVar
from collections import namedtuple

import numpy as np
import numpy.typing as npt
from numpy.typing import NDArray

from discotime.utils.estimators import AalenJohansen

T_co = TypeVar("T_co", bound=np.generic, covariant=True)

Int = int | np.int_
Num = Int | float | np.float_


def _loo_idx(m: int) -> npt.NDArray[np.int64]:
    """Create indices for leave-one-out resampling.

    Example:
        >>> _loo_idx(4)
        array([[1, 2, 3],
               [0, 2, 3],
               [0, 1, 3],
               [0, 1, 2]])
    """
    return np.arange(1, m, dtype=np.int64) - np.tri(m, m - 1, k=-1, dtype=bool)


def _cic_pv(
    time: npt.NDArray[np.float_],
    event: npt.NDArray[np.int_],
    tau: npt.ArrayLike,
) -> NDArray[np.float64]:
    """Obtain jackknife pseudovalues from the marginal Aalen-Johansen estimate.

    Example:
        >>> rng = np.random.default_rng()
        >>> time, event = rng.weibull(5., size=100), rng.integers(3, size=100)
        >>> aj = _cic_pv(time, event, tau=[0.5, 3, 4])
        >>> aj.shape
        (100, 3, 2)
    """
    N = time.size
    tau = np.asarray(tau, np.float_)

    n_causes = np.max(event)
    cumulative_incidence = AalenJohansen(time, event, n_causes)(tau)
    cumulative_incidence_loo = np.array(
        [AalenJohansen(time[i], event[i], n_causes)(tau) for i in _loo_idx(N)]
    )
    return (N * cumulative_incidence) - ((N - 1) * cumulative_incidence_loo)


[docs]class BrierScore: """Brier score of survival model with right-censored data. In the case of right-censored data, possibly unknown time-dependent event status can be replaced with jackknife pseudovalues from the marginal Aalen-Johansen estimates [1]. Args: timepoints: sequence of `t` timepoints to evaluate. causes: causes for which the brier score is calculated. Refs: [1] Cortese, Giuliana, Thomas A. Gerds, and Per K. Andersen. "Comparing predictions among competing risks models with time‐dependent covariates." Statistics in medicine 32.18 (2013): 3089-3101. """ Results = namedtuple("Results", ["null", "model"]) def __init__( self, timepoints: npt.ArrayLike, causes: Optional[list[int]] = None ) -> None: if causes is not None and 0 in causes: raise ValueError("cause 0 is censoring and cannot be included.") self.timepoints = np.asanyarray(timepoints) self.causes = causes @staticmethod def _brier_score(label, estimate) -> NDArray[T_co]: estimate = np.asanyarray(estimate) score = label * (1 - (2 * estimate)) + np.square(estimate) return np.atleast_1d(np.mean(score, axis=0)) def __call__( self, estimates: Optional[NDArray[T_co]], time: npt.ArrayLike, event: npt.ArrayLike, ) -> Results: """Calculate Brier score of survival model. Also includes the Brier score of a null-model based on the cumulative incidence curve obtained with the Aalen-Johansen estimator. Args: estimates: an array with shape (`m`, `t`, `e`), where `m` is the batch size, `t` is the number of time bins, and `e` is number of competing causes/risks. time: survival time. event: event indicator with event=0 defined censoring. """ time, event = np.asanyarray(time), np.asanyarray(event) if not np.issubdtype(event.dtype, np.integer): raise TypeError("`event` can only be of integer type.") # check dimensionality of input arrays if time.ndim != 1: raise ValueError(f"`time` is a {time.ndim}D array.") if event.ndim != 1: raise ValueError(f"`event` is a {event.ndim}D array.") # use all available causes if none is specified if self.causes is None: causes = np.trim_zeros(np.unique(event)) else: causes = np.asanyarray(self.causes) # obtain jack-knife pseudovalues cic_pv = _cic_pv(time, event, self.timepoints)[..., causes - 1] # the null model is the cic from the AalenJohansen estimator p_null = AalenJohansen(time, event)(self.timepoints)[..., causes - 1] bs_null: npt.NDArray[np.float32] = self._brier_score(cic_pv, p_null) # calculate brier score on the model estimates if estimates is not None: bs_model = self._brier_score(cic_pv, estimates[..., causes - 1]) # type: ignore else: bs_model = None return self.Results(null=bs_null, model=bs_model)