clinicadl.train.Trainer¶
- class clinicadl.train.Trainer(maps: Path | str | Maps, model: Model, metrics: dict[str, Metric | MetricConfig] | MetricsHandler | None = None, optimization: OptimizationConfig | None = None, callbacks: list[Callback] | CallbacksHandler | None = None, overwrite: bool = False) None[source]¶
The core class to manage model training and evaluation.
This class makes all
ClinicaDLobjects work together in order to have functional training and evaluation pipelines. It handles the logic common to most PyTorch workflow, while allowing the user to customize the essential logic in theModel, or the non-essential logic viacallbacks.Traineralso offers high-performance computing options, configured viaComputationalConfig.Trainerwill record all outputs and results in a Maps directory, as well as the configurations used in order to reproduce the experiment.Main methods:
train(): for training your model;resume(): for resuming an interrupted training;validate(): for computing new metrics on your validation data;test(): for evaluating your model on test data.
- Parameters:
maps (Union[PathType, Maps]) – Directory where outputs, results and configurations will be saved.
model (Model) – The model to train or evaluate.
metrics (Optional[Union[dict[str, MetricOrConfig], MetricsHandler]], default=None) –
Dictionary of metric names and metric instances for monitoring model performance. Metric instances can be passed via a
clinicadl.metrics.Metricor aconfiguration class. AMetricsHandlercontaining the metrics can also be passed directly.Note
Here you define the metrics useful within the scope of your
Trainer, but it doesn’t mean that they will all always be computed. The list of the metrics to compute will be given to the relevant methods of the trainer.You can still add metrics later via
add_metrics().
By default, only the loss returned by
Model.get_loss_functionswill be monitored (it is expected to be called"loss").optimization (Optional[OptimizationConfig], default=None) – Configuration for the optimization of the neural network during training (e.g., the number of epochs). If
None, the default parameters orOptimizationConfigwill be used.callbacks (Optional[Union[list[Callback], CallbacksHandler]], default=None) –
List of
Callbackto customize the trainer. ACallbacksHandlercontaining the desired callbacks can also be passed directly.If
None, only the default callbacks inCallbacksHandlerwill be applied.
- overwritebool, default=False
Whether to overwrite the MAPS if it already exists.
- property metrics: MetricsHandler¶
The metrics computed by
Trainer.
- property callbacks: CallbacksHandler¶
The callbacks called by
Trainer.
- property optimization: OptimizationConfig¶
The optimization configuration of the
Trainer.
- property state: TrainerState¶
The current state of the
Trainer.
- classmethod from_maps(maps_path: Path | str, **kwargs) Self[source]¶
To restore a
Trainerfrom a MAPS directory.This classmethod will try to restore the
Trainerthat created the MAPS by reading thejsonconfiguration files saved inside.If errors happen when restoring a component of the
Trainer(e.g., you used your ownModel, so theTrainercannot read it), you can restore the problematic objects on your own and pass them via keyword arguments (e.g.,Trainer.train(..., model=...)).- Parameters:
maps_path (PathType) – Path to the MAPS directory.
**kwargs (Any) – To pass objects that the
Trainercannot restore.
- Returns:
Self – The
Trainerassociated to the MAPS.
Examples
trainer = Trainer( maps="maps", model=MyModel() ... ) ...
>>> Trainer.from_maps("maps") CannotReadJsonError: Cannot read the model [...] # if MyModel is not a model supported natively by ClinicaDL >>> Trainer.from_maps("maps", model=MyModel())
- add_metrics(**metrics: Metric | MetricConfig) None[source]¶
To add new metrics to compute to
metrics.- Parameters:
**metrics (MetricOrConfig) – The metrics to add, passed via a
clinicadl.metrics.Metricor aconfiguration class.
- add_callbacks(callbacks: Sequence[Callback]) None[source]¶
To add new callbacks to
callbacks.- Parameters:
callbacks (Sequence[Callback]) – The callbacks to add.
- train(split: Split, computational: ComputationalConfig | None = None, metrics: Sequence[str] | None = None) None[source]¶
To train a model.
- Parameters:
split (Split) – The split containing the training and validation data.
computational (Optional[ComputationalConfig], default=None) – Computational configuration. This is where you can setup high-performance computing features, or a seed to make your training reproducible.
metrics (Optional[Sequence[str]], default=None) –
The names of the metric to compute on the validation data. The metrics mentioned here must be in
metrics, so they must have been defined beforehand when instantiating theTraineror viaadd_metrics().By default, all the defined metrics will be computed.
See also
resume()To resume an interrupted training.
- resume(split_idx: int, split: Split | None = None) None[source]¶
To resume an interrupted training launched with
train().Trainerwill look for the last checkpoint saved withclinicadl.callbacks.TrainingCheckpointCallback, and restart training from there.Trainerwill attempt to load your training and validation data from the MAPS directory, so, typically, providing the split index is sufficient. However, it it fails, you can manually supply the split.Important
Here, the computational setup will be the same as the one used when training the model before the interruption. So, if the model was first trained on a GPU, make sure a GPU is available when calling
resume.
- validate(split_idx: int, metrics: Sequence[str], dataloader: DataLoader | None = None, model_checkpoint: str | None = None) None[source]¶
To evaluate your model on your validation data with new metrics.
This method should be used when you wish to compute new validation metrics once the training phase is complete.
Trainerattempts to load your validation data from the MAPS directory, so, typically, providing the split index is sufficient. However, it it fails, you can manually supply the validation dataloader.Important
Here, the computational setup will be the same as the one used when training the model. So, if the model was trained on a GPU, make sure a GPU is available when calling
validate.- Parameters:
split_idx (int) – The index of the split on which the model to validate has been trained.
metrics (Sequence[str]) – The names of the metrics to compute on the validation data. The metrics mentioned here must be in
metrics, so they must have been defined beforehand when instantiating theTraineror viaadd_metrics().dataloader (Optional[DataLoader], default=None) – The dataloader containing the validation data on which the model will be evaluated. If you pass a dataloader, it will be your validation data, otherwise,
Trainerwill attempt to load the validation data from the MAPS.model_checkpoint (Optional[str], default=None) –
The name of the model checkpoint to validate. If
None, all the checkpoints will be evaluated. The name of the checkpoint must follow one of these formats:"best-<metric-name>": the best model trained on splitsplit_idxw.r.t. the metric<metric-name>;"epoch-<epoch-idx>": the checkpoint of the model trained on splitsplit_idxat epoch<epoch-idx>;"final": the model at the end of training on splitsplit_idx.
Examples
from clinicadl.train import Trainer from clinicadl.metrics.config import ( LossMetricConfig, DiceMetricConfig, SurfaceDiceMetricConfig, ) from clinicadl.callbacks import ModelCheckpointCallback trainer = Trainer( ... metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()}, callbacks=[ModelCheckpointCallback(metric="loss")], # save the best model w.r.t. "loss" ) trainer.train(...) # training on split 1, validation on "loss" and "dice" trainer.add_metrics( nsd=SurfaceDiceMetricConfig(class_thresholds=[1]) ) # define a new metric named "nsd"
trainer.validate( split_idx=1, metrics=["nsd"], model_checkpoint="best-loss" ) # validation of a specific model checkpoint with the new metric "nsd"
trainer.validate(split_idx=1, metrics=["nsd"]) # validation of all checkpoints
>>> trainer.validate(split_idx=1, metrics=["nsd"]) CannotReadJsonError: ClinicaDL could not read the dataloader in ... >>> trainer.validate( split_idx=1, metrics=["nsd"], dataloader=dataloader ) # pass the dataloader if a reading error is raised
- test(model_checkpoint: str, group_name: str, metrics: Sequence[str] | None = None, dataloader: DataLoader | None = None, computational: ComputationalConfig = ComputationalConfig(gpu=True, non_blocking=True, amp=True, channels_last=True, seed=None, deterministic=False)) None[source]¶
To test you model.
This method checks for data leakage between your test and your training/validation data.
Your test data must be associated with a
group_name. If thegroup_namealready exists in the MAPS,Trainerattempts to load the test data from it. However, it it fails, you can manually supply the test dataloader.- Parameters:
model_checkpoint (str) –
The name of the model checkpoint to test. The name of the checkpoint must follow one of these formats:
"split-<split-idx>_best-<metric-name>": the best model trained on split<split-idx>w.r.t. the metric<metric-name>;"split-<split-idx>_epoch-<epoch-idx>": the checkpoint of the model trained on split<split-idx>at epoch<epoch-idx>;"split-<split-idx>_final": the model at the end of training on split<split-idx>.
group_name (str) – The name given to these test data. If the group name was already used,
Trainerwill attempt to restore the data from the MAPS; otherwise, you must pass adataloadercontaining your test data.metrics (Optional[Sequence[str]], default=None) –
The names of the metric to compute on the test data. The metrics mentioned here must be in
metrics, so they must have been defined beforehand when instantiating theTraineror viaadd_metrics().By default, all the defined metrics will be computed.
dataloader (Optional[DataLoader], default=None) – The dataloader containing the test data on which the model will be evaluated. If you pass a dataloader, it will be your test data; otherwise,
Trainerwill attempt to load the test data from the MAPS.computational (ComputationalConfig, default=ComputationalConfig()) – Computational configuration. This is where you can setup high-performance computing features.
Examples
from clinicadl.train import Trainer from clinicadl.metrics.config import ( LossMetricConfig, DiceMetricConfig, ) from clinicadl.callbacks import ModelCheckpointCallback trainer = Trainer( ... metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()}, callbacks=[ ModelCheckpointCallback(metric="loss") ], # save the best model w.r.t. "loss" ) trainer.train(...) # training on split 1
trainer.test( model_checkpoint="split-1_best-loss", group_name="adni" ) # compute all the metrics
trainer.test( model_checkpoint="split-1_best-loss", group_name="adni", metrics=["dice"] ) # compute only "dice"
>>> trainer.test(model_checkpoint="split-1_best-loss", group_name="adni") CannotReadJsonError: ClinicaDL could not read the dataloader in ... >>> trainer.test( model_checkpoint="split-1_best-loss", group_name="adni", dataloader=dataloader ) # pass the dataloader if the Trainer cannot read the dataloader associated to "adni"