from __future__ import annotations
import math
from collections import defaultdict
from collections.abc import Sequence
from logging import getLogger
from typing import TYPE_CHECKING, Any, Mapping, Optional, TypeVar, Union
from pydantic import Field, NonNegativeFloat, NonNegativeInt, model_validator
from typing_extensions import Self
from clinicadl.metrics.enum import Optimum
from clinicadl.metrics.handler import MetricsHandler
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.objects import HasConfig
from ..base import Callback
from .utils import QuantityMonitoring
if TYPE_CHECKING:
from clinicadl.train import TrainerState
logger = getLogger(__name__)
class OneMetricEarlyStoppingConfig(ObjectConfig["OneMetricEarlyStopping"]):
"""Config class for ``OneMetricEarlyStopping``."""
metric: str
patience: NonNegativeInt
min_delta: NonNegativeFloat
mode: Optional[Optimum] # we may not know the mode at first
check_finite: bool
upper_bound: Optional[float]
lower_bound: Optional[float]
@model_validator(mode="after")
def _check_bounds(self) -> Self:
"""Validate that upper_bound is greater than lower_bound."""
if self.upper_bound is not None and self.lower_bound is not None:
if self.lower_bound > self.upper_bound:
raise ValueError("Upper bound should be greater than lower bound.")
return self
@classmethod
def _get_class(cls):
return OneMetricEarlyStopping
class OneMetricEarlyStopping(
QuantityMonitoring, HasConfig[OneMetricEarlyStoppingConfig]
):
"""
Early stopping for a single metric.
"""
_config_type = OneMetricEarlyStoppingConfig
def __init__(
self,
metric: str,
patience: int,
min_delta: float,
mode: Optimum,
check_finite: bool,
upper_bound: Optional[float],
lower_bound: Optional[float],
) -> None:
self.config = self._config_type(
metric=metric,
patience=patience,
min_delta=min_delta,
mode=mode,
check_finite=check_finite,
upper_bound=upper_bound,
lower_bound=lower_bound,
)
super().__init__(
name=self.config.metric,
min_delta=self.config.min_delta,
mode=self.config.mode,
)
# pylint: disable=arguments-renamed
def step(self, metrics: MetricsHandler, state: TrainerState) -> bool:
"""
Checks if training should stop.
Parameters
----------
metrics : MetricsHandler
The :py:class:`clinicadl.metrics.MetricsHandler` containing the validation metrics.
state : TrainerState
The state of the trainer.
Returns
-------
bool
The decision.
"""
value = metrics.get_metric_value(
metric=self.config.metric, epoch=state.current_epoch
)
if self.config.check_finite and (math.isinf(value) or math.isnan(value)):
logger.warning(
"Metric '%s' value at epoch %s is not a finite float. Stopping training.",
self.config.metric,
state.current_epoch,
)
return True
if self.config.upper_bound is not None and (value > self.config.upper_bound):
logger.warning(
"Metric '%s' value %s exceeds upper bound %s at epoch %s. Stopping training.",
self.config.metric,
value,
self.config.upper_bound,
state.current_epoch,
)
return True
if self.config.lower_bound is not None and (value < self.config.lower_bound):
logger.warning(
"Metric '%s' value %s falls below lower bound %s at epoch %s. Stopping training.",
self.config.metric,
value,
self.config.lower_bound,
state.current_epoch,
)
return True
super().step(value, log=True)
if self.num_non_improvements > self.config.patience:
logger.info(
"Early stopping triggered on metric '%s' after %s evaluation(s) without improvement.",
self.config.metric,
self.num_non_improvements,
)
return True
return False
T = TypeVar("T")
class EarlyStoppingCallbackConfig(ObjectConfig["EarlyStoppingCallback"]):
"""Config class for ``EarlyStoppingCallback``."""
stoppers: Sequence[OneMetricEarlyStoppingConfig] = Field(
json_schema_extra={
"reader": lambda stoppers: list(
map(OneMetricEarlyStoppingConfig.from_dict, stoppers)
)
}
)
@classmethod
def from_parameters(
cls,
metric: Union[str, Sequence[str]],
patience: Union[int, Sequence[int]],
min_delta: Union[float, Sequence[float]],
check_finite: Union[bool, Sequence[bool]],
upper_bound: Union[Optional[float], Sequence[Optional[float]]],
lower_bound: Union[Optional[float], Sequence[Optional[float]]],
) -> Self:
"""
Creates a sequence of Early Stoppers from sequences of parameters.
"""
metric = (
metric
if isinstance(metric, Sequence) and not isinstance(metric, str)
else [metric]
)
n = len(metric)
configs = []
for m, p, m_d, c_f, u_b, l_b in zip(
metric,
cls._ensure_sequence(patience, n, "patience"),
cls._ensure_sequence(min_delta, n, "min_delta"),
cls._ensure_sequence(check_finite, n, "check_finite"),
cls._ensure_sequence(upper_bound, n, "upper_bound"),
cls._ensure_sequence(lower_bound, n, "lower_bound"),
):
configs.append(
OneMetricEarlyStoppingConfig(
metric=m,
patience=p,
min_delta=m_d,
check_finite=c_f,
upper_bound=u_b,
lower_bound=l_b,
mode=None,
)
)
return cls(stoppers=configs)
@classmethod
def _ensure_sequence(
cls, x: Union[T, Sequence[T]], len_: int, name: str
) -> Sequence[T]:
"""
Ensure a sequence for any parameter.
"""
if not isinstance(x, Sequence) or isinstance(x, str):
return [x] * len_
if len(x) == 1:
return x * len_
if len(x) != len_:
raise ValueError(
f"For {cls._get_name()}, there are {len_} metrics, but you passed {len(x)} '{name}': {x}"
)
return x
@classmethod
def _get_class(cls):
return EarlyStoppingCallback
[docs]
class EarlyStoppingCallback(Callback, HasConfig[EarlyStoppingCallbackConfig]):
"""
Early Stopping callback monitoring one or multiple metrics.
This callback stops training if the monitored metric(s) do not improve for a
specified number of evaluation phases (which does not necessarily happen every epoch, see :py:class:`clinicadl.optim.OptimizationConfig`).
It can monitor multiple metrics simultaneously and allows separate configuration for each metric.
For any parameter listed below, you may provide either a single value—applied uniformly to all
monitored metrics—or a sequence of values to configure metrics individually.
.. note::
Passing multiple metrics here means that training should stop when **all** the
monitored metrics have met their stopping criterion. If you want to stop the
training when **any** of them has met its stopping criterion, you can instantiate
multiple ``EarlyStoppingCallbacks`` that will monitor each metric independently.
.. note::
No need to specify if the monitored quantity should be minimized or maximized,
it is specified in :py:attr:`clinicadl.metrics.Metric.optimum`.
Parameters
----------
metric : Union[str, Sequence[str]]
Metric(s) to monitor.
patience : Union[int, Sequence[int]], default=3
The number of allowed evaluation phases with no improvement. For example, if ``patience=0``
the stop signal is triggered as soon as the monitored quantity is not improved.
min_delta : Union[float, Sequence[float]], default=0.0
Minimum absolute change in a monitored metric to qualify as an improvement.
check_finite : Union[bool, Sequence[bool]], default=True
Whether to stop if the metric becomes NaN or infinite.
upper_bound : Union[Optional[float], Sequence[Optional[float]]], default=None
Optional upper threshold that will trigger stopping if exceeded.
lower_bound : Union[Optional[float], Sequence[Optional[float]]], default=None
Optional lower threshold that triggers stopping when the value falls below it.
Examples
--------
.. code-block::
from clinicadl.callbacks import EarlyStoppingCallback
from clinicadl.train import Trainer
from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig
trainer = Trainer(
metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()},
callbacks=[EarlyStoppingCallback(metric="mse", patience=5)],
...
)
"""
_config_type = EarlyStoppingCallbackConfig
def __init__(
self,
metric: Union[str, Sequence[str]],
patience: Union[int, Sequence[int]] = 3,
min_delta: Union[float, Sequence[float]] = 0.0,
check_finite: Union[bool, Sequence[bool]] = True,
upper_bound: Union[Optional[float], Sequence[Optional[float]]] = None,
lower_bound: Union[Optional[float], Sequence[Optional[float]]] = None,
) -> None:
self.config = self._config_type.from_parameters(
metric=metric,
patience=patience,
min_delta=min_delta,
check_finite=check_finite,
upper_bound=upper_bound,
lower_bound=lower_bound,
)
self.stoppers: Optional[list[OneMetricEarlyStopping]] = None
def _init_stoppers(self) -> None:
"""
Initializes all the underlying early stoppers.
"""
have_modes = all(stopper.mode is not None for stopper in self.config.stoppers)
if not have_modes:
raise RuntimeError(
"Cannot initialize early stoppers because their modes "
"need to be specified. E.g. by calling on_train_start"
)
self.stoppers = [config.get_object() for config in self.config.stoppers]
def _add_modes(self, modes: dict[str, Optimum]) -> None:
"""
Adds their modes to the early stoppers.
"""
for config in self.config.stoppers:
config.mode = modes[config.metric]
def _reset(self) -> None:
"""
To reset metrics monitoring.
"""
if self.stoppers:
for stopper in self.stoppers:
stopper.reset()
# pylint: disable=arguments-differ, unused-argument
[docs]
def on_train_start(self, metrics: MetricsHandler, **kwargs):
self._reset()
for config in self.config.stoppers:
metrics.check_metric_name(metric=config.metric)
modes = {name: metric.optimum for name, metric in metrics.metrics.items()}
self._add_modes(modes)
self._init_stoppers()
[docs]
def on_resume(self, **kwargs):
self._init_stoppers()
[docs]
def on_validation_end(
self, *, state: TrainerState, metrics: MetricsHandler, **kwargs
) -> None:
should_stops = [stopper.step(metrics, state) for stopper in self.stoppers]
should_stop = all(should_stops)
if should_stop:
logger.info(
"Early stopping criteria met for all monitored metrics. Stopping training."
)
state.should_stop = should_stop
[docs]
def state_dict(self) -> Mapping[str, Any]:
if self.stoppers is None:
return {}
return {
stopper.config.metric: stopper.state_dict() for stopper in self.stoppers
}
[docs]
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
if self.stoppers is not None and state_dict:
for stopper in self.stoppers:
stopper.load_state_dict(state_dict[stopper.config.metric])
@classmethod
def _from_config(cls, config):
args = defaultdict(list)
for stopper in config.stoppers:
for k, v in stopper:
args[k].append(v)
args.pop("mode")
early_stopper = cls(**args)
modes = {stopper.metric: stopper.mode for stopper in config.stoppers}
early_stopper._add_modes(modes)
try:
early_stopper._init_stoppers()
except RuntimeError:
pass
return early_stopper