clinicadl.callbacks.TrainingCheckpointCallback¶
- class clinicadl.callbacks.TrainingCheckpointCallback(every_n_epochs: int = 10, enabled: bool = True)[source]¶
To save checkpoints of a training phase.
The user can then resume a training from the last saved checkpoint when calling
Trainer.resume.The checkpoints will be deleted when the training is completed. To save permanently checkpoints of your neural network, use instead
ModelCheckpointCallback.- Parameters:
- on_train_start(*, metrics: MetricsHandler, callbacks: CallbacksHandler, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, **kwargs) None[source]¶
Called once at the beginning of
Trainer.trainifresume=False.If resuming a training,
on_resume()will be called instead.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_exception(*, maps: Maps, state: TrainerState, **kwargs) None[source]¶
Called when an exception interrupts an execution of the
Trainer.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.exception (Exception) – The exception that has been raised.
- on_resume(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, callbacks: CallbacksHandler, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, **kwargs) None[source]¶
Called once when
Trainer.trainis resuming a training.More precisely, this method will be called just before loading the checkpoints.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_epoch_end(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called at the end of an epoch in
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_train_end(*, maps: Maps, state: TrainerState, **kwargs) None[source]¶
Called once at the end of
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.