clinicadl.data.dataloader.DataLoader¶
- class clinicadl.data.dataloader.DataLoader(dataset: Dataset[SampleT], *, batch_size: int = 1, sampling_weights: str | None = None, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, prefetch_factor: int | None = None, persistent_workers: bool = False, collate_fn: CollateFn | None = None)[source]¶
To load data in batches.
It inherits from
torch.utils.data.DataLoader. Only this dataloader is guaranteed to work with ClinicaDL.The format of the output of the iterator can be specified via
collate_fn.- Parameters:
dataset (Dataset) – Dataset from which to load the data.
batch_size (PositiveInt, default=1) – Batch size..
sampling_weights (Optional[str], default=None) –
Name of the column in the
Dataset.dfwhere to find the sampling weights. The column must containfloatvalues.The probability of sampling a certain sample is proportional to the associated value in this column.
Warning
sampling_weightsdoesn’t work with anUnpairedDataset.shuffle (bool, default=True) –
Whether to shuffle the data.
Note
If
sampling_weightsis passed, the data will be fetched randomly with replacement, no matter the value ofshuffle.num_workers (NonNegativeInt, default=0) – Number of workers for data loading.
pin_memory (bool, default=True) – Whether to copy tensors into device/CUDA pinned memory before returning them.
drop_last (bool, default=False) – Whether to drop the last incomplete batch.
prefetch_factor (Optional[int], default=None) – Number of batches loaded in advance by each worker. Can’t be passed if
num_workers=0.persistent_workers (bool, default=False) – Whether to maintain the worker processes alive at the end of an epoch. Can’t be passed if
num_workers=0.collate_fn (Optional[CollateFn], default=None) – To customize the way samples are collated into batches. See
clinicadl.data.dataloader.collate.
- Raises:
ValueError – If
prefetch_factororpersistent_workersis passed, butnum_workers=0.ValueError – If the dataset is an
UnpairedDatasetandsampling_weightsis notNone.KeyError – If
sampling_weightsis notNone, but there is no column named likesampling_weightsin the dataframe of the dataset.ValueError – If
sampling_weightsis notNoneand the associated column cannot be converted to float values.
Examples
Data look like: bids ├── metadata.tsv ├── sub-001 │ ├── ses-M000 │ │ ├── pet │ │ └── sub-001_ses-M000_trc-18FAV45_pet.nii.gz ... ... The "metadata.tsv" file looks like: participant_id session_id age sex sub-001 ses-M000 55.0 M ...from clinicadl.data.datasets import BidsDataset, PairedDataset from clinicadl.io.bids import BidsFileType from clinicadl.data.dataloader import DataLoader bids = BidsDataset( "bids", file_type=BidsFileType(data_type="pet", suffix="pet"), data="bids/metadata.tsv", ) dataloader = DataLoader(bids, batch_size=3, shuffle=False)
>>> batch = next(iter(dataloader)) >>> batch [Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)]
Now, let’s see what happens with a
PairedDataset:paired_dataset = PairedDataset([bids, bids]) dataloader = DataLoader(paired_dataset, batch_size=3, shuffle=False)
>>> batch = next(iter(dataloader)) >>> batch ([Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)], [Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)])
Because, the default behavior is to use
ToBatchesCollateto collate batches, we obtain here a tuple of \(n\) batches, where \(n\) is the number of datasets that we paired.See also
torch.utils.data.DataLoaderFor more details on the parameters.
- set_epoch(epoch: int) None[source]¶
Sets the epoch.
This ensures a different random ordering for
torch.utils.data.distributed.DistributedSamplerand a different random mapping forclinicadl.data.datasets.UnpairedDatasetfor each epoch.- Parameters:
epoch (int) – Epoch number.