Source code for clinicadl.models.base

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch.amp as amp
import torch.nn as nn
import torch.optim as optim

from clinicadl.utils.objects import JsonReaderWriter

if TYPE_CHECKING:
    from clinicadl.data.dataloader import Batch, BatchType
    from clinicadl.losses.types import Loss, LossType


[docs] class Model(JsonReaderWriter, ABC, nn.Module): """ The base model from which every model that works with ``ClinicaDL`` must inherit. ``Model`` inherits itself from :py:class:`torch.nn.Module`. So you can classically define your neural networks in the ``__init__`` method (don't forget to call ``super().__init__()`` first!). Besides, the following methods must be overwritten: - :py:meth:`forward_step`: defines the forward logic during training; - :py:meth:`backward_step`: defines the gradients computation logic; - :py:meth:`optimization_step`: defines the optimization logic; - :py:meth:`evaluation_step`: defines the evaluation logic; - :py:meth:`build_optimizers`: to build the optimizers used for training; - :py:meth:`get_loss_functions`: to access the loss functions used during training. You can also override :py:meth:`get_summary` to give a description of your neural network(s). .. tip:: Since rewriting all these methods can be tedious, feel free to inherit from an existing ``Model`` with shared logic and rewrite only the relevant methods. See Also -------- :py:class:`~clinicadl.models.SupervisedModel` A ``Model`` for supervised training. :py:class:`~clinicadl.models.ReconstructionModel` A ``Model`` for image reconstruction. """
[docs] @abstractmethod def forward_step(self, batch: BatchType) -> LossType: """ Performs the training forward step using the provided batch of data and returns the computed loss. Several losses can be computed during this step. It is on this loss(es) that the gradients will be computed. .. note:: No need to bother with computational aspects (sending the batch to GPU, :term:`AMP`, etc.), ``ClinicaDL`` takes care of this. Parameters ---------- batch : Union[Batch, Sequence[Batch], dict[Any, Batch]] The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either be a :py:class:`~clinicadl.data.dataloader.Batch`, a sequence of ``Batch`` or a dictionary of ``Batch``. Returns ------- Union[torch.Tensor, dict[str, torch.Tensor]] The computed loss(es), as a **1-item** :py:class:`torch.Tensor`, or a dictionary of such ``Tensors``. """
[docs] @abstractmethod def backward_step( self, loss: LossType, grad_scaler: amp.GradScaler = amp.GradScaler(enabled=False), ) -> None: """ Performs gradient computation using the loss(es) returned by :py:meth:`forward_step`. Parameters ---------- loss : Union[torch.Tensor, dict[str, torch.Tensor]] The loss(es) on which gradient will be computed. grad_scaler : GradScaler, default=GradScaler(enabled=False) A potential :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. """
[docs] @abstractmethod def optimization_step( self, optimizers: dict[str, optim.Optimizer], grad_scaler: amp.GradScaler = amp.GradScaler(enabled=False), ) -> None: """ Performs the optimization step using the gradients accumulated in :py:meth:`backward_step`. .. note:: ``ClinicaDL`` takes care of zeroing gradients after this step. Parameters ---------- optimizers : dict[str, torch.optim.Optimizer] The optimizers, as defined in :py:meth:`build_optimizers`. grad_scaler : GradScaler, default=GradScaler(enabled=False) A potential :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients. """
[docs] @abstractmethod def evaluation_step(self, batch: BatchType) -> Batch: """ Performs the evaluation step where a validation/test batch is passed through the neural network. Metrics will be computed on the outputs of this method. .. note:: No need to bother with computational aspects (sending the batch to GPU, disabling gradients, etc.), ``ClinicaDL`` takes care of this. Parameters ---------- batch : Union[Batch, Sequence[Batch], dict[Any, Batch]] The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either be a :py:class:`~clinicadl.data.dataloader.Batch`, a sequence of ``Batch`` or a dictionary of ``Batch``. Returns ------- Batch The output :py:class:`~clinicadl.data.dataloader.Batch`. .. important:: Even if the input batch is a sequence or a dict of ``Batch``, the output must be a single ``Batch``. Metrics will be computed on each element of this output batch. """
@abstractmethod def prediction_step(self, batch: BatchType) -> Batch: """ Performs inference on a batch. As opposed to :py:meth:`evaluation_step`, no metrics will be computed on the outputs. This method is to use the model for inference once it has been trained and tested. .. note:: No need to send tensors to another device or to wrap your evaluation logic in the ``torch.no_grad()`` context manager, ``ClinicaDL`` takes care of this. Parameters ---------- batch : BatchType The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either a :py:class:`~clinicadl.data.dataloader.Batch`, or a ``tuple`` of ``Batch`` (e.g. if you use :py:class:`~clinicadl.data.datasets.PairedDataset`). Returns ------- Batch The output :py:class:`~clinicadl.data.dataloader.Batch`. .. important:: Even if the input batch is a ``tuple`` of :py:class:`~clinicadl.data.dataloader.Batch`, the output must be a single :py:class:`~clinicadl.data.dataloader.Batch`. Metrics will be computed on each element of this output batch. """
[docs] @abstractmethod def build_optimizers(self) -> dict[str, optim.Optimizer]: """ To build optimizers that will be used during training. All optimizers must be given a name. Returns ------- dict[str, optim.Optimizer] The optimizers and their names. """
[docs] @abstractmethod def get_loss_functions(self) -> dict[str, Loss]: """ To retrieve loss functions used during training. All loss functions must be given a name. .. important:: All loss functions must have a :torch:`PyTorch style <nn.html#loss-functions>`, i.e. a callable that returns a :py:class:`torch.Tensor` and with an attribute named ``reduction`` that can be set to ``"none"``. Returns ------- dict[str, Callable[..., Tensor]] The loss functions and their names. """
[docs] def get_summary( self, input_data: BatchType, ) -> str: """ Returns a summary of your neural network, produced by `torchinfo <https://github.com/TylerYep/torchinfo>`_ for example. If this method is not implemented, the ``nn_summary.txt`` in your :term:`MAPS` will not be created. Parameters ---------- input_data : Union[Batch, Sequence[Batch], dict[Any, Batch]] Input data to pass to the neural network to build the summary. Returns ------- str The summary. """ raise NotImplementedError()
[docs] def reset(self) -> None: """ Resets randomly the neural network's weights and removes all the accumulated gradients. **Only the trainable (i.e. with ``requires_grad=True``) weights will be reset.** """ self.zero_grad() for module in self.modules(): has_trainable_params = any( (p.requires_grad for p in module.parameters(recurse=False)) ) if has_trainable_params and hasattr(module, "reset_parameters"): module.reset_parameters()