1.3. Transforming data

Raw neuroimaging data is rarely fed to a network as-is: it is normalised, possibly cut into patches or slices, resized to fit the network, and — during training — augmented. In ClinicaDL, this whole pipeline is described by a single object, the TransformsHandler, which you pass to a dataset through its transforms argument.

1.3.1. Defining the transformation pipeline

A TransformsHandler organises transforms into four stages, applied in this order:

  1. image_transforms — transforms applied to the whole image, before extraction. This is where normalisation belongs, so that statistics are computed on the full image and not on a single patch or slice.

  2. extraction — what you work on: the whole image, patches, or slices.

  3. sample_transforms — transforms applied to a sample (an image, a patch or a slice), after extraction. This is typically where you resize a sample to fit your network.

  4. augmentations — transforms applied last, only during training.

from clinicadl.transforms import TransformsHandler, extraction
import torchio as tio

transforms = TransformsHandler(
    image_transforms=[tio.ZNormalization()],
    extraction=extraction.Patch(patch_size=64),
    sample_transforms=[tio.CropOrPad(64)],
    augmentations=[tio.RandomFlip()],
)

image_transforms, sample_transforms and augmentations are sequences: the handler composes them in order, so the order matters. A dataset using these transforms automatically applies the augmentations only when it is in training mode (see Dataset.train / Dataset.eval).

1.3.2. Patches and slices

The extraction decides what a single element of the dataset is. ClinicaDL provides three extractions, in clinicadl.transforms.extraction:

Image

No extraction: each sample is a whole image. This is the default.

Patch

Each sample is a 3D patch, obtained with a sliding window. The main argument is patch_size (a single value, or one per spatial dimension); overlap, pad_mode and pad_value control how patches are tiled across the image.

Slice

Each sample is a 2D slice taken along the axis specified via slice_direction. Which slices to keep can be chosen with slices, discarded_slices, borders or tsv_path.

The extraction also determines how many samples an image yields, which you can check on a single DataPoint:

from clinicadl.transforms.extraction import Patch
from clinicadl.data.structures.examples import Colin27DataPoint

data = Colin27DataPoint()
patch = Patch(patch_size=64)
>>> data.spatial_shape
(181, 217, 181)
>>> patch.num_samples_per_image(data)
36
>>> patch(data, sample_index=0).spatial_shape
(64, 64, 64)

Slicing works the same way:

from clinicadl.transforms.extraction import Slice

slices = Slice(borders=10, slice_direction=1)
>>> slices.num_samples_per_image(data)
197     # 217 coronal slices minus 2 × 10 borders
>>> slices(data, sample_index=0).spatial_shape
(181, 1, 181)

Remember that, because each image now yields several samples, the length of a dataset is the number of images times the number of samples per image (see Reading BIDS datasets).

1.3.3. Preprocessing, augmentation and post-processing

ClinicaDL works with any callable that takes a DataPoint and returns a DataPoint. This means you can use, as a transform:

For instance, a simple preprocessing-and-augmentation pipeline built entirely from raw TorchIO transforms:

import torchio as tio
from clinicadl.transforms import TransformsHandler, extraction

transforms = TransformsHandler(
    extraction=extraction.Image(),
    image_transforms=[
        tio.ToCanonical(),         # reorient to RAS+
        tio.ZNormalization(),      # intensity normalisation on the whole image
    ],
    augmentations=[
        tio.RandomFlip(axes=("LR",)),
        tio.RandomAffine(degrees=10),
    ],
)

A custom transform is just a function. Here is one that adds a binary head mask derived from the image:

import torch
from clinicadl.data.structures import DataPoint
from clinicadl.data.structures.examples import Colin27DataPoint

def add_foreground_mask(datapoint: DataPoint) -> DataPoint:
    mask = (datapoint.get_image_tensor("image") > 0).to(torch.int)
    datapoint.add_mask(mask, "foreground")
    return datapoint

data = Colin27DataPoint()
>>> add_foreground_mask(data)
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id', 'foreground'); images: 3)

However, it is advised to inherit from torchio.Transform to create your own transforms.

Where do the transforms apply?

Put normalisation in image_transforms (computed on the whole image), resizing to the network input size in sample_transforms (after a patch or slice has been extracted), and data augmentation in augmentations (active only during training).

Note

In this chapter we deliberately use raw transforms — TorchIO transforms and plain functions. ClinicaDL also offers configuration classes that wrap many of these transforms in a serialisable form, for better reproducibility. Those are introduced later, in Chapter 3.


Your data is now read and transformed. Before training, you need to separate it into training, validation and test sets — the topic of the next section.