clinicadl.callbacks.Callback¶
- class clinicadl.callbacks.Callback[source]¶
To define arbitrary actions to perform at certain points of the training and evaluation workflows.
Each method of this class starting by
"on_..."is associated to an event of the training or evaluation phase performed by theTrainer. By overriding these methods, the user can define action to perform when the event happens.Important
Callbacks should capture NON-ESSENTIAL logic such as saving checkpoints or logging. The essential logic should be defined in a
clinicadl.models.Model.To define your own callback, you can override any of the method associated to an event. You can also override
state_dict()andload_state_dict(), to be able the recover the state of your callback when resuming an interrupted training.List of events
on_trainer_init(): called when the trainer is instantiated.on_exception(): called whenever an uncaught exception is raised.
pseudocode¶IF resume THEN CALL on_resume ELSE CALL on_train_start END IF FOR EACH epoch IN training_epochs DO CALL on_epoch_start FOR EACH batch IN training_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_forward_step_start DO forward_step CALL on_backward_step_start DO backward_step CALL on_backward_step_end IF batch_idx MOD optimization_interval == 0 THEN CALL on_optimization_step_start DO optimization_step CALL on_optimization_step_end END IF CALL on_batch_end END FOR IF epoch_idx MOD evaluation_interval == 0 THEN CALL on_validation_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_validation_end END IF CALL on_epoch_end END FOR CALL on_train_end
pseudocode¶CALL on_validate_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_validate_end
pseudocode¶CALL on_test_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_test_end
- state_dict() Mapping[str, Any][source]¶
To get a checkpoint of the current state of the callback.
- Returns:
Mapping[str, Any] – The current state in a
dict.
- load_state_dict(state_dict: Mapping[str, Any]) None[source]¶
Sets to callbacks to a given state.
- Parameters:
state_dict (Mapping[str, Any]) – The desired state of the
Callback, as returned bystate_dict().
- on_backward_step_end(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called every time
Model.backward_stephas just been called inTrainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_backward_step_start(*, model: Model, maps: Maps, state: TrainerState, loss: Tensor | dict[str, Tensor], grad_scaler: GradScaler) None[source]¶
Called every time
Model.backward_stepwill be called inTrainer.train.Note
This event is equivalent to
on_forward_step_end.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.loss (LossType) – The loss output by
Model.forward_stepand input byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
- on_batch_end(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called every time the processing of a batch is completed during training, validation, or test phase.
Note
This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to
on_evaluation_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_batch_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]¶
Called every time a new batch has been loaded in training, validation or test phase.
Note
This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to
on_evaluation_start()(except if the batch is sent to another device).- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.batch (BatchType) – The input batch.
- on_evaluation_step_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]¶
Called every time
Model.evaluation_stepwill be called inTrainer.train,Trainer.validate, orTrainer.test.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.batch (BatchType) – The batch input to
Model.evaluation_step.
- on_epoch_end(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called at the end of an epoch in
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_epoch_start(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called at the beginning of an epoch in
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_exception(*, model: Model, maps: Maps, state: TrainerState, exception: Exception) None[source]¶
Called when an exception interrupts an execution of the
Trainer.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.exception (Exception) – The exception that has been raised.
- on_forward_step_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]¶
Called every time
Model.forward_stepwill be called inTrainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.batch (BatchType) – The batch input to
Model.forward_step.
- on_metrics_computation_end(*, model: Model, maps: Maps, state: TrainerState, detailed_metrics_df: DataFrame) None[source]¶
Called every time metrics have just been computed on a batch.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.detailed_metrics_df (pandas.DataFrame) – The evaluation metrics on the batch.
- on_metrics_computation_start(*, model: Model, maps: Maps, state: TrainerState, output: Batch, metrics: MetricsHandler) None[source]¶
Called every time
Model.evaluation_stephas been called and metrics will now be computed.Note
This event is equivalent to
on_evaluation_step_end.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.output (Batch) – The batch output by
Model.evaluation_step.metrics (MetricsHandler) – The metrics to compute.
- on_optimization_step_end(*, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, Optimizer], grad_scaler: GradScaler) None[source]¶
Called every time
Model.optimization_stephas just been called inTrainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
- on_optimization_step_start(*, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, Optimizer], grad_scaler: GradScaler) None[source]¶
Called every time
Model.optimization_stepwill be called inTrainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
- on_resume(*, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]¶
Called once when
Trainer.trainis resuming a training.More precisely, this method will be called just before loading the checkpoints.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_test_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]¶
Called once at the end of
Trainer.test.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.``test metrics`.
- on_test_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, group_name: str, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]¶
Called once at the beginning of
Trainer.test.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.dataloader (DataLoader) – The dataloader on which the test is performed.
model_checkpoint (str) – The model checkpoint currently being tested.
metrics (MetricsHandler) – The test metrics to compute.
group_name (str) – The name given to the test data.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the validation phase.
- on_train_end(*, model: Model, maps: Maps, state: TrainerState) None[source]¶
Called once at the end of
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_train_start(*, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]¶
Called once at the beginning of
Trainer.trainifresume=False.If resuming a training,
on_resume()will be called instead.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_trainer_init(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, optimization: OptimizationConfig, callbacks: CallbacksHandler) None[source]¶
Called once when the
Traineris created or restored withTrainer.from_maps.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The the metrics passed to the
Trainer.optimization (OptimizationConfig) – The optimization specifications of the training phase.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.
- on_validate_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]¶
Called at the end of a model validation in
Trainer.validate.Not to be confused with
on_validation_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The validation metrics computed.
- on_validate_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]¶
Called when a model is validated in
Trainer.validate.Not to be confused with
on_validation_start().Important
If
model_checkpoint=Nonewas passed toTrainer.validate, all the models saved during training will be validated. Therefore,on_validate_startwill be called for each model.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.dataloader (DataLoader) – The dataloader on which validation is performed.
model_checkpoint (str) – The model checkpoint currently being validated.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the validation phase.
- on_validation_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]¶
Called at the end of every validation loop in
Trainer.train.Not to be confused with
on_validate_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The validation metrics computed.
- on_validation_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, metrics: MetricsHandler) None[source]¶
Called at the beginning of every validation loop in
Trainer.train.Not to be confused with
on_validate_start().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.dataloader (DataLoader) – The dataloader on which validation is performed.
metrics (MetricsHandler) – The validation metrics to compute.