from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch.amp as amp
import torch.nn as nn
import torch.optim as optim
from clinicadl.utils.objects import JsonReaderWriter
if TYPE_CHECKING:
from clinicadl.data.dataloader import Batch, BatchType
from clinicadl.losses.types import Loss, LossType
[docs]
class Model(JsonReaderWriter, ABC, nn.Module):
"""
The base model from which every model that works with ``ClinicaDL`` must inherit.
``Model`` inherits itself from :py:class:`torch.nn.Module`. So you can classically define
your neural networks in the ``__init__`` method (don't forget to call ``super().__init__()`` first!).
Besides, the following methods must be overwritten:
- :py:meth:`forward_step`: defines the forward logic during training;
- :py:meth:`backward_step`: defines the gradients computation logic;
- :py:meth:`optimization_step`: defines the optimization logic;
- :py:meth:`evaluation_step`: defines the evaluation logic;
- :py:meth:`build_optimizers`: to build the optimizers used for training;
- :py:meth:`get_loss_functions`: to access the loss functions used during training.
You can also override :py:meth:`get_summary` to give a description of your neural network(s).
.. tip::
Since rewriting all these methods can be tedious, feel free to inherit from an existing ``Model`` with shared logic
and rewrite only the relevant methods.
See Also
--------
:py:class:`~clinicadl.models.SupervisedModel`
A ``Model`` for supervised training.
:py:class:`~clinicadl.models.ReconstructionModel`
A ``Model`` for image reconstruction.
"""
[docs]
@abstractmethod
def forward_step(self, batch: BatchType) -> LossType:
"""
Performs the training forward step using the provided batch of data and returns
the computed loss.
Several losses can be computed during this step.
It is on this loss(es) that the gradients will be computed.
.. note::
No need to bother with computational aspects (sending the batch to GPU, :term:`AMP`, etc.),
``ClinicaDL`` takes care of this.
Parameters
----------
batch : Union[Batch, Sequence[Batch], dict[Any, Batch]]
The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either be a
:py:class:`~clinicadl.data.dataloader.Batch`, a sequence of ``Batch``
or a dictionary of ``Batch``.
Returns
-------
Union[torch.Tensor, dict[str, torch.Tensor]]
The computed loss(es), as a **1-item** :py:class:`torch.Tensor`, or a dictionary of such ``Tensors``.
"""
[docs]
@abstractmethod
def backward_step(
self,
loss: LossType,
grad_scaler: amp.GradScaler = amp.GradScaler(enabled=False),
) -> None:
"""
Performs gradient computation using the loss(es) returned by :py:meth:`forward_step`.
Parameters
----------
loss : Union[torch.Tensor, dict[str, torch.Tensor]]
The loss(es) on which gradient will be computed.
grad_scaler : GradScaler, default=GradScaler(enabled=False)
A potential :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients.
"""
[docs]
@abstractmethod
def optimization_step(
self,
optimizers: dict[str, optim.Optimizer],
grad_scaler: amp.GradScaler = amp.GradScaler(enabled=False),
) -> None:
"""
Performs the optimization step using the gradients accumulated in
:py:meth:`backward_step`.
.. note::
``ClinicaDL`` takes care of zeroing gradients after this step.
Parameters
----------
optimizers : dict[str, torch.optim.Optimizer]
The optimizers, as defined in :py:meth:`build_optimizers`.
grad_scaler : GradScaler, default=GradScaler(enabled=False)
A potential :torch:`torch.amp.GradScaler <amp.html#gradient-scaling>` used to scale gradients.
"""
[docs]
@abstractmethod
def evaluation_step(self, batch: BatchType) -> Batch:
"""
Performs the evaluation step where a validation/test batch is passed through
the neural network.
Metrics will be computed on the outputs of this method.
.. note::
No need to bother with computational aspects (sending the batch to GPU, disabling gradients, etc.),
``ClinicaDL`` takes care of this.
Parameters
----------
batch : Union[Batch, Sequence[Batch], dict[Any, Batch]]
The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either be a
:py:class:`~clinicadl.data.dataloader.Batch`, a sequence of ``Batch``
or a dictionary of ``Batch``.
Returns
-------
Batch
The output :py:class:`~clinicadl.data.dataloader.Batch`.
.. important::
Even if the input batch is a sequence or a dict of ``Batch``,
the output must be a single ``Batch``. Metrics will be
computed on each element of this output batch.
"""
@abstractmethod
def prediction_step(self, batch: BatchType) -> Batch:
"""
Performs inference on a batch.
As opposed to :py:meth:`evaluation_step`, no metrics will be computed on the outputs. This method is to
use the model for inference once it has been trained and tested.
.. note::
No need to send tensors to another device or to wrap your evaluation logic in the ``torch.no_grad()`` context manager,
``ClinicaDL`` takes care of this.
Parameters
----------
batch : BatchType
The batch of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. It can either a
:py:class:`~clinicadl.data.dataloader.Batch`, or a ``tuple`` of ``Batch``
(e.g. if you use :py:class:`~clinicadl.data.datasets.PairedDataset`).
Returns
-------
Batch
The output :py:class:`~clinicadl.data.dataloader.Batch`.
.. important::
Even if the input batch is a ``tuple`` of :py:class:`~clinicadl.data.dataloader.Batch`,
the output must be a single :py:class:`~clinicadl.data.dataloader.Batch`. Metrics will be
computed on each element of this output batch.
"""
[docs]
@abstractmethod
def build_optimizers(self) -> dict[str, optim.Optimizer]:
"""
To build optimizers that will be used during training.
All optimizers must be given a name.
Returns
-------
dict[str, optim.Optimizer]
The optimizers and their names.
"""
[docs]
@abstractmethod
def get_loss_functions(self) -> dict[str, Loss]:
"""
To retrieve loss functions used during training.
All loss functions must be given a name.
.. important::
All loss functions must have a :torch:`PyTorch style <nn.html#loss-functions>`, i.e. a
callable that returns a :py:class:`torch.Tensor` and with an attribute named ``reduction``
that can be set to ``"none"``.
Returns
-------
dict[str, Callable[..., Tensor]]
The loss functions and their names.
"""
[docs]
def get_summary(
self,
input_data: BatchType,
) -> str:
"""
Returns a summary of your neural network, produced by
`torchinfo <https://github.com/TylerYep/torchinfo>`_ for example.
If this method is not implemented, the ``nn_summary.txt`` in your :term:`MAPS` will not be
created.
Parameters
----------
input_data : Union[Batch, Sequence[Batch], dict[Any, Batch]]
Input data to pass to the neural network to build the summary.
Returns
-------
str
The summary.
"""
raise NotImplementedError()
[docs]
def reset(self) -> None:
"""
Resets randomly the neural network's weights and removes all the accumulated
gradients.
**Only the trainable (i.e. with ``requires_grad=True``) weights will be reset.**
"""
self.zero_grad()
for module in self.modules():
has_trainable_params = any(
(p.requires_grad for p in module.parameters(recurse=False))
)
if has_trainable_params and hasattr(module, "reset_parameters"):
module.reset_parameters()