clinicadl.models.SupervisedModel¶
- class clinicadl.models.SupervisedModel(network: ~torch.nn.modules.module.Module | ~clinicadl.networks.nn.utils.config.NetworkConfig, loss: ~typing.Callable[[...], ~torch.Tensor] | ~clinicadl.losses.config.configs.LossConfig, optimizer: ~clinicadl.optim.optimizers.config.base.OptimizerConfig, inferer: ~clinicadl.infer.abstract.Inferer = <clinicadl.infer.simple.SimpleInferer object>, label_key: str = 'label')[source]¶
A vanilla supervised model, for usual classification, regression, or segmentation task.
- Parameters:
network (NetworkOrConfig) – The neural network, passed as a
torch.nn.Moduleor aconfiguration object.loss (LossOrConfig) –
The loss function, passed as a
callablethat returns a 1-itemTensor, or aconfiguration object.Important
The loss function must have a PyTorch style, with an attribute named
reductionthat can be set tonone.optimizer (OptimizerConfig) – The optimizer, passed as a
configuration object.label_key (str, default="label") – The key of the label in the training
samples.
See also
ReconstructionModelFor image reconstruction.
- forward_step(batch: Batch) Tensor[source]¶
Performs a classical supervised forward step and returns the computed loss.
- Parameters:
batch (Batch) – The batch of
DataPoints.- Returns:
torch.Tensor – The computed loss, as a 1-item
torch.Tensor.
- backward_step(loss: ~torch.Tensor, grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None¶
Performs a classical gradient computation using the loss returned by
forward_step().- Parameters:
loss (torch.Tensor) – The loss on which gradients will be computed.
grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- optimization_step(optimizers: dict[str, ~torch.optim.optimizer.Optimizer], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None¶
Performs a classical optimization step using the gradients accumulated in
backward_step().- Parameters:
optimizers (dict[str, torch.optim.Optimizer]) – The optimizer, as defined in
build_optimizers().grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- evaluation_step(batch: Batch) Batch¶
Passes the input images in the network and saves the output in the batch.
- Parameters:
batch (Batch) – The batch of
DataPoints.- Returns:
Batch – The output
Batch.
- build_optimizers() dict[str, Optimizer]¶
Returns a new instance of the optimizer.
- Returns:
dict[str, optim.Optimizer] – The optimizer, named
"optimizer".
- get_loss_functions() dict[str, Callable[[...], Tensor]]¶
Returns the loss function.
- Returns:
dict[str, Loss] – The loss function, named
"loss".