clinicadl.callbacks.Callback

class clinicadl.callbacks.Callback[source]

To define arbitrary actions to perform at certain points of the training and evaluation workflows.

Each method of this class starting by "on_..." is associated to an event of the training or evaluation phase performed by the Trainer. By overriding these methods, the user can define action to perform when the event happens.

Important

Callbacks should capture NON-ESSENTIAL logic such as saving checkpoints or logging. The essential logic should be defined in a clinicadl.models.Model.

To define your own callback, you can override any of the method associated to an event. You can also override state_dict() and load_state_dict(), to be able the recover the state of your callback when resuming an interrupted training.

List of events
pseudocode
IF resume THEN
    CALL on_resume
ELSE
    CALL on_train_start
END IF

FOR EACH epoch IN training_epochs DO
    CALL on_epoch_start

    FOR EACH batch IN training_dataloader DO
        CALL on_batch_start

        MOVE batch TO appropriate_device

        CALL on_forward_step_start
        DO forward_step
        CALL on_backward_step_start
        DO backward_step
        CALL on_backward_step_end

        IF batch_idx MOD optimization_interval == 0 THEN
            CALL on_optimization_step_start
            DO optimization_step
            CALL on_optimization_step_end
        END IF

        CALL on_batch_end
    END FOR

    IF epoch_idx MOD evaluation_interval == 0 THEN
        CALL on_validation_start

        FOR EACH batch IN validation_dataloader DO
            CALL on_batch_start

            MOVE batch TO appropriate_device

            CALL on_evaluation_step_start
            DO evaluation_step
            CALL on_metrics_computation_start
            DO metrics_computation
            CALL on_metrics_computation_end

            CALL on_batch_end
        END FOR

        CALL on_validation_end
    END IF

    CALL on_epoch_end
END FOR

CALL on_train_end
pseudocode
CALL on_validate_start

FOR EACH batch IN validation_dataloader DO
    CALL on_batch_start

    MOVE batch TO appropriate_device

    CALL on_evaluation_step_start
    DO evaluation_step
    CALL on_metrics_computation_start
    DO metrics_computation
    CALL on_metrics_computation_end

    CALL on_batch_end
END FOR

CALL on_validate_end
pseudocode
CALL on_test_start

FOR EACH batch IN validation_dataloader DO
    CALL on_batch_start

    MOVE batch TO appropriate_device

    CALL on_evaluation_step_start
    DO evaluation_step
    CALL on_metrics_computation_start
    DO metrics_computation
    CALL on_metrics_computation_end

    CALL on_batch_end
END FOR

CALL on_test_end
state_dict() Mapping[str, Any][source]

To get a checkpoint of the current state of the callback.

Returns:

Mapping[str, Any] – The current state in a dict.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Sets to callbacks to a given state.

Parameters:

state_dict (Mapping[str, Any]) – The desired state of the Callback, as returned by state_dict().

on_backward_step_end(*, model: Model, maps: Maps, state: TrainerState) None[source]

Called every time Model.backward_step has just been called in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_backward_step_start(*, model: Model, maps: Maps, state: TrainerState, loss: Tensor | dict[str, Tensor], grad_scaler: GradScaler) None[source]

Called every time Model.backward_step will be called in Trainer.train.

Note

This event is equivalent to on_forward_step_end.

Parameters:
on_batch_end(*, model: Model, maps: Maps, state: TrainerState) None[source]

Called every time the processing of a batch is completed during training, validation, or test phase.

Note

This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to on_evaluation_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_batch_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]

Called every time a new batch has been loaded in training, validation or test phase.

Note

This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to on_evaluation_start() (except if the batch is sent to another device).

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • batch (BatchType) – The input batch.

on_evaluation_step_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]

Called every time Model.evaluation_step will be called in Trainer.train, Trainer.validate, or Trainer.test.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • batch (BatchType) – The batch input to Model.evaluation_step.

on_epoch_end(*, model: Model, maps: Maps, state: TrainerState) None[source]

Called at the end of an epoch in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_epoch_start(*, model: Model, maps: Maps, state: TrainerState) None[source]

Called at the beginning of an epoch in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_exception(*, model: Model, maps: Maps, state: TrainerState, exception: Exception) None[source]

Called when an exception interrupts an execution of the Trainer.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • exception (Exception) – The exception that has been raised.

on_forward_step_start(*, model: Model, maps: Maps, state: TrainerState, batch: Batch | Sequence[Batch] | dict[Any, Batch]) None[source]

Called every time Model.forward_step will be called in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • batch (BatchType) – The batch input to Model.forward_step.

on_metrics_computation_end(*, model: Model, maps: Maps, state: TrainerState, detailed_metrics_df: DataFrame) None[source]

Called every time metrics have just been computed on a batch.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • detailed_metrics_df (pandas.DataFrame) – The evaluation metrics on the batch.

on_metrics_computation_start(*, model: Model, maps: Maps, state: TrainerState, output: Batch, metrics: MetricsHandler) None[source]

Called every time Model.evaluation_step has been called and metrics will now be computed.

Note

This event is equivalent to on_evaluation_step_end.

Parameters:
on_optimization_step_end(*, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, Optimizer], grad_scaler: GradScaler) None[source]

Called every time Model.optimization_step has just been called in Trainer.train.

Parameters:
on_optimization_step_start(*, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, Optimizer], grad_scaler: GradScaler) None[source]

Called every time Model.optimization_step will be called in Trainer.train.

Parameters:
on_resume(*, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]

Called once when Trainer.train is resuming a training.

More precisely, this method will be called just before loading the checkpoints.

Parameters:
on_test_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]

Called once at the end of Trainer.test.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.``test metrics`.

on_test_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, group_name: str, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]

Called once at the beginning of Trainer.test.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • dataloader (DataLoader) – The dataloader on which the test is performed.

  • model_checkpoint (str) – The model checkpoint currently being tested.

  • metrics (MetricsHandler) – The test metrics to compute.

  • group_name (str) – The name given to the test data.

  • callbacks (CallbacksHandler) – The callbacks passed to the Trainer.

  • computational (ComputationalConfig) – The clinicadl.train.ComputationalConfig defining the computational specifications of the validation phase.

on_train_end(*, model: Model, maps: Maps, state: TrainerState) None[source]

Called once at the end of Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_train_start(*, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, Optimizer], grad_scaler: GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]

Called once at the beginning of Trainer.train if resume=False.

If resuming a training, on_resume() will be called instead.

Parameters:
on_trainer_init(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, optimization: OptimizationConfig, callbacks: CallbacksHandler) None[source]

Called once when the Trainer is created or restored with Trainer.from_maps.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The the metrics passed to the Trainer.

  • optimization (OptimizationConfig) – The optimization specifications of the training phase.

  • callbacks (CallbacksHandler) – The callbacks passed to the Trainer.

on_validate_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]

Called at the end of a model validation in Trainer.validate.

Not to be confused with on_validation_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The validation metrics computed.

on_validate_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig) None[source]

Called when a model is validated in Trainer.validate.

Not to be confused with on_validation_start().

Important

If model_checkpoint=None was passed to Trainer.validate, all the models saved during training will be validated. Therefore, on_validate_start will be called for each model.

Parameters:
on_validation_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler) None[source]

Called at the end of every validation loop in Trainer.train.

Not to be confused with on_validate_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The validation metrics computed.

on_validation_start(*, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, metrics: MetricsHandler) None[source]

Called at the beginning of every validation loop in Trainer.train.

Not to be confused with on_validate_start().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • dataloader (DataLoader) – The dataloader on which validation is performed.

  • metrics (MetricsHandler) – The validation metrics to compute.