clinicadl.models.ReconstructionModel¶
- class clinicadl.models.ReconstructionModel(network: ~torch.nn.modules.module.Module | ~clinicadl.networks.nn.utils.config.NetworkConfig, loss: ~typing.Callable[[...], ~torch.Tensor] | ~clinicadl.losses.config.configs.LossConfig, optimizer: ~clinicadl.optim.optimizers.config.base.OptimizerConfig, inferer: ~clinicadl.infer.abstract.Inferer = <clinicadl.infer.simple.SimpleInferer object>, **kwargs)[source]¶
A vanilla reconstruction model, to work with simple AutoEncoders like
AutoEncoder.Only the
forward_step()differs fromSupervisedModel.- Parameters:
network (NetworkOrConfig) – The neural network, passed as a
torch.nn.Moduleor aconfiguration object.loss (LossOrConfig) –
The reconstruction loss function, passed as a
callablethat returns a 1-itemTensor, or aconfiguration object.Important
The loss function must have a PyTorch style, with an attribute named
reductionthat can be set tonone.optimizer (OptimizerConfig) – The optimizer, passed as a
configuration object.
See also
SupervisedModelFor supervised training.
- forward_step(batch: Batch) torch.Tensor[source]¶
Performs a pass forward in the neural network and a comparison with the input image.
- Parameters:
batch (Batch) – The batch of
DataPoints.- Returns:
torch.Tensor – The computed loss, as a 1-item
torch.Tensor.
- backward_step(loss: ~torch.Tensor, grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None¶
Performs a classical gradient computation using the loss returned by
forward_step().- Parameters:
loss (torch.Tensor) – The loss on which gradients will be computed.
grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- optimization_step(optimizers: dict[str, ~torch.optim.optimizer.Optimizer], grad_scaler: ~torch.amp.grad_scaler.GradScaler = <torch.amp.grad_scaler.GradScaler object>) None¶
Performs a classical optimization step using the gradients accumulated in
backward_step().- Parameters:
optimizers (dict[str, torch.optim.Optimizer]) – The optimizer, as defined in
build_optimizers().grad_scaler (GradScaler, default=GradScaler(enabled=False)) – A potential torch.amp.GradScaler used to scale gradients.
- evaluation_step(batch: Batch) Batch¶
Passes the input images in the network and saves the output in the batch.
- Parameters:
batch (Batch) – The batch of
DataPoints.- Returns:
Batch – The output
Batch.
- build_optimizers() dict[str, Optimizer]¶
Returns a new instance of the optimizer.
- Returns:
dict[str, optim.Optimizer] – The optimizer, named
"optimizer".
- get_loss_functions() dict[str, Callable[[...], Tensor]]¶
Returns the loss function.
- Returns:
dict[str, Loss] – The loss function, named
"loss".