2.2. Training¶
Training is orchestrated by the Trainer. It ties
together a Model, the data of a
Split, the metrics to monitor and the callbacks to run,
and takes care of the training loop — moving data to the GPU, mixed precision,
gradient accumulation, evaluation, checkpointing, etc. — so that you only provide the
pieces specific to your experiment.
2.2.1. A first training¶
Putting together what we built in Chapter 1 and in Defining a model:
from clinicadl.train import Trainer
# `model` is a Model and `split` a Split (see the previous sections)
trainer = Trainer(maps="maps_directory", model=model)
trainer.train(split)
The Trainer writes everything it produces — trained weights, metrics, logs and
the configuration used — into the maps directory (the MAPS, see
Chapter 3).
train() runs the training on a single split. To
train on every fold of a KFold, simply loop over the
splits:
for split in splitter.get_splits(dataset):
split.build_train_loader(batch_size=8, shuffle=True)
split.build_val_loader(batch_size=8)
trainer.train(split)
2.2.2. Configuring the optimization¶
How the optimization is run — the number of epochs, gradient accumulation, gradient
clipping, how often to evaluate — is described by an
OptimizationConfig, passed to the Trainer:
from clinicadl.optim import OptimizationConfig
from clinicadl.train import Trainer
trainer = Trainer(
maps="maps_directory",
model=model,
optimization=OptimizationConfig(
num_epochs=100,
accumulation_steps=2, # virtually doubles the batch size
evaluation_interval=5, # evaluate every 5 epochs
),
)
2.2.3. Controlling the hardware¶
Computational aspects — GPU, AMP, memory format, and the seed for
reproducibility — are set per training run through a
ComputationalConfig passed to
train():
from clinicadl.train import ComputationalConfig
trainer.train(
split,
computational=ComputationalConfig(gpu=True, amp=True, seed=42, deterministic=True),
)
Tip
Setting seed and deterministic=True makes a training run reproducible.
A global seed can also be set once with
clinicadl.utils.seed.seed_everything() or clinicadl.utils.seed.seed_everything_context().
2.2.4. Monitoring the training¶
Monitoring the training is done through metrics and callbacks (logging, early stopping, etc.). We cover metrics in Evaluating and callbacks in Callbacks.
2.2.5. Resuming an interrupted training¶
Long trainings can be interrupted — a bug, a power cut. As long
as a TrainingCheckpointCallback was active (it is one
of the default callbacks), the Trainer periodically saves a checkpoint of the
training state in the MAPS, and you can pick up where it stopped with
resume():
trainer.resume(split_idx=0)
If the Trainer object is no longer in memory — typically in a fresh Python
session — rebuild it from the MAPS first, then resume:
from clinicadl.train import Trainer
trainer = Trainer.from_maps("maps_directory")
trainer.resume(split_idx=0)
Important
The computational setup recorded at training time is reused on resume. If the model was first trained on a GPU, make sure a GPU is available when you resume.
Now you know how to setup the training of a model. The next section shows how to evaluate it.