from __future__ import annotations
import warnings
from collections.abc import Sequence
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Generator,
Iterator,
Optional,
TypeVar,
Union,
)
import torch
from torch.amp.autocast_mode import autocast
from typing_extensions import Self
from clinicadl.callbacks import CallbacksHandler
from clinicadl.callbacks.base import Event
from clinicadl.data.dataloader.batch import Batch
from clinicadl.data.dataloader.loader import DataLoader, DataLoaderConfig
from clinicadl.data.datasets.factory import get_dataset_from_json
from clinicadl.io.maps.maps import Maps
from clinicadl.metrics import MetricsHandler
from clinicadl.metrics.config import LossMetricConfig
from clinicadl.models.factory import get_model_from_json
from clinicadl.optim.config import OptimizationConfig
from clinicadl.split.split import Split
from clinicadl.train.computational import ComputationalConfig
from clinicadl.train.trainer_state import TrainerState
from clinicadl.utils.dictionary.words import CPU
from clinicadl.utils.enum import TrainerCall
from clinicadl.utils.exceptions import (
CannotReadJsonError,
CannotReadJsonFieldError,
add_note,
)
from clinicadl.utils.seed import seed_everything_context
if TYPE_CHECKING:
from torch.amp.grad_scaler import GradScaler
from torch.optim import Optimizer
from clinicadl.callbacks import Callback
from clinicadl.data.dataloader import BatchType, DataLoader
from clinicadl.data.datasets import Dataset
from clinicadl.io.maps.inference import InferenceDirType
from clinicadl.io.maps.utils import DataDir
from clinicadl.metrics.types import MetricOrConfig
from clinicadl.models import Model
from clinicadl.utils.typing import PathType
T = TypeVar("T")
[docs]
class Trainer:
"""
The core class to manage model **training** and **evaluation**.
This class makes all ``ClinicaDL`` objects work together in order to have
functional training and evaluation pipelines. It handles the logic
common to most PyTorch workflow, while allowing the user to customize the essential
logic in the :py:class:`~clinicadl.models.Model`, or the non-essential logic via
:py:mod:`callbacks <clinicadl.callbacks>`.
``Trainer`` also offers high-performance computing options, configured via
:py:class:`~clinicadl.train.ComputationalConfig`.
``Trainer`` will record all outputs and results in a :term:`Maps`
directory, as well as the configurations used in order to reproduce the experiment.
Main methods:
- :py:meth:`train`: for training your model;
- :py:meth:`resume`: for resuming an interrupted training;
- :py:meth:`validate`: for computing new metrics on your validation data;
- :py:meth:`test`: for evaluating your model on test data.
Parameters
----------
maps : Union[PathType, Maps]
Directory where outputs, results and configurations will be saved.
model : Model
The model to train or evaluate.
metrics : Optional[Union[dict[str, MetricOrConfig], MetricsHandler]], default=None
Dictionary of metric names and metric instances for monitoring model performance.
Metric instances can be passed via a :py:class:`clinicadl.metrics.Metric` or a
:py:mod:`configuration class <clinicadl.metrics.config>`.
A :py:class:`~clinicadl.metrics.MetricsHandler` containing the metrics can also be passed directly.
.. note::
- Here you define the metrics useful within the scope of your ``Trainer``,
but it doesn't mean that they will all always be computed. The list of the metrics to compute
will be given to the relevant methods of the trainer.
- You can still add metrics later via :py:meth:`add_metrics`.
By default, only the loss returned by :py:meth:`Model.get_loss_functions <clinicadl.models.Model.get_loss_functions>`
will be monitored (it is expected to be called ``"loss"``).
optimization : Optional[OptimizationConfig], default=None
Configuration for the optimization of the neural network during training
(e.g., the number of epochs). If ``None``, the default parameters or
:py:class:`~clinicadl.optim.OptimizationConfig` will be used.
callbacks : Optional[Union[list[Callback], CallbacksHandler]], default=None
List of :py:class:`~clinicadl.callbacks.Callback` to customize the trainer.
A :py:class:`~clinicadl.callbacks.CallbacksHandler` containing the desired callbacks can also be passed directly.\n
If ``None``, only the default callbacks in :py:class:`~clinicadl.callbacks.CallbacksHandler` will
be applied.\n
overwrite : bool, default=False
Whether to overwrite the :term:`MAPS` if it already exists.
"""
def __init__(
self,
maps: Union[PathType, Maps],
model: Model,
metrics: Optional[Union[dict[str, MetricOrConfig], MetricsHandler]] = None,
optimization: Optional[OptimizationConfig] = None,
callbacks: Optional[Union[list[Callback], CallbacksHandler]] = None,
overwrite: bool = False,
) -> None:
if not isinstance(maps, Maps):
maps = Maps(maps)
maps.create(overwrite=overwrite)
self._instantiate_attributes(maps, model, metrics, optimization, callbacks)
@property
def maps(self) -> Maps:
"""
The :term:`MAPS` associated to the ``Trainer``.
"""
return self._maps
@property
def model(self) -> Model:
"""
The model associated to the ``Trainer``.
"""
return self._model
@property
def metrics(self) -> MetricsHandler:
"""
The metrics computed by ``Trainer``.
"""
return self._metrics
@property
def callbacks(self) -> CallbacksHandler:
"""
The callbacks called by ``Trainer``.
"""
return self._callbacks
@property
def optimization(self) -> OptimizationConfig:
"""
The optimization configuration of the ``Trainer``.
"""
return self._optim_config
@property
def state(self) -> TrainerState:
"""
The current state of the ``Trainer``.
"""
return self._state
def _instantiate_attributes(
self,
maps: Maps,
model: Model,
metrics: Optional[Union[dict[str, MetricOrConfig], MetricsHandler]],
optimization: Optional[OptimizationConfig],
callbacks: Optional[Union[list[Callback], CallbacksHandler]],
) -> None:
"""
Defines the Trainer's attributes.
"""
self._maps = maps
self._model = model
if not optimization:
optimization = OptimizationConfig()
self._initial_state_dict = None
if not optimization.reset_model:
self._initial_state_dict = deepcopy(self._model.cpu().state_dict())
if metrics is None:
metrics = {"loss": LossMetricConfig(loss_name="loss")}
if isinstance(metrics, MetricsHandler):
self._metrics = metrics
else:
self._metrics = MetricsHandler(**metrics)
self._metrics.init_metrics(self._model)
if isinstance(callbacks, CallbacksHandler):
self._callbacks = callbacks
else:
self._callbacks = CallbacksHandler(callbacks=callbacks or ())
self._optim_config = optimization
self._state = TrainerState()
self._call_event(
Event.INIT,
metrics=self.metrics,
optimization=self.optimization,
callbacks=self.callbacks,
)
[docs]
@classmethod
def from_maps(cls, maps_path: PathType, **kwargs) -> Self:
"""
To restore a ``Trainer`` from a :term:`MAPS` directory.
This classmethod will try to restore the ``Trainer`` that
created the :term:`MAPS` by reading the ``json`` configuration
files saved inside.
If errors happen when restoring a component of the ``Trainer``
(e.g., you used your own :py:class:`~clinicadl.models.Model`, so the ``Trainer`` cannot read it),
you can restore the problematic objects on your own and pass them via keyword arguments
(e.g., ``Trainer.train(..., model=...))``.
Parameters
----------
maps_path : PathType
Path to the :term:`MAPS` directory.
**kwargs : Any
To pass objects that the ``Trainer`` cannot restore.
Returns
-------
Self
The ``Trainer`` associated to the :term:`MAPS`.
Examples
--------
.. code-block::
trainer = Trainer(
maps="maps",
model=MyModel()
...
)
...
.. code-block::
>>> Trainer.from_maps("maps")
CannotReadJsonError: Cannot read the model [...] # if MyModel is not a model supported natively by ClinicaDL
>>> Trainer.from_maps("maps", model=MyModel())
"""
maps = Maps(maps_path)
maps.read()
model = cls._read_attribute_in_json(
"model", maps.model_json, get_model_from_json, kwargs
)
metrics = cls._read_attribute_in_json(
"metrics", maps.metrics_json, MetricsHandler.from_json, kwargs
)
callbacks = cls._read_attribute_in_json(
"callbacks", maps.callbacks_json, CallbacksHandler.from_json, kwargs
)
optimization = OptimizationConfig.from_json(maps.training.optimization_json)
trainer = cls.__new__(cls) # bypass init because maps exist
trainer._instantiate_attributes(maps, model, metrics, optimization, callbacks)
return trainer
@staticmethod
def _read_attribute_in_json(
name: str, path: Path, reader: Callable[[Path], T], kwargs: dict[str, Any]
) -> T:
"""
Reads a Trainer attribute in a json, and handles potential reading errors.
"""
try:
return kwargs.get(name) or reader(path)
except CannotReadJsonFieldError as e:
raise CannotReadJsonError(
f"Cannot read the {name} (in {path}). Please pass it to from_maps via a keyword "
f"argument (e.g. Trainer.from_maps(..., {name}=...))."
) from e
[docs]
def add_metrics(self, **metrics: MetricOrConfig) -> None:
"""
To add new metrics to compute to :py:attr:`metrics`.
Parameters
----------
**metrics : MetricOrConfig
The metrics to add, passed via a :py:class:`clinicadl.metrics.Metric` or a
:py:mod:`configuration class <clinicadl.metrics.config>`.
"""
self.metrics.add_metrics(**metrics)
self.metrics.to_json(self._maps.metrics_json, overwrite=True)
[docs]
def add_callbacks(self, callbacks: Sequence[Callback]) -> None:
"""
To add new callbacks to :py:attr:`callbacks`.
Parameters
----------
callbacks : Sequence[Callback]
The callbacks to add.
"""
self._callbacks.add_callbacks(callbacks)
self._callbacks.to_json(self._maps.callbacks_json, overwrite=True)
[docs]
def train(
self,
split: Split,
computational: Optional[ComputationalConfig] = None,
metrics: Optional[Sequence[str]] = None,
) -> None:
"""
To train a model.
Parameters
----------
split : Split
The split containing the training and validation data.
computational : Optional[ComputationalConfig], default=None
Computational configuration.
This is where you can setup high-performance computing features, or a seed to make
your training reproducible.
metrics : Optional[Sequence[str]], default=None
The names of the metric to compute on the validation data. The metrics mentioned
here must be in :py:attr:`metrics`, so they must have been defined beforehand when
instantiating the ``Trainer`` or via :py:meth:`add_metrics`.\n
By default, all the defined metrics will be computed.
See Also
--------
:py:meth:`resume`
To resume an interrupted training.
"""
self.maps.read()
self._create_split(split.index)
with self._seed_context(
computational.seed, computational.deterministic
), self._exception_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
self._train(
split=split, computational=computational, metrics=metrics, resume=False
)
[docs]
def resume(self, split_idx: int, split: Optional[Split] = None) -> None:
"""
To resume an interrupted training launched with :py:meth:`train`.
``Trainer`` will look for the last checkpoint saved with :py:class:`clinicadl.callbacks.TrainingCheckpointCallback`,
and restart training from there.
``Trainer`` will attempt to load your training and validation data from the :term:`MAPS`
directory, so, typically, providing the split index is sufficient. However, it it fails, you can
manually supply the split.
.. important::
Here, the computational setup will be the same as the one used when training
the model before the interruption. So, if the model was first trained on a GPU,
make sure a GPU is available when calling ``resume``.
Parameters
----------
split_idx : int
The index of the split to resume training on.
split : Optional[Split], default=None
The split containing the training and validation data.\n
If you pass a split here, it will be used; otherwise, ``Trainer``
will attempt to load the split from the :term:`MAPS`.
"""
self.maps.read()
self._check_split_exists(split_idx)
if not split:
split = self._get_split(split_idx)
else:
assert (
split_idx == split.index
), f"split_idx does not match split.index. Got {split_idx} and {split.index}"
computational = self._get_old_computational_config(split_idx)
with self._seed_context(
computational.seed, computational.deterministic
), self._exception_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
self._train(
split=split,
computational=computational,
metrics=None, # metrics filtering is done when loading the checkpoint
resume=True,
)
def _train(
self,
split: Split,
computational: ComputationalConfig,
metrics: Optional[Sequence[str]],
resume: bool,
) -> None:
"""
Instantiates/restarts the required objects (e.g. optimizers),
sends the model to the specified device and starts the
training loop.
"""
if metrics:
metrics_handler = self.metrics.subset(metrics)
else:
metrics_handler = self.metrics
optimizers = self.model.build_optimizers()
grad_scaler = computational.get_scaler()
self._reset_train(
split=split,
metrics=metrics_handler,
optimizers=optimizers,
)
self._model_to(computational)
if resume:
self._call_event(
Event.RESUME,
split=split,
optimizers=optimizers,
grad_scaler=grad_scaler,
optimization=self.optimization,
metrics=metrics_handler,
callbacks=self.callbacks,
computational=computational,
)
else:
self._call_event(
Event.TRAIN_START,
split=split,
optimizers=optimizers,
grad_scaler=grad_scaler,
optimization=self.optimization,
metrics=metrics_handler,
callbacks=self.callbacks,
computational=computational,
)
self._train_loop(
split=split,
optimizers=optimizers,
grad_scaler=grad_scaler,
metrics=metrics_handler,
computational=computational,
)
self._call_event(Event.TRAIN_END)
def _train_loop(
self,
split: Split,
optimizers: dict[str, Optimizer],
grad_scaler: GradScaler,
metrics: MetricsHandler,
computational: ComputationalConfig,
) -> None:
"""
The core training logic with the iteration on the epochs, and the
nested iteration on the training batches.
"""
for epoch in range(
self.state.current_epoch + 1, self.optimization.num_epochs + 1
):
if self.state.should_stop:
break
self._reset_epoch(epoch, train_loader=split.train_loader)
self._call_event(Event.EPOCH_START)
for batch_idx, batch in enumerate(split.train_loader, start=1):
self.state.current_train_batch = batch_idx
self._call_event(Event.BATCH_START, batch=batch)
self._batch_to(batch, computational=computational)
self._call_event(Event.FORWARD_START, batch=batch)
with autocast(
device_type=computational.device.type,
enabled=computational.amp,
):
loss = self.model.forward_step(batch)
self._call_event(
Event.BACKWARD_START, loss=loss, grad_scaler=grad_scaler
)
self.model.backward_step(loss, grad_scaler=grad_scaler)
self._call_event(Event.BACKWARD_END)
if batch_idx % self.optimization.accumulation_steps == 0:
self._call_event(
Event.OPTIM_STEP_START,
optimizers=optimizers,
grad_scaler=grad_scaler,
)
self._clip_gradients()
self.model.optimization_step(
optimizers=optimizers, grad_scaler=grad_scaler
)
self.state.optim_step += 1
grad_scaler.update()
for optimizer in optimizers.values():
optimizer.zero_grad(set_to_none=True)
self._call_event(
Event.OPTIM_STEP_END,
optimizers=optimizers,
grad_scaler=grad_scaler,
)
self._call_event(Event.BATCH_END)
if (
self.state.current_epoch
% self.optimization.evaluation_interval # always validate the first epoch
== 0
):
self._validation(
split.val_loader,
metrics=metrics,
computational=computational,
)
self._call_event(Event.EPOCH_END)
[docs]
def validate(
self,
split_idx: int,
metrics: Sequence[str],
dataloader: Optional[DataLoader] = None,
model_checkpoint: Optional[str] = None,
) -> None:
"""
To evaluate your model on your validation data with new metrics.
This method should be used when you wish to compute new validation
metrics once the training phase is complete.
``Trainer`` attempts to load your validation data from the :term:`MAPS`
directory, so, typically, providing the split index is sufficient. However, it it fails, you can
manually supply the validation dataloader.
.. important::
Here, the computational setup will be the same as the one used when training
the model. So, if the model was trained on a GPU, make sure a GPU is available when
calling ``validate``.
Parameters
----------
split_idx : int
The index of the split on which the model to validate has been trained.
metrics : Sequence[str]
The names of the metrics to compute on the validation data. The metrics mentioned
here must be in :py:attr:`metrics`, so they must have been defined beforehand when
instantiating the ``Trainer`` or via :py:meth:`add_metrics`.
dataloader : Optional[DataLoader], default=None
The dataloader containing the validation data on which the model will be evaluated.
If you pass a dataloader, it will be your validation data, otherwise, ``Trainer``
will attempt to load the validation data from the :term:`MAPS`.
model_checkpoint: Optional[str], default=None
The name of the model checkpoint to validate. If ``None``, all the checkpoints will
be evaluated. The name of the checkpoint must follow one of these formats:
- ``"best-<metric-name>"``: the best model trained on split ``split_idx``
w.r.t. the metric ``<metric-name>``;
- ``"epoch-<epoch-idx>"``: the checkpoint of the model trained on split ``split_idx``
at epoch ``<epoch-idx>``;
- ``"final"``: the model at the end of training on split ``split_idx``.
Examples
--------
.. code-block::
from clinicadl.train import Trainer
from clinicadl.metrics.config import (
LossMetricConfig,
DiceMetricConfig,
SurfaceDiceMetricConfig,
)
from clinicadl.callbacks import ModelCheckpointCallback
trainer = Trainer(
...
metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()},
callbacks=[ModelCheckpointCallback(metric="loss")], # save the best model w.r.t. "loss"
)
trainer.train(...) # training on split 1, validation on "loss" and "dice"
trainer.add_metrics(
nsd=SurfaceDiceMetricConfig(class_thresholds=[1])
) # define a new metric named "nsd"
.. code-block::
trainer.validate(
split_idx=1, metrics=["nsd"], model_checkpoint="best-loss"
) # validation of a specific model checkpoint with the new metric "nsd"
.. code-block::
trainer.validate(split_idx=1, metrics=["nsd"]) # validation of all checkpoints
.. code-block::
>>> trainer.validate(split_idx=1, metrics=["nsd"])
CannotReadJsonError: ClinicaDL could not read the dataloader in ...
>>> trainer.validate(
split_idx=1, metrics=["nsd"], dataloader=dataloader
) # pass the dataloader if a reading error is raised
"""
self.maps.read()
self._check_split_exists(split_idx)
computational = self._get_old_computational_config(split_idx)
with self._seed_context(
computational.seed, computational.deterministic
), self._exception_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
self._validate(
split_idx=split_idx,
metrics=metrics,
dataloader=dataloader,
model_checkpoint=model_checkpoint,
computational=computational,
)
def _validate(
self,
split_idx: int,
metrics: Sequence[str],
dataloader: Optional[DataLoader] = None,
model_checkpoint: Optional[str] = None,
computational: ComputationalConfig = ComputationalConfig(),
) -> None:
"""
Instantiates/restarts the required objects (e.g. metrics),
loads the model weights, sends the model to the specified device and starts the
evaluation loop.
"""
if not dataloader:
dataloader = self._get_dataloader(
self.maps.training.data.validation.splits[split_idx]
)
metrics_handler = self.metrics.subset(metrics)
for chkpt_name, chkpt in self._get_models_in_split(split_idx, model_checkpoint):
self._reset_validate(split_idx, dataloader, metrics_handler)
self._load_model_checkpoint(chkpt)
self._model_to(computational)
self._call_event(
Event.VALIDATE_START,
dataloader=dataloader,
model_checkpoint=chkpt_name,
metrics=metrics_handler,
callbacks=self.callbacks,
computational=computational,
)
self._evaluation_loop(
dataloader,
metrics=metrics_handler,
computational=computational,
)
self._call_event(
Event.VALIDATE_END,
metrics=metrics_handler,
)
[docs]
def test(
self,
model_checkpoint: str,
group_name: str,
metrics: Optional[Sequence[str]] = None,
dataloader: Optional[DataLoader] = None,
computational: ComputationalConfig = ComputationalConfig(),
) -> None:
"""
To test you model.
This method checks for :term:`data leakage` between your test and your training/validation data.
Your test data must be associated with a ``group_name``. If the ``group_name`` already exists
in the :term:`MAPS`, ``Trainer`` attempts to load the test data from it. However, it it fails, you can
manually supply the test dataloader.
Parameters
----------
model_checkpoint : str
The name of the model checkpoint to test. The name of the checkpoint must follow one of these formats:
- ``"split-<split-idx>_best-<metric-name>"``: the best model trained on split ``<split-idx>``
w.r.t. the metric ``<metric-name>``;
- ``"split-<split-idx>_epoch-<epoch-idx>"``: the checkpoint of the model trained on split ``<split-idx>``
at epoch ``<epoch-idx>``;
- ``"split-<split-idx>_final"``: the model at the end of training on split ``<split-idx>``.
group_name : str
The name given to these test data. If the group name was already used, ``Trainer``
will attempt to restore the data from the :term:`MAPS`; otherwise, you must pass
a ``dataloader`` containing your test data.
metrics : Optional[Sequence[str]], default=None
The names of the metric to compute on the test data. The metrics mentioned
here must be in :py:attr:`metrics`, so they must have been defined beforehand when
instantiating the ``Trainer`` or via :py:meth:`add_metrics`.\n
By default, all the defined metrics will be computed.
dataloader : Optional[DataLoader], default=None
The dataloader containing the test data on which the model will be evaluated.
If you pass a dataloader, it will be your test data; otherwise, ``Trainer``
will attempt to load the test data from the :term:`MAPS`.
computational : ComputationalConfig, default=ComputationalConfig()
Computational configuration.
This is where you can setup high-performance computing features.
Examples
--------
.. code-block::
from clinicadl.train import Trainer
from clinicadl.metrics.config import (
LossMetricConfig,
DiceMetricConfig,
)
from clinicadl.callbacks import ModelCheckpointCallback
trainer = Trainer(
...
metrics={"loss": LossMetricConfig(), "dice": DiceMetricConfig()},
callbacks=[
ModelCheckpointCallback(metric="loss")
], # save the best model w.r.t. "loss"
)
trainer.train(...) # training on split 1
.. code-block::
trainer.test(
model_checkpoint="split-1_best-loss", group_name="adni"
) # compute all the metrics
.. code-block::
trainer.test(
model_checkpoint="split-1_best-loss", group_name="adni", metrics=["dice"]
) # compute only "dice"
.. code-block::
>>> trainer.test(model_checkpoint="split-1_best-loss", group_name="adni")
CannotReadJsonError: ClinicaDL could not read the dataloader in ...
>>> trainer.test(
model_checkpoint="split-1_best-loss", group_name="adni", dataloader=dataloader
) # pass the dataloader if the Trainer cannot read the dataloader associated to "adni"
"""
with self._seed_context(
computational.seed, computational.deterministic
), self._exception_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
self._test(
model_checkpoint=model_checkpoint,
metrics=metrics,
group_name=group_name,
dataloader=dataloader,
computational=computational,
)
def _test(
self,
model_checkpoint: str,
group_name: str,
metrics: Optional[Sequence[str]],
dataloader: Optional[DataLoader] = None,
computational: ComputationalConfig = ComputationalConfig(),
) -> None:
"""
Instantiates/restarts the required objects (e.g. metrics),
loads the model weights, sends the model to the specified device and starts the
evaluation loop.
"""
self.maps.read()
if not dataloader:
self._check_group_exists(group_name)
dataloader = self._get_dataloader(self.maps.test.groups[group_name])
self._create_results_dir(
self.maps.test, group_name=group_name, model_checkpoint=model_checkpoint
)
if metrics:
metrics_handler = self.metrics.subset(metrics)
else:
metrics_handler = self.metrics
self._reset_test(dataloader, metrics_handler)
self._load_model_checkpoint(
self.maps.training.get_checkpoint_dir(model_checkpoint).model_pt
)
self._model_to(computational)
self._call_event(
Event.TEST_START,
dataloader=dataloader,
model_checkpoint=model_checkpoint,
metrics=metrics_handler,
group_name=group_name,
callbacks=self.callbacks,
computational=computational,
)
self._evaluation_loop(
dataloader, metrics=metrics_handler, computational=computational
)
self._call_event(
Event.TEST_END,
metrics=metrics_handler,
)
# def predict(
# self,
# model_checkpoint: str,
# group_name: str,
# dataloader: Optional[DataLoader] = None,
# computational: ComputationalConfig = ComputationalConfig(),
# ) -> None:
# """
# .. admonition:: Not Implemented
# :class: warning
# ``Trainer.predict`` will be implemented in a future release.
# """
# raise NotImplementedError(
# "Trainer.predict will be implemented in a future release"
# )
def _validation(
self,
dataloader: DataLoader,
metrics: MetricsHandler,
computational: ComputationalConfig,
) -> None:
"""
Validation phase during training.
"""
self._reset_validation(dataloader, metrics)
self._call_event(
Event.VAL_START,
dataloader=dataloader,
metrics=metrics,
)
self._evaluation_loop(
dataloader,
metrics=metrics,
epoch=self.state.current_epoch,
computational=computational,
)
self._call_event(
Event.VAL_END,
metrics=metrics,
)
def _evaluation_loop(
self,
dataloader: DataLoader,
metrics: MetricsHandler,
computational: ComputationalConfig,
epoch: Optional[int] = None,
) -> None:
"""
The core evaluation logic, with the iteration on the
evaluation batches.
"""
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader, start=1):
if self.state.called == TrainerCall.TEST:
self.state.current_test_batch = batch_idx
else:
self.state.current_val_batch = batch_idx
self._call_event(Event.BATCH_START, batch=batch)
self._batch_to(batch, computational=computational)
self._call_event(Event.EVAL_START, batch=batch)
with autocast(
device_type=computational.device.type,
enabled=computational.amp,
):
output_batch = self.model.evaluation_step(batch)
self._call_event(
Event.METRIC_START,
output=output_batch,
metrics=metrics,
)
metrics_df = metrics(output_batch, epoch=epoch)
self._call_event(
Event.METRIC_END,
detailed_metrics_df=metrics_df,
)
self._call_event(Event.BATCH_END)
metrics.aggregate(epoch=epoch)
def _reset_train(
self, split: Split, metrics: MetricsHandler, optimizers: dict[str, Optimizer]
) -> None:
"""
Resets the relevant objects before training.
"""
self.state.reset_training(
split_idx=split.index, num_epochs=self._optim_config.num_epochs
)
self._reset_model()
split.train_dataset.train()
split.val_dataset.eval()
metrics.reset(reset_df=True)
for optimizer in optimizers.values():
optimizer.zero_grad()
def _reset_model(self):
"""
Resets the model according to the resetting strategy.
"""
self.model.to(
device=CPU, memory_format=torch.contiguous_format
) # otherwise seeding will not be consistent across memory formats and devices!
if self.optimization.reset_model:
self.model.reset()
else:
self.model.load_state_dict(self._initial_state_dict)
def _reset_epoch(self, epoch: int, train_loader: DataLoader) -> None:
"""
Resets the relevant objects before a new epoch.
"""
self.state.reset_epoch(current_epoch=epoch, train_loader=train_loader)
self.model.train()
train_loader.set_epoch(epoch)
def _reset_validation(
self, val_loader: DataLoader, metrics: MetricsHandler
) -> None:
"""
Resets the relevant objects before a validation phase in :py:meth:`train`.
"""
self.state.reset_validation(
split_idx=self.state.split_idx, val_loader=val_loader, in_training=True
)
self.model.eval()
metrics.reset(reset_df=False)
def _reset_validate(
self, split_idx: int, dataloader: DataLoader, metrics: MetricsHandler
) -> None:
"""
Resets the relevant objects before a validation phase in :py:meth:`validate`.
"""
self.state.reset_validation(
split_idx=split_idx, val_loader=dataloader, in_training=False
)
self.model.eval()
dataloader.dataset.eval()
metrics.reset(reset_df=True)
def _reset_test(self, dataloader: DataLoader, metrics: MetricsHandler) -> None:
"""
Resets the relevant objects before a test phase.
"""
self.state.reset_test(dataloader)
self.model.eval()
dataloader.dataset.eval()
metrics.reset(reset_df=True)
@staticmethod
def _seed_context(seed: Optional[int], deterministic: bool) -> ContextManager[None]:
"""
Returns a context manager for reproducibility if a seed is passed and a
null context otherwise.
"""
return (
seed_everything_context(seed=seed, deterministic=deterministic)
if seed is not None
else nullcontext()
)
@contextmanager
def _exception_context(self) -> Generator[None, None, None]:
"""
Context manager to handle exceptions.
"""
try:
yield
except Exception as e:
self._call_event(Event.EXCEPTION, exception=e)
raise
def _model_to(self, comp_config: ComputationalConfig) -> None:
"""
Sends the model to the right device and converts to the specified memory format.
"""
comp_config.check_device()
self.model.to(device=comp_config.device, non_blocking=comp_config.non_blocking)
if comp_config.channels_last:
try:
self.model.to(memory_format=torch.channels_last)
except RuntimeError:
self.model.to(memory_format=torch.channels_last_3d)
@classmethod
def _batch_to(cls, batch: BatchType, computational: ComputationalConfig) -> None:
"""
Sends the data to the right device and converts to the specified memory format.
"""
if isinstance(batch, Batch):
batch.to(
device=computational.device,
non_blocking=computational.non_blocking,
channels_last=computational.channels_last,
)
elif isinstance(batch, dict):
for b in batch.values():
cls._batch_to(b, computational=computational)
else: # batch has been checks by ChecksCallback, so it must be a sequence
for b in batch:
cls._batch_to(b, computational=computational)
def _load_model_checkpoint(self, model_path: Path) -> None:
"""
Loads model weights on CPU.
"""
self.model.to(CPU) # load weights on cpu
state_dict = self.maps.open_file(model_path)
self.model.load_state_dict(state_dict)
def _clip_gradients(
self,
) -> None:
"""
To clip gradients.
"""
if self.optimization.clip_grad_value is not None:
torch.nn.utils.clip_grad_value_(
self.model.parameters(), clip_value=self.optimization.clip_grad_value
)
if self.optimization.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.optimization.clip_grad_norm,
norm_type=self.optimization.grad_norm_type,
)
def _call_event(self, event: Event, **kwargs):
"""
Calls a callback event.
"""
self.callbacks.call_event(
event, model=self.model, maps=self.maps, state=self.state, **kwargs
)
def _get_split(self, split_idx: int) -> Split:
"""
To load a old split.
"""
try:
train_dataset, train_dataloader_config = _get_data(
self.maps.training.data.train.splits[split_idx]
)
val_dataset, val_dataloader_config = _get_data(
self.maps.training.data.validation.splits[split_idx]
)
except CannotReadJsonError as e:
add_note(e, "Please pass directly the split via 'split'.")
raise
train_data = self.maps.open_file(
self.maps.training.data.train.splits[split_idx].data_tsv
)
val_data = self.maps.open_file(
self.maps.training.data.validation.splits[split_idx].data_tsv
)
train_dataset = train_dataset.subset(train_data)
val_dataset = val_dataset.subset(val_data)
split = Split(
index=split_idx, train_dataset=train_dataset, val_dataset=val_dataset
)
split.build_train_loader(**train_dataloader_config.to_raw_dict())
split.build_val_loader(**val_dataloader_config.to_raw_dict())
return split
def _get_models_in_split(
self, split_idx: int, model_checkpoint: Optional[str]
) -> Iterator[tuple[str, Path]]:
"""
Gets either the paths to all the models in a split, or the
path to a specific checkpoint. Returns a generator in both cases.
"""
if model_checkpoint:
yield (
model_checkpoint,
self.maps.training.splits[split_idx]
.models.get_checkpoint_dir(model_checkpoint)
.model_pt,
)
else:
for model_checkpoint in self.maps.training.splits[
split_idx
].models.get_all_models():
yield from self._get_models_in_split(split_idx, model_checkpoint)
def _get_dataloader(self, data_dir: DataDir) -> DataLoader:
"""
To load old dataset and dataloader config and build a dataloader
with them.
"""
try:
dataset, dataloader_config = _get_data(data_dir)
except CannotReadJsonError as e:
add_note(e, "Please pass directly the dataloader via 'dataloader'.")
raise
return dataloader_config.get_object(dataset)
def _get_old_computational_config(self, split_idx) -> ComputationalConfig:
"""
Gets the computational setting from an old training.
"""
return ComputationalConfig.from_json(
self.maps.training.splits[split_idx].computational_json
)
def _create_split(self, split_idx: int) -> None:
"""
Checks that the split directory doesn't exist and creates it.
"""
if split_idx in self.maps.training.splits_list:
raise ValueError(
f"Training on split {split_idx} has already been performed. To relaunch a training on this split, "
"first delete it properly with clinicadl.io.Maps.delete_split; to resume a training on this split, "
"set resume=True."
)
self.maps.training.create_split(split_idx, exist_ok=True)
def _create_results_dir(
self,
inference_dir: InferenceDirType,
group_name: str,
model_checkpoint: str,
) -> None:
"""
Creates the directory to save the results of the checkpoint on the group.
"""
split_idx, chkpt = self.maps.training.read_checkpoint_name(model_checkpoint)
inference_dir.create_group(group_name, exist_ok=True)
group_dir = inference_dir.groups[group_name].results
group_dir.create_split(split_idx, exist_ok=True)
if chkpt in group_dir.splits[split_idx].models_list:
raise FileExistsError(
f"There are already some results for checkpoint '{chkpt}' in {group_dir.splits[split_idx].models[chkpt].path}. "
f"If you want to continue, please first delete the folder."
)
group_dir.splits[split_idx].create_model(chkpt, exist_ok=True)
def _check_split_exists(self, split_idx: int) -> None:
"""
Checks that a split directory exists.
"""
if split_idx not in self.maps.training.splits_list:
raise KeyError(f"No training performed on split {split_idx}.")
def _check_group_exists(self, group: str) -> None:
"""
Checks that a group already exists.
"""
if group not in self.maps.test.groups_list:
raise KeyError(
f"The group you passed ('{group}') does not exist yet, so you must pass a dataloader."
)
def _get_data(data_dir: DataDir) -> tuple[Dataset, DataLoaderConfig]:
"""
To load an old dataset and dataloader.
"""
def _error_msg(obj: str, path: Path) -> str:
return f"ClinicaDL could not read the {obj} in {path}."
try:
dataset = get_dataset_from_json(data_dir.dataset_json)
except Exception as e:
raise CannotReadJsonError(_error_msg("dataset", data_dir.dataset_json)) from e
try:
dataloader_config = DataLoaderConfig.from_json(data_dir.dataloader_json)
except Exception as e:
raise CannotReadJsonError(
_error_msg("dataloader", data_dir.dataloader_json)
) from e
return dataset, dataloader_config