clinicadl.train.Trainer

class clinicadl.train.Trainer(maps: Path | str | Maps, model: Model, metrics: dict[str, Metric | MetricConfig] | MetricsHandler | None = None, optimization: OptimizationConfig | None = None, callbacks: list[Callback] | CallbacksHandler | None = None, overwrite: bool = False) None[source]

The core class to manage model training and evaluation.

This class makes all ClinicaDL objects work together in order to have functional training and evaluation pipelines. It handles the logic common to most PyTorch workflow, while allowing the user to customize the essential logic in the Model, or the non-essential logic via callbacks.

Trainer also offers high-performance computing options, configured via ComputationalConfig.

Trainer will record all outputs and results in a Maps directory, as well as the configurations used in order to reproduce the experiment.

Main methods:

  • train(): for training your model;

  • resume(): for resuming an interrupted training;

  • validate(): for computing new metrics on your validation data;

  • test(): for evaluating your model on test data.

Parameters:
  • maps (Union[PathType, Maps]) – Directory where outputs, results and configurations will be saved.

  • model (Model) – The model to train or evaluate.

  • metrics (Optional[Union[dict[str, MetricOrConfig], MetricsHandler]], default=None) –

    Dictionary of metric names and metric instances for monitoring model performance. Metric instances can be passed via a clinicadl.metrics.Metric or a configuration class. A MetricsHandler containing the metrics can also be passed directly.

    Note

    • Here you define the metrics useful within the scope of your Trainer, but it doesn’t mean that they will all always be computed. The list of the metrics to compute will be given to the relevant methods of the trainer.

    • You can still add metrics later via add_metrics().

    By default, only the loss returned by Model.get_loss_functions will be monitored (it is expected to be called "loss").

  • optimization (Optional[OptimizationConfig], default=None) – Configuration for the optimization of the neural network during training (e.g., the number of epochs). If None, the default parameters or OptimizationConfig will be used.

  • callbacks (Optional[Union[list[Callback], CallbacksHandler]], default=None) –

    List of Callback to customize the trainer. A CallbacksHandler containing the desired callbacks can also be passed directly.

    If None, only the default callbacks in CallbacksHandler will be applied.

overwritebool, default=False

Whether to overwrite the MAPS if it already exists.

property maps: Maps

The MAPS associated to the Trainer.

property model: Model

The model associated to the Trainer.

property metrics: MetricsHandler

The metrics computed by Trainer.

property callbacks: CallbacksHandler

The callbacks called by Trainer.

property optimization: OptimizationConfig

The optimization configuration of the Trainer.

property state: TrainerState

The current state of the Trainer.

classmethod from_maps(maps_path: Path | str, **kwargs) Self[source]

To restore a Trainer from a MAPS directory.

This classmethod will try to restore the Trainer that created the MAPS by reading the json configuration files saved inside.

If errors happen when restoring a component of the Trainer (e.g., you used your own Model, so the Trainer cannot read it), you can restore the problematic objects on your own and pass them via keyword arguments (e.g., Trainer.train(..., model=...)).

Parameters:
  • maps_path (PathType) – Path to the MAPS directory.

  • **kwargs (Any) – To pass objects that the Trainer cannot restore.

Returns:

Self – The Trainer associated to the MAPS.

Examples

trainer = Trainer(
    maps="maps",
    model=MyModel()
    ...
)
...
>>> Trainer.from_maps("maps")
CannotReadJsonError: Cannot read the model [...]    # if MyModel is not a model supported natively by ClinicaDL
>>> Trainer.from_maps("maps", model=MyModel())
add_metrics(**metrics: Metric | MetricConfig) None[source]

To add new metrics to compute to metrics.

Parameters:

**metrics (MetricOrConfig) – The metrics to add, passed via a clinicadl.metrics.Metric or a configuration class.

add_callbacks(callbacks: Sequence[Callback]) None[source]

To add new callbacks to callbacks.

Parameters:

callbacks (Sequence[Callback]) – The callbacks to add.

train(split: Split, computational: ComputationalConfig | None = None, metrics: Sequence[str] | None = None) None[source]

To train a model.

Parameters:
  • split (Split) – The split containing the training and validation data.

  • computational (Optional[ComputationalConfig], default=None) – Computational configuration. This is where you can setup high-performance computing features, or a seed to make your training reproducible.

  • metrics (Optional[Sequence[str]], default=None) –

    The names of the metric to compute on the validation data. The metrics mentioned here must be in metrics, so they must have been defined beforehand when instantiating the Trainer or via add_metrics().

    By default, all the defined metrics will be computed.

See also

resume()

To resume an interrupted training.

resume(split_idx: int, split: Split | None = None) None[source]

To resume an interrupted training launched with train().

Trainer will look for the last checkpoint saved with clinicadl.callbacks.TrainingCheckpointCallback, and restart training from there.

Trainer will attempt to load your training and validation data from the MAPS directory, so, typically, providing the split index is sufficient. However, it it fails, you can manually supply the split.

Important

Here, the computational setup will be the same as the one used when training the model before the interruption. So, if the model was first trained on a GPU, make sure a GPU is available when calling resume.

Parameters:
  • split_idx (int) – The index of the split to resume training on.

  • split (Optional[Split], default=None) –

    The split containing the training and validation data.

    If you pass a split here, it will be used; otherwise, Trainer will attempt to load the split from the MAPS.

validate(split_idx: int, metrics: Sequence[str], dataloader: DataLoader | None = None, model_checkpoint: str | None = None) None[source]

To evaluate your model on your validation data with new metrics.

This method should be used when you wish to compute new validation metrics once the training phase is complete.

Trainer attempts to load your validation data from the MAPS directory, so, typically, providing the split index is sufficient. However, it it fails, you can manually supply the validation dataloader.

Important

Here, the computational setup will be the same as the one used when training the model. So, if the model was trained on a GPU, make sure a GPU is available when calling validate.

Parameters:
  • split_idx (int) – The index of the split on which the model to validate has been trained.

  • metrics (Sequence[str]) – The names of the metrics to compute on the validation data. The metrics mentioned here must be in metrics, so they must have been defined beforehand when instantiating the Trainer or via add_metrics().

  • dataloader (Optional[DataLoader], default=None) – The dataloader containing the validation data on which the model will be evaluated. If you pass a dataloader, it will be your validation data, otherwise, Trainer will attempt to load the validation data from the MAPS.

  • model_checkpoint (Optional[str], default=None) –

    The name of the model checkpoint to validate. If None, all the checkpoints will be evaluated. The name of the checkpoint must follow one of these formats:

    • "best-<metric-name>": the best model trained on split split_idx w.r.t. the metric <metric-name>;

    • "epoch-<epoch-idx>": the checkpoint of the model trained on split split_idx at epoch <epoch-idx>;

    • "final": the model at the end of training on split split_idx.

Examples

from clinicadl.train import Trainer
from clinicadl.metrics.config import (
    LossMetricConfig,
    DiceMetricConfig,
    SurfaceDiceMetricConfig,
)
from clinicadl.callbacks import ModelCheckpointCallback

trainer = Trainer(
    ...
    metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()},
    callbacks=[ModelCheckpointCallback(metric="loss")],  # save the best model w.r.t. "loss"
)

trainer.train(...)  # training on split 1, validation on "loss" and "dice"

trainer.add_metrics(
    nsd=SurfaceDiceMetricConfig(class_thresholds=[1])
)  # define a new metric named "nsd"
trainer.validate(
    split_idx=1, metrics=["nsd"], model_checkpoint="best-loss"
)  # validation of a specific model checkpoint with the new metric "nsd"
trainer.validate(split_idx=1, metrics=["nsd"])  # validation of all checkpoints
>>> trainer.validate(split_idx=1, metrics=["nsd"])
CannotReadJsonError: ClinicaDL could not read the dataloader in ...
>>> trainer.validate(
        split_idx=1, metrics=["nsd"], dataloader=dataloader
    )  # pass the dataloader if a reading error is raised
test(model_checkpoint: str, group_name: str, metrics: Sequence[str] | None = None, dataloader: DataLoader | None = None, computational: ComputationalConfig = ComputationalConfig(gpu=True, non_blocking=True, amp=True, channels_last=True, seed=None, deterministic=False)) None[source]

To test you model.

This method checks for data leakage between your test and your training/validation data.

Your test data must be associated with a group_name. If the group_name already exists in the MAPS, Trainer attempts to load the test data from it. However, it it fails, you can manually supply the test dataloader.

Parameters:
  • model_checkpoint (str) –

    The name of the model checkpoint to test. The name of the checkpoint must follow one of these formats:

    • "split-<split-idx>_best-<metric-name>": the best model trained on split <split-idx> w.r.t. the metric <metric-name>;

    • "split-<split-idx>_epoch-<epoch-idx>": the checkpoint of the model trained on split <split-idx> at epoch <epoch-idx>;

    • "split-<split-idx>_final": the model at the end of training on split <split-idx>.

  • group_name (str) – The name given to these test data. If the group name was already used, Trainer will attempt to restore the data from the MAPS; otherwise, you must pass a dataloader containing your test data.

  • metrics (Optional[Sequence[str]], default=None) –

    The names of the metric to compute on the test data. The metrics mentioned here must be in metrics, so they must have been defined beforehand when instantiating the Trainer or via add_metrics().

    By default, all the defined metrics will be computed.

  • dataloader (Optional[DataLoader], default=None) – The dataloader containing the test data on which the model will be evaluated. If you pass a dataloader, it will be your test data; otherwise, Trainer will attempt to load the test data from the MAPS.

  • computational (ComputationalConfig, default=ComputationalConfig()) – Computational configuration. This is where you can setup high-performance computing features.

Examples

from clinicadl.train import Trainer
from clinicadl.metrics.config import (
    LossMetricConfig,
    DiceMetricConfig,
)
from clinicadl.callbacks import ModelCheckpointCallback

trainer = Trainer(
    ...
    metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()},
    callbacks=[
        ModelCheckpointCallback(metric="loss")
    ],  # save the best model w.r.t. "loss"
)

trainer.train(...)  # training on split 1
trainer.test(
    model_checkpoint="split-1_best-loss", group_name="adni"
)  # compute all the metrics
trainer.test(
    model_checkpoint="split-1_best-loss", group_name="adni", metrics=["dice"]
)  # compute only "dice"
>>> trainer.test(model_checkpoint="split-1_best-loss", group_name="adni")
CannotReadJsonError: ClinicaDL could not read the dataloader in ...
>>> trainer.test(
    model_checkpoint="split-1_best-loss", group_name="adni", dataloader=dataloader
)  # pass the dataloader if the Trainer cannot read the dataloader associated to "adni"