clinicadl.data.dataloader.DataLoader

class clinicadl.data.dataloader.DataLoader(dataset: Dataset[SampleT], *, batch_size: int = 1, sampling_weights: str | None = None, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, prefetch_factor: int | None = None, persistent_workers: bool = False, collate_fn: CollateFn | None = None)[source]

To load data in batches.

It inherits from torch.utils.data.DataLoader. Only this dataloader is guaranteed to work with ClinicaDL.

The format of the output of the iterator can be specified via collate_fn.

Parameters:
  • dataset (Dataset) – Dataset from which to load the data.

  • batch_size (PositiveInt, default=1) – Batch size..

  • sampling_weights (Optional[str], default=None) –

    Name of the column in the Dataset.df where to find the sampling weights. The column must contain float values.

    The probability of sampling a certain sample is proportional to the associated value in this column.

    Warning

    sampling_weights doesn’t work with an UnpairedDataset.

  • shuffle (bool, default=True) –

    Whether to shuffle the data.

    Note

    If sampling_weights is passed, the data will be fetched randomly with replacement, no matter the value of shuffle.

  • num_workers (NonNegativeInt, default=0) – Number of workers for data loading.

  • pin_memory (bool, default=True) – Whether to copy tensors into device/CUDA pinned memory before returning them.

  • drop_last (bool, default=False) – Whether to drop the last incomplete batch.

  • prefetch_factor (Optional[int], default=None) – Number of batches loaded in advance by each worker. Can’t be passed if num_workers=0.

  • persistent_workers (bool, default=False) – Whether to maintain the worker processes alive at the end of an epoch. Can’t be passed if num_workers=0.

  • collate_fn (Optional[CollateFn], default=None) – To customize the way samples are collated into batches. See clinicadl.data.dataloader.collate.

Raises:
  • ValueError – If prefetch_factor or persistent_workers is passed, but num_workers=0.

  • ValueError – If the dataset is an UnpairedDataset and sampling_weights is not None.

  • KeyError – If sampling_weights is not None, but there is no column named like sampling_weights in the dataframe of the dataset.

  • ValueError – If sampling_weights is not None and the associated column cannot be converted to float values.

Examples

Data look like:

bids
├── metadata.tsv
├── sub-001
│   ├── ses-M000
│   │   ├── pet
│   │       └── sub-001_ses-M000_trc-18FAV45_pet.nii.gz
    ...
...

The "metadata.tsv" file looks like:

participant_id  session_id   age   sex
sub-001         ses-M000     55.0  M
...
from clinicadl.data.datasets import BidsDataset, PairedDataset
from clinicadl.io.bids import BidsFileType
from clinicadl.data.dataloader import DataLoader

bids = BidsDataset(
    "bids",
    file_type=BidsFileType(data_type="pet", suffix="pet"),
    data="bids/metadata.tsv",
)
dataloader = DataLoader(bids, batch_size=3, shuffle=False)
>>> batch = next(iter(dataloader))
>>> batch
[Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)]

Now, let’s see what happens with a PairedDataset:

paired_dataset = PairedDataset([bids, bids])
dataloader = DataLoader(paired_dataset, batch_size=3, shuffle=False)
>>> batch = next(iter(dataloader))
>>> batch
([Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)],
    [Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1),
    Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)])

Because, the default behavior is to use ToBatchesCollate to collate batches, we obtain here a tuple of \(n\) batches, where \(n\) is the number of datasets that we paired.

See also

torch.utils.data.DataLoader

For more details on the parameters.

set_epoch(epoch: int) None[source]

Sets the epoch.

This ensures a different random ordering for torch.utils.data.distributed.DistributedSampler and a different random mapping for clinicadl.data.datasets.UnpairedDataset for each epoch.

Parameters:

epoch (int) – Epoch number.