clinicadl.callbacks.EarlyStoppingCallback

class clinicadl.callbacks.EarlyStoppingCallback(metric: str | Sequence[str], patience: int | Sequence[int] = 3, min_delta: float | Sequence[float] = 0.0, check_finite: bool | Sequence[bool] = True, upper_bound: float | None | Sequence[float | None] = None, lower_bound: float | None | Sequence[float | None] = None) None[source]

Early Stopping callback monitoring one or multiple metrics.

This callback stops training if the monitored metric(s) do not improve for a specified number of evaluation phases (which does not necessarily happen every epoch, see clinicadl.optim.OptimizationConfig).

It can monitor multiple metrics simultaneously and allows separate configuration for each metric. For any parameter listed below, you may provide either a single value—applied uniformly to all monitored metrics—or a sequence of values to configure metrics individually.

Note

Passing multiple metrics here means that training should stop when all the monitored metrics have met their stopping criterion. If you want to stop the training when any of them has met its stopping criterion, you can instantiate multiple EarlyStoppingCallbacks that will monitor each metric independently.

Note

No need to specify if the monitored quantity should be minimized or maximized, it is specified in clinicadl.metrics.Metric.optimum.

Parameters:
  • metric (Union[str, Sequence[str]]) – Metric(s) to monitor.

  • patience (Union[int, Sequence[int]], default=3) – The number of allowed evaluation phases with no improvement. For example, if patience=0 the stop signal is triggered as soon as the monitored quantity is not improved.

  • min_delta (Union[float, Sequence[float]], default=0.0) – Minimum absolute change in a monitored metric to qualify as an improvement.

  • check_finite (Union[bool, Sequence[bool]], default=True) – Whether to stop if the metric becomes NaN or infinite.

  • upper_bound (Union[Optional[float], Sequence[Optional[float]]], default=None) – Optional upper threshold that will trigger stopping if exceeded.

  • lower_bound (Union[Optional[float], Sequence[Optional[float]]], default=None) – Optional lower threshold that triggers stopping when the value falls below it.

Examples

from clinicadl.callbacks import EarlyStoppingCallback
from clinicadl.train import Trainer
from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig

trainer = Trainer(
    metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()},
    callbacks=[EarlyStoppingCallback(metric="mse", patience=5)],
    ...
)
on_train_start(metrics: MetricsHandler, **kwargs)[source]

Called once at the beginning of Trainer.train if resume=False.

If resuming a training, on_resume() will be called instead.

Parameters:
on_resume(**kwargs)[source]

Called once when Trainer.train is resuming a training.

More precisely, this method will be called just before loading the checkpoints.

Parameters:
on_validation_end(*, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]

Called at the end of every validation loop in Trainer.train.

Not to be confused with on_validate_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The validation metrics computed.

state_dict() Mapping[str, Any][source]

To get a checkpoint of the current state of the callback.

Returns:

Mapping[str, Any] – The current state in a dict.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Sets to callbacks to a given state.

Parameters:

state_dict (Mapping[str, Any]) – The desired state of the Callback, as returned by state_dict().