clinicadl.models.Model¶
- class clinicadl.models.Model(*args, **kwargs) None[source]¶
The base model from which every model that works with
ClinicaDLmust inherit.Modelinherits itself fromtorch.nn.Module. So you can classically define your neural networks in the__init__method (don’t forget to callsuper().__init__()first!).Besides, the following methods must be overwritten:
forward_step(): defines the forward logic during training;backward_step(): defines the gradients computation logic;optimization_step(): defines the optimization logic;evaluation_step(): defines the evaluation logic;build_optimizers(): to build the optimizers used for training;get_loss_functions(): to access the loss functions used during training.
You can also override
get_summary()to give a description of your neural network(s).Tip
Since rewriting all these methods can be tedious, feel free to inherit from an existing
Modelwith shared logic and rewrite only the relevant methods.See also
SupervisedModelA
Modelfor supervised training.ReconstructionModelA
Modelfor image reconstruction.
- abstract forward_step(batch: BatchType) LossType[source]¶
Performs the training forward step using the provided batch of data and returns the computed loss.
Several losses can be computed during this step.
It is on this loss(es) that the gradients will be computed.
Note
No need to bother with computational aspects (sending the batch to GPU, AMP, etc.),
ClinicaDLtakes care of this.- Parameters:
batch (Union[Batch, Sequence[Batch], dict[Any, Batch]]) – The batch of
DataPoints. It can either be aBatch, a sequence ofBatchor a dictionary ofBatch.- Returns:
Union[torch.Tensor, dict[str, torch.Tensor]] – The computed loss(es), as a 1-item
torch.Tensor, or a dictionary of suchTensors.
- abstract backward_step(loss: ~torch.Tensor | dict[str, ~torch.Tensor], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None[source]¶
Performs gradient computation using the loss(es) returned by
forward_step().- Parameters:
loss (Union[torch.Tensor, dict[str, torch.Tensor]]) – The loss(es) on which gradient will be computed.
grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- abstract optimization_step(optimizers: dict[str, ~torch.optim.optimizer.Optimizer], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None[source]¶
Performs the optimization step using the gradients accumulated in
backward_step().Note
ClinicaDLtakes care of zeroing gradients after this step.- Parameters:
optimizers (dict[str, torch.optim.Optimizer]) – The optimizers, as defined in
build_optimizers().grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- abstract evaluation_step(batch: Batch | Sequence[Batch] | dict[Any, Batch]) Batch[source]¶
Performs the evaluation step where a validation/test batch is passed through the neural network.
Metrics will be computed on the outputs of this method.
Note
No need to bother with computational aspects (sending the batch to GPU, disabling gradients, etc.),
ClinicaDLtakes care of this.- Parameters:
batch (Union[Batch, Sequence[Batch], dict[Any, Batch]]) – The batch of
DataPoints. It can either be aBatch, a sequence ofBatchor a dictionary ofBatch.- Returns:
Batch – The output
Batch.Important
Even if the input batch is a sequence or a dict of
Batch, the output must be a singleBatch. Metrics will be computed on each element of this output batch.
- abstract build_optimizers() dict[str, Optimizer][source]¶
To build optimizers that will be used during training.
All optimizers must be given a name.
- Returns:
dict[str, optim.Optimizer] – The optimizers and their names.
- abstract get_loss_functions() dict[str, Callable[[...], Tensor]][source]¶
To retrieve loss functions used during training.
All loss functions must be given a name.
Important
All loss functions must have a PyTorch style, i.e. a callable that returns a
torch.Tensorand with an attribute namedreductionthat can be set to"none".- Returns:
dict[str, Callable[…, Tensor]] – The loss functions and their names.