2.2. Training

Training is orchestrated by the Trainer. It ties together a Model, the data of a Split, the metrics to monitor and the callbacks to run, and takes care of the training loop — moving data to the GPU, mixed precision, gradient accumulation, evaluation, checkpointing, etc. — so that you only provide the pieces specific to your experiment.

2.2.1. A first training

Putting together what we built in Chapter 1 and in Defining a model:

from clinicadl.train import Trainer

# `model` is a Model and `split` a Split (see the previous sections)
trainer = Trainer(maps="maps_directory", model=model)
trainer.train(split)

The Trainer writes everything it produces — trained weights, metrics, logs and the configuration used — into the maps directory (the MAPS, see Chapter 3).

train() runs the training on a single split. To train on every fold of a KFold, simply loop over the splits:

for split in splitter.get_splits(dataset):
    split.build_train_loader(batch_size=8, shuffle=True)
    split.build_val_loader(batch_size=8)
    trainer.train(split)

2.2.2. Configuring the optimization

How the optimization is run — the number of epochs, gradient accumulation, gradient clipping, how often to evaluate — is described by an OptimizationConfig, passed to the Trainer:

from clinicadl.optim import OptimizationConfig
from clinicadl.train import Trainer

trainer = Trainer(
    maps="maps_directory",
    model=model,
    optimization=OptimizationConfig(
        num_epochs=100,
        accumulation_steps=2,      # virtually doubles the batch size
        evaluation_interval=5,     # evaluate every 5 epochs
    ),
)

2.2.3. Controlling the hardware

Computational aspects — GPU, AMP, memory format, and the seed for reproducibility — are set per training run through a ComputationalConfig passed to train():

from clinicadl.train import ComputationalConfig

trainer.train(
    split,
    computational=ComputationalConfig(gpu=True, amp=True, seed=42, deterministic=True),
)

Tip

Setting seed and deterministic=True makes a training run reproducible. A global seed can also be set once with clinicadl.utils.seed.seed_everything() or clinicadl.utils.seed.seed_everything_context().

2.2.4. Monitoring the training

Monitoring the training is done through metrics and callbacks (logging, early stopping, etc.). We cover metrics in Evaluating and callbacks in Callbacks.

2.2.5. Resuming an interrupted training

Long trainings can be interrupted — a bug, a power cut. As long as a TrainingCheckpointCallback was active (it is one of the default callbacks), the Trainer periodically saves a checkpoint of the training state in the MAPS, and you can pick up where it stopped with resume():

trainer.resume(split_idx=0)

If the Trainer object is no longer in memory — typically in a fresh Python session — rebuild it from the MAPS first, then resume:

from clinicadl.train import Trainer

trainer = Trainer.from_maps("maps_directory")
trainer.resume(split_idx=0)

Important

The computational setup recorded at training time is reused on resume. If the model was first trained on a GPU, make sure a GPU is available when you resume.


Now you know how to setup the training of a model. The next section shows how to evaluate it.