clinicadl.models.SupervisedModel

class clinicadl.models.SupervisedModel(network: ~torch.nn.modules.module.Module | ~clinicadl.networks.nn.utils.config.NetworkConfig, loss: ~typing.Callable[[...], ~torch.Tensor] | ~clinicadl.losses.config.configs.LossConfig, optimizer: ~clinicadl.optim.optimizers.config.base.OptimizerConfig, inferer: ~clinicadl.infer.abstract.Inferer = <clinicadl.infer.simple.SimpleInferer object>, label_key: str = 'label')[source]

A vanilla supervised model, for usual classification, regression, or segmentation task.

Parameters:

See also

ReconstructionModel

For image reconstruction.

forward_step(batch: Batch) Tensor[source]

Performs a classical supervised forward step and returns the computed loss.

Parameters:

batch (Batch) – The batch of DataPoints.

Returns:

torch.Tensor – The computed loss, as a 1-item torch.Tensor.

backward_step(loss: ~torch.Tensor, grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None

Performs a classical gradient computation using the loss returned by forward_step().

Parameters:
  • loss (torch.Tensor) – The loss on which gradients will be computed.

  • grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.

optimization_step(optimizers: dict[str, ~torch.optim.optimizer.Optimizer], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None

Performs a classical optimization step using the gradients accumulated in backward_step().

Parameters:
evaluation_step(batch: Batch) Batch

Passes the input images in the network and saves the output in the batch.

Parameters:

batch (Batch) – The batch of DataPoints.

Returns:

Batch – The output Batch.

build_optimizers() dict[str, Optimizer]

Returns a new instance of the optimizer.

Returns:

dict[str, optim.Optimizer] – The optimizer, named "optimizer".

get_loss_functions() dict[str, Callable[[...], Tensor]]

Returns the loss function.

Returns:

dict[str, Loss] – The loss function, named "loss".

get_summary(input_data: Batch) str

Returns a summary of the neural network, produced by torchinfo.

Parameters:

input_data (Batch) – Input data to pass to the neural network to build the summary.

Returns:

str – The summary.

reset() None

Resets randomly the neural network’s weights and removes all the accumulated gradients.

Only the trainable (i.e. with ``requires_grad=True``) weights will be reset.