1.5. Batching data for training

When training a neural network, you typically feed it multiple images together, known as a batch. ClinicaDL provides a DataLoader that iterates over a Dataset and groups its samples into a Batch (or a sequence of Batch), ready to be turned into tensors.

1.5.1. Grouping samples in a batch

A Batch is a list of DataPoints with a few extra conveniences. The most useful is get_field(), which gathers one field across all the samples and returns it as a batch-first torch.Tensor when possible (and a plain list otherwise):

from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch

batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.get_field("image").shape
torch.Size([2, 1, 181, 217, 181])
>>> batch.get_field("participant_id")
['sub-colin', 'sub-colin']

A Batch can also move its tensors to a device or memory format via to(), and accept new fields produced by a network via add_field(), add_images() and add_masks() — handy to store a model’s output back next to its input.

1.5.2. Iterating over a dataset with a DataLoader

DataLoader iterates over a Dataset and groups its samples into a Batch (or a sequence of Batch). It is a subclass of torch.utils.data.DataLoader, so it behaves like the PyTorch dataloader you may already know, with many common parameters (batch_size, shuffle, num_workers, pin_memory, etc.).

Important

Only the ClinicaDL DataLoader is guaranteed to work with ClinicaDL datasets. Prefer it over the raw PyTorch one.

from clinicadl.data.datasets import BidsDataset
from clinicadl.io.bids import BidsFileType
from clinicadl.data.dataloader import DataLoader

dataset = BidsDataset(
    "bids_directory",
    file_type=BidsFileType(data_type="anat", suffix="T1w"),
)
loader = DataLoader(dataset, batch_size=3, shuffle=True)
>>> batch = next(iter(loader))
>>> len(batch)
3

Beyond the standard PyTorch arguments, the DataLoader adds sampling_weights: the name of a column of the dataset’s df whose values are used as sampling probabilities. This is convenient to oversample under-represented classes.

1.5.2.1. Collating: how to assemble samples?

How individual samples are assembled into a Batch by the DataLoader is decided by a collate function, passed through the collate_fn argument and described by the abstract CollateFn. ClinicaDL chooses a sensible default, so you usually do not need to set it:

  • ToBatchCollate — the default when each dataset element is a single Sample. It produces a single Batch.

  • ToBatchesCollate — the default when each element is a tuple of samples, as returned by a PairedDataset or UnpairedDataset (see joining datasets). It produces one Batch per dataset.

  • MergeBatchesCollate — to merge such a tuple into a single Batch instead.

So, with a PairedDataset, the loader returns a tuple of batches by default:

from clinicadl.data.datasets import PairedDataset

paired = PairedDataset([dataset_t1, dataset_pet])
loader = DataLoader(paired, batch_size=3, shuffle=False)
>>> batch_t1, batch_pet = next(iter(loader))     # one Batch per modality

But a single batch with MergeBatchesCollate:

from clinicadl.data.dataloader import MergeBatchesCollate

loader = DataLoader(paired, batch_size=3, shuffle=False, collate_fn=MergeBatchesCollate())
>>> batch = next(iter(loader))
>>> batch.get_field("image").shape
torch.Size([3, 2, 181, 217, 181])   # 2 channels because the two images were merged

To define your own collating behaviour, subclass CollateFn and implement its __call__ method.

1.5.2.2. Building DataLoaders from a split

In practice you build one loader for the training set and one for the validation set. The Split returned by a splitter (see Splitting data) provides methods to do so:

split.build_train_loader(batch_size=8, shuffle=True)
split.build_val_loader(batch_size=8)

# now you can access the dataloaders
train_loader = split.train_loader
val_loader = split.val_loader

This closes Chapter 1: you can now load your data, transform it, split it without leakage, and iterate over it in batches. The next chapter demonstrates how these batches are used to train a model