Source code for clinicadl.callbacks.base

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping

import pandas as pd
import torch

if TYPE_CHECKING:
    from clinicadl.data.dataloader import Batch, BatchType, DataLoader
    from clinicadl.io.maps import Maps
    from clinicadl.losses.types import LossType
    from clinicadl.metrics import MetricsHandler
    from clinicadl.models import Model
    from clinicadl.optim.config import OptimizationConfig
    from clinicadl.split import Split
    from clinicadl.train import TrainerState
    from clinicadl.train.computational import ComputationalConfig

    from .handler import CallbacksHandler


class Event(str, Enum):
    """Event that can trigger an action from a :py:class:`clinicadl.callbacks.Callback`."""

    EXCEPTION = "on_exception"
    INIT = "on_trainer_init"

    # Training
    TRAIN_START = "on_train_start"
    RESUME = "on_resume"
    TRAIN_END = "on_train_end"
    EPOCH_START = "on_epoch_start"
    EPOCH_END = "on_epoch_end"
    BATCH_START = "on_batch_start"
    BATCH_END = "on_batch_end"
    FORWARD_START = "on_forward_step_start"
    BACKWARD_START = "on_backward_step_start"
    BACKWARD_END = "on_backward_step_end"
    OPTIM_STEP_START = "on_optimization_step_start"
    OPTIM_STEP_END = "on_optimization_step_end"

    # Validation
    VAL_START = "on_validation_start"
    VAL_END = "on_validation_end"
    VALIDATE_START = "on_validate_start"
    VALIDATE_END = "on_validate_end"
    EVAL_START = "on_evaluation_step_start"
    METRIC_START = "on_metrics_computation_start"
    METRIC_END = "on_metrics_computation_end"

    # Test
    TEST_START = "on_test_start"
    TEST_END = "on_test_end"

    # Predict
    PREDICT_START = "on_predict_start"
    PREDICTION_START = "on_prediction_step_start"
    PREDICTION_END = "on_prediction_step_end"
    PREDICT_END = "on_predict_end"


[docs] class Callback: """ To define arbitrary actions to perform at certain points of the training and evaluation workflows. Each method of this class starting by ``"on_..."`` is associated to an event of the training or evaluation phase performed by the :py:class:`~clinicadl.train.Trainer`. By overriding these methods, the user can define action to perform when the event happens. .. important:: Callbacks should capture **NON-ESSENTIAL** logic such as saving checkpoints or logging. The essential logic should be defined in a :py:class:`clinicadl.models.Model`. To define your own callback, you can override any of the method associated to an event. You can also override :py:meth:`state_dict` and :py:meth:`load_state_dict`, to be able the recover the state of your callback when resuming an interrupted training. .. dropdown:: List of events :icon: list-unordered :color: info - :py:meth:`on_trainer_init`: called when the trainer is instantiated. - :py:meth:`on_exception`: called whenever an uncaught exception is raised. .. tab-set:: .. tab-item:: :py:meth:`Trainer.train <clinicadl.train.Trainer.train>` - :py:meth:`on_backward_step_end` - :py:meth:`on_backward_step_start` - :py:meth:`on_batch_end` - :py:meth:`on_batch_start` - :py:meth:`on_epoch_end` - :py:meth:`on_epoch_start` - :py:meth:`on_evaluation_step_start` - :py:meth:`on_forward_step_start` - :py:meth:`on_metrics_computation_end` - :py:meth:`on_metrics_computation_start` - :py:meth:`on_optimization_step_end` - :py:meth:`on_optimization_step_start` - :py:meth:`on_resume` - :py:meth:`on_train_end` - :py:meth:`on_train_start` - :py:meth:`on_validation_end` - :py:meth:`on_validation_start` .. code-block:: :caption: pseudocode :emphasize-lines: 2, 4, 8, 11, 15, 17, 19, 22, 24, 27, 31, 34, 38, 40, 42, 44, 47, 50, 53 IF resume THEN CALL on_resume ELSE CALL on_train_start END IF FOR EACH epoch IN training_epochs DO CALL on_epoch_start FOR EACH batch IN training_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_forward_step_start DO forward_step CALL on_backward_step_start DO backward_step CALL on_backward_step_end IF batch_idx MOD optimization_interval == 0 THEN CALL on_optimization_step_start DO optimization_step CALL on_optimization_step_end END IF CALL on_batch_end END FOR IF epoch_idx MOD evaluation_interval == 0 THEN CALL on_validation_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_validation_end END IF CALL on_epoch_end END FOR CALL on_train_end .. tab-item:: :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>` - :py:meth:`on_batch_end` - :py:meth:`on_batch_start` - :py:meth:`on_evaluation_step_start` - :py:meth:`on_metrics_computation_end` - :py:meth:`on_metrics_computation_start` - :py:meth:`on_validate_end` - :py:meth:`on_validate_start` .. code-block:: :caption: pseudocode :emphasize-lines: 1, 4, 8, 10, 12, 14, 17 CALL on_validate_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_validate_end .. tab-item:: :py:meth:`Trainer.test <clinicadl.train.Trainer.test>` - :py:meth:`on_batch_end` - :py:meth:`on_batch_start` - :py:meth:`on_evaluation_step_start` - :py:meth:`on_metrics_computation_end` - :py:meth:`on_metrics_computation_start` - :py:meth:`on_test_end` - :py:meth:`on_test_start` .. code-block:: :caption: pseudocode :emphasize-lines: 1, 4, 8, 10, 12, 14, 17 CALL on_test_start FOR EACH batch IN validation_dataloader DO CALL on_batch_start MOVE batch TO appropriate_device CALL on_evaluation_step_start DO evaluation_step CALL on_metrics_computation_start DO metrics_computation CALL on_metrics_computation_end CALL on_batch_end END FOR CALL on_test_end """
[docs] def state_dict(self) -> Mapping[str, Any]: """ To get a checkpoint of the current state of the callback. Returns ------- Mapping[str, Any] The current state in a ``dict``. """
[docs] def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: """ Sets to callbacks to a given state. Parameters ---------- state_dict : Mapping[str, Any] The desired state of the ``Callback``, as returned by :py:meth:`state_dict`. """
[docs] def on_backward_step_end( self, *, model: Model, maps: Maps, state: TrainerState, ) -> None: """ Called every time :py:meth:`Model.backward_step <clinicadl.models.Model.backward_step>` has just been called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """
[docs] def on_backward_step_start( self, *, model: Model, maps: Maps, state: TrainerState, loss: LossType, grad_scaler: torch.amp.GradScaler, ) -> None: """ Called every time :py:meth:`Model.backward_step <clinicadl.models.Model.backward_step>` will be called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. .. note:: This event is equivalent to ``on_forward_step_end``. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. loss : LossType The loss output by :py:meth:`Model.forward_step <clinicadl.models.Model.forward_step>` and input by :py:meth:`Model.backward_step <clinicadl.models.Model.backward_step>`. grad_scaler : torch.amp.GradScaler The :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. """
[docs] def on_batch_end( self, *, model: Model, maps: Maps, state: TrainerState, ) -> None: """ Called every time the processing of a batch is completed during training, validation, or test phase. .. note:: This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to :py:meth:`on_evaluation_end`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """
[docs] def on_batch_start( self, *, model: Model, maps: Maps, state: TrainerState, batch: BatchType, ) -> None: """ Called every time a new batch has been loaded in training, validation or test phase. .. note:: This event may be redundant with other events: e.g., in evaluation phases, it is equivalent to :py:meth:`on_evaluation_start` (except if the batch is sent to another device). Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. batch : BatchType The input batch. """
[docs] def on_evaluation_step_start( self, *, model: Model, maps: Maps, state: TrainerState, batch: BatchType, ) -> None: """ Called every time :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>` will be called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`, :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`, or :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. batch : BatchType The batch input to :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>`. """
[docs] def on_epoch_end(self, *, model: Model, maps: Maps, state: TrainerState) -> None: """ Called at the end of an epoch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """
[docs] def on_epoch_start(self, *, model: Model, maps: Maps, state: TrainerState) -> None: """ Called at the beginning of an epoch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """
[docs] def on_exception( self, *, model: Model, maps: Maps, state: TrainerState, exception: Exception, ) -> None: """ Called when an exception interrupts an execution of the :py:class:`~clinicadl.train.Trainer`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. exception : Exception The exception that has been raised. """
[docs] def on_forward_step_start( self, *, model: Model, maps: Maps, state: TrainerState, batch: BatchType ) -> None: """ Called every time :py:meth:`Model.forward_step <clinicadl.models.Model.forward_step>` will be called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. batch : BatchType The batch input to :py:meth:`Model.forward_step <clinicadl.models.Model.forward_step>`. """
[docs] def on_metrics_computation_end( self, *, model: Model, maps: Maps, state: TrainerState, detailed_metrics_df: pd.DataFrame, ) -> None: """ Called every time metrics have just been computed on a batch. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. detailed_metrics_df : pandas.DataFrame The evaluation metrics on the batch. """
[docs] def on_metrics_computation_start( self, *, model: Model, maps: Maps, state: TrainerState, output: Batch, metrics: MetricsHandler, ) -> None: """ Called every time :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>` has been called and metrics will now be computed. .. note:: This event is equivalent to ``on_evaluation_step_end``. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. output : Batch The batch output by :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>`. metrics : MetricsHandler The metrics to compute. """
[docs] def on_optimization_step_end( self, *, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, ) -> None: """ Called every time :py:meth:`Model.optimization_step <clinicadl.models.Model.optimization_step>` has just been called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. optimizers : dict[str, torch.optim.Optimizer] The current :py:class:`torch.optim.Optimizer`, as returned by by :py:meth:`Model.backward_step <clinicadl.models.Model.build_optimizers>`. grad_scaler : torch.amp.GradScaler The :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. """
[docs] def on_optimization_step_start( self, *, model: Model, maps: Maps, state: TrainerState, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, ) -> None: """ Called every time :py:meth:`Model.optimization_step <clinicadl.models.Model.optimization_step>` will be called in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. optimizers : dict[str, torch.optim.Optimizer] The current :py:class:`torch.optim.Optimizer`, as returned by by :py:meth:`Model.backward_step <clinicadl.models.Model.build_optimizers>`. grad_scaler : torch.amp.GradScaler The :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. """
def on_predict_end( self, *, model: Model, maps: Maps, state: TrainerState, ) -> None: """ Called once at the end of :py:meth:`Trainer.predict <clinicadl.train.Trainer.predict>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """ def on_predict_start( self, *, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, group_name: str, callbacks: CallbacksHandler, computational: ComputationalConfig, ) -> None: """ Called once at the beginning of :py:meth:`Trainer.predict <clinicadl.train.Trainer.predict>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. dataloader : DataLoader The dataloader on which the prediction is performed. model_checkpoint : str The model checkpoint currently being used. group_name : str The name given to the data. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. computational : ComputationalConfig The :py:class:`clinicadl.train.ComputationalConfig` defining the computational specifications of the prediction phase. """ def on_prediction_step_end( self, *, model: Model, maps: Maps, state: TrainerState, output: Batch, ) -> None: """ Called every time :py:meth:`Model.prediction <clinicadl.models.Model.evaluation_step>` hast just been called. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. output : Batch The :py:class:`clinicadl.data.dataloader.Batch` output by :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>`. """ def on_prediction_step_start( self, *, model: Model, maps: Maps, state: TrainerState, batch: BatchType, ) -> None: """ Called every time :py:meth:`Model.prediction <clinicadl.models.Model.evaluation_step>` will be called. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. batch : BatchType The batch input to :py:meth:`Model.evaluation_step <clinicadl.models.Model.evaluation_step>`. """
[docs] def on_resume( self, *, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig, ) -> None: """ Called once when :py:meth:`Trainer.train <clinicadl.train.Trainer.train>` is resuming a training. More precisely, this method will be called just before loading the checkpoints. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. split : Split The :py:class:`clinicadl.split.Split` on which training is performed. optimizers : dict[str, torch.optim.Optimizer] The current :py:class:`torch.optim.Optimizer`, as returned by by :py:meth:`Model.backward_step <clinicadl.models.Model.build_optimizers>`. grad_scaler : torch.amp.GradScaler The :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. optimization : OptimizationConfig The optimization specifications of the training phase. metrics : MetricsHandler The validation metrics to compute. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. computational : ComputationalConfig The :py:class:`clinicadl.train.ComputationalConfig` defining the computational specifications of the training phase. """
[docs] def on_test_end( self, *, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, ) -> None: """ Called once at the end of :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``.``test metrics`. """
[docs] def on_test_start( self, *, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, group_name: str, callbacks: CallbacksHandler, computational: ComputationalConfig, ) -> None: """ Called once at the beginning of :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. dataloader : DataLoader The dataloader on which the test is performed. model_checkpoint : str The model checkpoint currently being tested. metrics : MetricsHandler The test metrics to compute. group_name : str The name given to the test data. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. computational : ComputationalConfig The :py:class:`clinicadl.train.ComputationalConfig` defining the computational specifications of the validation phase. """
[docs] def on_train_end( self, *, model: Model, maps: Maps, state: TrainerState, ) -> None: """ Called once at the end of :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. """
[docs] def on_train_start( self, *, model: Model, maps: Maps, state: TrainerState, split: Split, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, optimization: OptimizationConfig, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig, ) -> None: """ Called once at the beginning of :py:meth:`Trainer.train <clinicadl.train.Trainer.train>` if ``resume=False``. If resuming a training, :py:meth:`on_resume` will be called instead. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. split : Split The :py:class:`clinicadl.split.Split` on which training is performed. optimizers : dict[str, torch.optim.Optimizer] The current :py:class:`torch.optim.Optimizer`, as returned by by :py:meth:`Model.backward_step <clinicadl.models.Model.build_optimizers>`. grad_scaler : torch.amp.GradScaler The :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. optimization : OptimizationConfig The optimization specifications of the training phase. metrics : MetricsHandler The validation metrics to compute. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. computational : ComputationalConfig The :py:class:`clinicadl.train.ComputationalConfig` defining the computational specifications of the training phase. """
[docs] def on_trainer_init( self, *, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, optimization: OptimizationConfig, callbacks: CallbacksHandler, ) -> None: """ Called once when the :py:class:`~clinicadl.train.Trainer` is created or restored with :py:meth:`Trainer.from_maps <clinicadl.train.Trainer.from_maps>`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. metrics : MetricsHandler The the metrics passed to the ``Trainer``. optimization : OptimizationConfig The optimization specifications of the training phase. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. """
[docs] def on_validate_end( self, *, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, ) -> None: """ Called at the end of a model validation in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`. Not to be confused with :py:meth:`on_validation_end`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. metrics : MetricsHandler The validation metrics computed. """
[docs] def on_validate_start( self, *, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, model_checkpoint: str, metrics: MetricsHandler, callbacks: CallbacksHandler, computational: ComputationalConfig, ) -> None: """ Called when a model is validated in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`. Not to be confused with :py:meth:`on_validation_start`. .. important:: If ``model_checkpoint=None`` was passed to :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`, all the models saved during training will be validated. Therefore, ``on_validate_start`` will be called for each model. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. dataloader : DataLoader The dataloader on which validation is performed. model_checkpoint : str The model checkpoint currently being validated. metrics : MetricsHandler The validation metrics to compute. callbacks : CallbacksHandler The callbacks passed to the ``Trainer``. computational : ComputationalConfig The :py:class:`clinicadl.train.ComputationalConfig` defining the computational specifications of the validation phase. """
[docs] def on_validation_end( self, *, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, ) -> None: """ Called at the end of every validation loop in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Not to be confused with :py:meth:`on_validate_end`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. metrics : MetricsHandler The validation metrics computed. """
[docs] def on_validation_start( self, *, model: Model, maps: Maps, state: TrainerState, dataloader: DataLoader, metrics: MetricsHandler, ) -> None: """ Called at the beginning of every validation loop in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. Not to be confused with :py:meth:`on_validate_start`. Parameters ---------- model : Model The model associated to the ``Trainer``. maps : Maps The :term:`MAPS` associated to the ``Trainer``. state : TrainerState The current state of the ``Trainer``. dataloader : DataLoader The dataloader on which validation is performed. metrics : MetricsHandler The validation metrics to compute. """