2.1. Defining a model¶
In ClinicaDL, a model is more than a neural network. A
Model bundles together everything needed to train and
evaluate a network: the network itself, a loss function, an optimizer, and the logic
that defines how a batch flows forward, how gradients are computed, how the weights
are optimized, and how the model is evaluated. By gathering this logic in one object, ClinicaDL can offer a generic
Trainer that works with any model.
2.1.1. The Model class¶
Every model in ClinicaDL inherits from the base Model,
which is itself a torch.nn.Module. Model defines the interface that the
Trainer relies on — a handful of methods that capture the
essential logic of an experiment:
forward_step()— how a batch is passed forward and the loss computed;backward_step()andoptimization_step()— how gradients are computed and applied;evaluation_step()— how inference is performed during model evaluation;build_optimizers()andget_loss_functions()— how the optimizers and loss functions are built.
You will rarely implement all of this yourself. ClinicaDL ships two ready-to-use
models that already define this logic for the most common cases —
SupervisedModel and
ReconstructionModel, described below. When you need a
different behaviour, you can subclass one of them (or Model itself) and override only
the relevant method, as covered in Chapter 4.
2.1.1.1. The SupervisedModel¶
SupervisedModel is the model for usual supervised
tasks — classification, regression and segmentation. You give it a network, a loss
and an optimizer, and tell it which field of your Samples
holds the label:
import torch
from clinicadl.models import SupervisedModel
from clinicadl.networks.nn import CNN
from clinicadl.optim.optimizers.config import AdamConfig
model = SupervisedModel(
network=CNN(
in_shape=(1, 169, 208, 179),
num_outputs=2,
conv_args={"channels": [8, 16, 32]},
),
loss=torch.nn.CrossEntropyLoss(), # any PyTorch-style loss
optimizer=AdamConfig(),
label_key="diagnosis", # the label field in the Sample
)
Three ingredients deserve a closer look:
networkThe neural network. You can pass any
torch.nn.Module— one of the architectures fromclinicadl.networks.nn(see below) or your own — or anetwork configuration object.lossThe loss function. Any PyTorch-style loss works: a callable returning a one-item
torch.Tensorand exposing areductionattribute that can be set to"none". This includes the losses oftorch.nn, the losses of MONAI, your own, or aloss configuration object.optimizerThe optimizer, passed as a
configuration objectsuch asAdamConfigorSGDConfig.
Note
Losses and networks can be passed as raw objects, but the optimizer is always passed as a configuration object here. Configuration classes — which record an object’s parameters in a serialisable, reproducible form — are the subject of Chapter 3.
By default, a SupervisedModel passes the whole image through the network during inference. To run
inference patch-by-patch or slice-by-slice instead, pass an
Inferer via the inferer argument — this is covered
in Evaluating.
2.1.1.2. The ReconstructionModel¶
ReconstructionModel is the counterpart for image
reconstruction, e.g. with an autoencoder. It works just like a
SupervisedModel, except that the loss compares the network’s output to the input
image — so there is no label to specify:
import torch
from clinicadl.models import ReconstructionModel
from clinicadl.networks.nn import AutoEncoder
from clinicadl.optim.optimizers.config import AdamConfig
model = ReconstructionModel(
network=AutoEncoder(
in_shape=(1, 80, 96, 80),
latent_size=128,
conv_args={"channels": [8, 16, 32]},
),
loss=torch.nn.MSELoss(),
optimizer=AdamConfig(),
)
2.1.2. Neural networks¶
clinicadl.networks.nn provides a catalogue of neural networks, all
subclasses of torch.nn.Module, organised in three families:
Builders — generic, fully configurable networks you assemble from your own specifications:
MLP— a multilayer perceptron;ConvEncoder/ConvDecoder— convolutional encoders and decoders;CNN— a convolutional encoder followed by a MLP;Generator— an MLP followed by a convolutional decoder (the symmetric ofCNN);AutoEncoderandVAE.
Common architectures — well-known networks, configurable in their depth and width:
Literature variants — ready-to-use architectures with the exact settings from
their original papers, for instance
ResNet18 … ResNet152,
DenseNet121 … DenseNet201,
or ViTB16.
from clinicadl.networks.nn import ConvEncoder, ResNet18
# a builder: you specify the architecture
encoder = ConvEncoder(spatial_dims=3, in_channels=1, channels=[8, 16, 32])
# a literature variant: ready to use
resnet = ResNet18(num_outputs=2)
Each network has a matching configuration class in
clinicadl.networks.config (e.g.
ResNet18Config), so a network can also be
described in a serialisable way — see Chapter 3.
With a model in hand, you are ready to train it.