2.4. Callbacks

A Trainer always runs the same steps, but at many points of that sequence you may want to do something extra — log a message, save a checkpoint, stop early, schedule the learning rate. These non-essential actions are the job of callbacks.

Important

Callbacks capture non-essential logic (logging, checkpointing, etc.). The essential logic of your experiment — how a batch is processed, how the loss is computed — belongs in your Model (see Defining a model).

A Callback defines actions to perform at specific events of the training and evaluation workflows (the start of an epoch, the end of a batch, the moment metrics are computed, etc.). You attach callbacks to a Trainer through its callbacks argument.

Here is an example of a Trainer that stops training after 10 evaluation steps without improvement in the validation loss, saves the best model with respect to the validation loss, and stores model checkpoints every 10 epochs.

from clinicadl.train import Trainer
from clinicadl.callbacks import EarlyStoppingCallback, ModelCheckpointCallback

trainer = Trainer(
    maps="maps_directory",
    model=model,
    metrics={"loss": LossMetricConfig(loss_name="loss")},
    callbacks=[
        EarlyStoppingCallback(metric="loss", patience=10),
        ModelCheckpointCallback(metric="loss", epochs=range(1, 100, 10)),
    ],
)

2.4.1. Available callbacks

Callback

Role

EarlyStoppingCallback

Stops training when a monitored metric stops improving.

LRSchedulerCallback

Adjusts the learning rate during training.

ModelCheckpointCallback

Saves the network weights at chosen epochs and/or the best model w.r.t. a metric.

TrainingCheckpointCallback

Saves the training state so that an interrupted training can be resumed.

LoggerCallback

Configures logging (console and files) and the progress bar.

MonitorCallback

Records computational statistics (time, GPU memory, throughput).

Note

Some callbacks are active by defaultLoggerCallback, MonitorCallback, ModelCheckpointCallback and TrainingCheckpointCallback. To change their settings, pass a new instance with the parameters you want; it overrides the default one. All the callbacks are managed by a CallbacksHandler, in which order matters.

2.4.2. Writing your own callback

When the built-in callbacks are not enough, you can write your own by subclassing Callback and overriding the on_... methods associated with the events you care about:

from clinicadl.callbacks import Callback

class PrintEpochCallback(Callback):
    def on_epoch_start(self, *, model, maps, state) -> None:
        print(f"Starting epoch {state.current_epoch}")

Writing your own ClinicaDL objects — callbacks and others — is the subject of Chapter 4.


This closes Chapter 2: you can now define a model, train it, evaluate it, and customise you workflow. The next chapter explains how ClinicaDL helps you manage and reproduce your experiments.