clinicadl.callbacks.EarlyStoppingCallback¶
- class clinicadl.callbacks.EarlyStoppingCallback(metric: str | Sequence[str], patience: int | Sequence[int] = 3, min_delta: float | Sequence[float] = 0.0, check_finite: bool | Sequence[bool] = True, upper_bound: float | None | Sequence[float | None] = None, lower_bound: float | None | Sequence[float | None] = None) None[source]¶
Early Stopping callback monitoring one or multiple metrics.
This callback stops training if the monitored metric(s) do not improve for a specified number of evaluation phases (which does not necessarily happen every epoch, see
clinicadl.optim.OptimizationConfig).It can monitor multiple metrics simultaneously and allows separate configuration for each metric. For any parameter listed below, you may provide either a single value—applied uniformly to all monitored metrics—or a sequence of values to configure metrics individually.
Note
Passing multiple metrics here means that training should stop when all the monitored metrics have met their stopping criterion. If you want to stop the training when any of them has met its stopping criterion, you can instantiate multiple
EarlyStoppingCallbacksthat will monitor each metric independently.Note
No need to specify if the monitored quantity should be minimized or maximized, it is specified in
clinicadl.metrics.Metric.optimum.- Parameters:
patience (Union[int, Sequence[int]], default=3) – The number of allowed evaluation phases with no improvement. For example, if
patience=0the stop signal is triggered as soon as the monitored quantity is not improved.min_delta (Union[float, Sequence[float]], default=0.0) – Minimum absolute change in a monitored metric to qualify as an improvement.
check_finite (Union[bool, Sequence[bool]], default=True) – Whether to stop if the metric becomes NaN or infinite.
upper_bound (Union[Optional[float], Sequence[Optional[float]]], default=None) – Optional upper threshold that will trigger stopping if exceeded.
lower_bound (Union[Optional[float], Sequence[Optional[float]]], default=None) – Optional lower threshold that triggers stopping when the value falls below it.
Examples
from clinicadl.callbacks import EarlyStoppingCallback from clinicadl.train import Trainer from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig trainer = Trainer( metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()}, callbacks=[EarlyStoppingCallback(metric="mse", patience=5)], ... )
- on_train_start(metrics: MetricsHandler, **kwargs)[source]¶
Called once at the beginning of
Trainer.trainifresume=False.If resuming a training,
on_resume()will be called instead.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_resume(**kwargs)[source]¶
Called once when
Trainer.trainis resuming a training.More precisely, this method will be called just before loading the checkpoints.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_validation_end(*, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]¶
Called at the end of every validation loop in
Trainer.train.Not to be confused with
on_validate_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The validation metrics computed.