clinicadl.callbacks.ModelCheckpointCallback¶
- class clinicadl.callbacks.ModelCheckpointCallback(metric: str | None = None, epochs: Sequence[int] | None = None, save_last: bool = True)[source]¶
To save checkpoints of the neural network weights at different points of the training.
Checkpoints can be saved after specified epochs and/or according to a monitored metric. In the latter case, only the best model according to this metric will be saved. The neural network weights after the last epoch can also be saved.
- Parameters:
metric (Optional[str], default=None) – A metric to monitor.
epochs (Optional[Sequence[int]], default=None) –
The list of epochs after which the neural network weights should be saved.
Important
Epochs are indexed from 1.
save_last (bool, default=True) – Whether to save the neural network weights after the last epoch.
Examples
from clinicadl.callbacks import ModelCheckpointCallback from clinicadl.train import Trainer from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig trainer = Trainer( metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()}, callbacks=[ ModelCheckpointCallback( metric="mse", epochs=range(1, 100, step=10), save_last=True ) ], ... )
- on_train_start(*, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]¶
Called once at the beginning of
Trainer.trainifresume=False.If resuming a training,
on_resume()will be called instead.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_resume(*, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]¶
Called once when
Trainer.trainis resuming a training.More precisely, this method will be called just before loading the checkpoints.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_validation_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]¶
Called at the end of every validation loop in
Trainer.train.Not to be confused with
on_validate_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The validation metrics computed.
- on_epoch_end(*, model: Model, maps: Maps, state: TrainerState, **kwargs) None[source]¶
Called at the end of an epoch in
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_train_end(*, model: Model, maps: Maps, state: TrainerState, **kwargs) None[source]¶
Called once at the end of
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.