clinicadl.models.ReconstructionModel

class clinicadl.models.ReconstructionModel(network: ~torch.nn.modules.module.Module | ~clinicadl.networks.nn.utils.config.NetworkConfig, loss: ~typing.Callable[[...], ~torch.Tensor] | ~clinicadl.losses.config.configs.LossConfig, optimizer: ~clinicadl.optim.optimizers.config.base.OptimizerConfig, inferer: ~clinicadl.infer.abstract.Inferer = <clinicadl.infer.simple.SimpleInferer object>, **kwargs)[source]

A vanilla reconstruction model, to work with simple AutoEncoders like AutoEncoder.

Only the forward_step() differs from SupervisedModel.

Parameters:

See also

SupervisedModel

For supervised training.

forward_step(batch: Batch) torch.Tensor[source]

Performs a pass forward in the neural network and a comparison with the input image.

Parameters:

batch (Batch) – The batch of DataPoints.

Returns:

torch.Tensor – The computed loss, as a 1-item torch.Tensor.

backward_step(loss: ~torch.Tensor, grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None

Performs a classical gradient computation using the loss returned by forward_step().

Parameters:
  • loss (torch.Tensor) – The loss on which gradients will be computed.

  • grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.

optimization_step(optimizers: dict[str, ~torch.optim.optimizer.Optimizer], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None

Performs a classical optimization step using the gradients accumulated in backward_step().

Parameters:
evaluation_step(batch: Batch) Batch

Passes the input images in the network and saves the output in the batch.

Parameters:

batch (Batch) – The batch of DataPoints.

Returns:

Batch – The output Batch.

build_optimizers() dict[str, Optimizer]

Returns a new instance of the optimizer.

Returns:

dict[str, optim.Optimizer] – The optimizer, named "optimizer".

get_loss_functions() dict[str, Callable[[...], Tensor]]

Returns the loss function.

Returns:

dict[str, Loss] – The loss function, named "loss".

get_summary(input_data: Batch) str

Returns a summary of the neural network, produced by torchinfo.

Parameters:

input_data (Batch) – Input data to pass to the neural network to build the summary.

Returns:

str – The summary.

reset() None

Resets randomly the neural network’s weights and removes all the accumulated gradients.

Only the trainable (i.e. with ``requires_grad=True``) weights will be reset.