1.5. Batching data for training¶
When training a neural network, you typically feed it multiple images together,
known as a batch. ClinicaDL provides a
DataLoader that iterates over a
Dataset and groups its samples into a
Batch (or a sequence of Batch), ready to be turned into tensors.
1.5.1. Grouping samples in a batch¶
A Batch is a list of
DataPoints with a few extra
conveniences. The most useful is
get_field(), which gathers one field
across all the samples and returns it as a batch-first torch.Tensor when
possible (and a plain list otherwise):
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.get_field("image").shape
torch.Size([2, 1, 181, 217, 181])
>>> batch.get_field("participant_id")
['sub-colin', 'sub-colin']
A Batch can also move its tensors to a device or memory format via
to(), and accept new fields produced by a
network via add_field(),
add_images() and
add_masks() — handy to store a model’s
output back next to its input.
1.5.2. Iterating over a dataset with a DataLoader¶
DataLoader iterates over a
Dataset and groups its samples into a
Batch (or a sequence of Batch). It is a subclass of
torch.utils.data.DataLoader, so it behaves like the PyTorch dataloader
you may already know, with many common parameters (batch_size, shuffle,
num_workers, pin_memory, etc.).
Important
Only the ClinicaDL DataLoader is guaranteed to work with ClinicaDL datasets.
Prefer it over the raw PyTorch one.
from clinicadl.data.datasets import BidsDataset
from clinicadl.io.bids import BidsFileType
from clinicadl.data.dataloader import DataLoader
dataset = BidsDataset(
"bids_directory",
file_type=BidsFileType(data_type="anat", suffix="T1w"),
)
loader = DataLoader(dataset, batch_size=3, shuffle=True)
>>> batch = next(iter(loader))
>>> len(batch)
3
Beyond the standard PyTorch arguments, the DataLoader adds sampling_weights:
the name of a column of the dataset’s df
whose values are used as sampling probabilities. This is convenient to oversample
under-represented classes.
1.5.2.1. Collating: how to assemble samples?¶
How individual samples are assembled into a Batch by the DataLoader is decided by a collate
function, passed through the collate_fn argument and described by the abstract
CollateFn. ClinicaDL chooses a sensible
default, so you usually do not need to set it:
ToBatchCollate— the default when each dataset element is a singleSample. It produces a singleBatch.ToBatchesCollate— the default when each element is a tuple of samples, as returned by aPairedDatasetorUnpairedDataset(see joining datasets). It produces oneBatchper dataset.MergeBatchesCollate— to merge such a tuple into a singleBatchinstead.
So, with a PairedDataset, the loader returns a
tuple of batches by default:
from clinicadl.data.datasets import PairedDataset
paired = PairedDataset([dataset_t1, dataset_pet])
loader = DataLoader(paired, batch_size=3, shuffle=False)
>>> batch_t1, batch_pet = next(iter(loader)) # one Batch per modality
But a single batch with MergeBatchesCollate:
from clinicadl.data.dataloader import MergeBatchesCollate
loader = DataLoader(paired, batch_size=3, shuffle=False, collate_fn=MergeBatchesCollate())
>>> batch = next(iter(loader))
>>> batch.get_field("image").shape
torch.Size([3, 2, 181, 217, 181]) # 2 channels because the two images were merged
To define your own collating behaviour, subclass
CollateFn and implement its __call__
method.
1.5.2.2. Building DataLoaders from a split¶
In practice you build one loader for the training set and one for the validation
set. The Split returned by a splitter (see
Splitting data) provides methods to do so:
split.build_train_loader(batch_size=8, shuffle=True)
split.build_val_loader(batch_size=8)
# now you can access the dataloaders
train_loader = split.train_loader
val_loader = split.val_loader
This closes Chapter 1: you can now load your data, transform it, split it without leakage, and iterate over it in batches. The next chapter demonstrates how these batches are used to train a model