clinicadl.callbacks.ModelCheckpointCallback

class clinicadl.callbacks.ModelCheckpointCallback(metric: str | None = None, epochs: Sequence[int] | None = None, save_last: bool = True)[source]

To save checkpoints of the neural network weights at different points of the training.

Checkpoints can be saved after specified epochs and/or according to a monitored metric. In the latter case, only the best model according to this metric will be saved. The neural network weights after the last epoch can also be saved.

Parameters:
  • metric (Optional[str], default=None) – A metric to monitor.

  • epochs (Optional[Sequence[int]], default=None) –

    The list of epochs after which the neural network weights should be saved.

    Important

    Epochs are indexed from 1.

  • save_last (bool, default=True) – Whether to save the neural network weights after the last epoch.

Examples

from clinicadl.callbacks import ModelCheckpointCallback
from clinicadl.train import Trainer
from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig

trainer = Trainer(
    metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()},
    callbacks=[
        ModelCheckpointCallback(
            metric="mse", epochs=range(1, 100, step=10), save_last=True
        )
    ],
    ...
)
on_train_start(*, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]

Called once at the beginning of Trainer.train if resume=False.

If resuming a training, on_resume() will be called instead.

Parameters:
on_resume(*, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]

Called once when Trainer.train is resuming a training.

More precisely, this method will be called just before loading the checkpoints.

Parameters:
on_validation_end(*, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]

Called at the end of every validation loop in Trainer.train.

Not to be confused with on_validate_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The validation metrics computed.

on_epoch_end(*, model: Model, maps: Maps, state: TrainerState, **kwargs) None[source]

Called at the end of an epoch in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_train_end(*, model: Model, maps: Maps, state: TrainerState, **kwargs) None[source]

Called once at the end of Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

state_dict() Mapping[str, Any][source]

To get a checkpoint of the current state of the callback.

Returns:

Mapping[str, Any] – The current state in a dict.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Sets to callbacks to a given state.

Parameters:

state_dict (Mapping[str, Any]) – The desired state of the Callback, as returned by state_dict().