Source code for clinicadl.data.dataloader.loader

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Optional

from pydantic import Field, NonNegativeInt, PositiveInt, model_validator
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import DistributedSampler, Sampler, WeightedRandomSampler

from clinicadl.data.datasets import (
    UnpairedDataset,
)
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.factories import safe_factory_from_json
from clinicadl.utils.objects import HasConfig
from clinicadl.utils.seed import pl_worker_init_function
from clinicadl.utils.typing import PathType

from ..datasets.base import Dataset, SampleT
from .collate import CollateFn, ToBatchCollate, ToBatchesCollate
from .collate.factory import get_collate_from_dict


def _read_collate(serialized_collate: Any) -> Any:
    """
    To read the field 'collate_fn'.
    """
    if isinstance(serialized_collate, dict):
        return get_collate_from_dict(serialized_collate)
    return serialized_collate


class DataLoaderConfig(ObjectConfig):
    """Config class for ``DataLoader``."""

    batch_size: PositiveInt
    sampling_weights: Optional[str]
    shuffle: bool
    num_workers: NonNegativeInt
    pin_memory: bool
    drop_last: bool
    prefetch_factor: Optional[NonNegativeInt]
    persistent_workers: bool
    collate_fn: Optional[CollateFn] = Field(json_schema_extra={"reader": _read_collate})

    @model_validator(mode="after")
    def _validate_worker_parameters(self):
        """Checks that 'prefetch_factor' is None if 'num_workers' = 0."""
        if self.num_workers == 0 and self.prefetch_factor:
            raise ValueError(
                "'prefetch_factor' option can only be specified num_workers > 0. Got "
                f"prefetch_factor={self.prefetch_factor} and num_workers={self.num_workers}"
            )
        if self.num_workers == 0 and self.persistent_workers:
            raise ValueError(
                "'persistent_workers' option can only be specified num_workers > 0. Got "
                f"persistent_workers={self.persistent_workers} and num_workers={self.num_workers}"
            )
        return self

    @classmethod
    def _get_class(cls) -> Any:
        return DataLoader

    def get_object(self, dataset: Dataset) -> DataLoader:
        """
        Returns the dataloader associated to this configuration,
        parametrized with the parameters passed by the user.

        Parameters
        ----------
        dataset : Dataset
            The dataset from which data are loaded.

        Returns
        -------
        DataLoader
            The ``DataLoader``.
        """
        return self._get_class()(dataset, **self.to_raw_dict())


[docs] class DataLoader(HasConfig["DataLoaderConfig"], TorchDataLoader[SampleT]): """ To load data in batches. It inherits from :py:class:`torch.utils.data.DataLoader`. **Only this dataloader is guaranteed to work with ClinicaDL.** The format of the output of the iterator can be specified via ``collate_fn``. Parameters ---------- dataset : Dataset Dataset from which to load the data. batch_size : PositiveInt, default=1 Batch size.. sampling_weights : Optional[str], default=None Name of the column in the :py:attr:`Dataset.df <clinicadl.data.datasets.Dataset.df>` where to find the sampling weights. The column must contain ``float`` values. The probability of sampling a certain sample is proportional to the associated value in this column. .. warning:: ``sampling_weights`` doesn't work with an :py:class:`~clinicadl.data.datasets.UnpairedDataset`. shuffle : bool, default=True Whether to shuffle the data. .. note:: If ``sampling_weights`` is passed, the data will be fetched randomly with replacement, no matter the value of ``shuffle``. num_workers : NonNegativeInt, default=0 Number of workers for data loading. pin_memory : bool, default=True Whether to copy tensors into device/CUDA pinned memory before returning them. drop_last : bool, default=False Whether to drop the last incomplete batch. prefetch_factor : Optional[int], default=None Number of batches loaded in advance by each worker. Can't be passed if ``num_workers=0``. persistent_workers : bool, default=False Whether to maintain the worker processes alive at the end of an epoch. Can't be passed if ``num_workers=0``. collate_fn : Optional[CollateFn], default=None To customize the way samples are collated into batches. See :py:mod:`clinicadl.data.dataloader.collate`. Raises ------ ValueError If ``prefetch_factor`` or ``persistent_workers`` is passed, but ``num_workers=0``. ValueError If the dataset is an :py:class:`~clinicadl.data.datasets.UnpairedDataset` and ``sampling_weights`` is not ``None``. KeyError If ``sampling_weights`` is not ``None``, but there is no column named like ``sampling_weights`` in the dataframe of the dataset. ValueError If ``sampling_weights`` is not ``None`` and the associated column cannot be converted to float values. Examples -------- .. code-block:: text Data look like: bids ├── metadata.tsv ├── sub-001 │ ├── ses-M000 │ │ ├── pet │ │ └── sub-001_ses-M000_trc-18FAV45_pet.nii.gz ... ... The "metadata.tsv" file looks like: participant_id session_id age sex sub-001 ses-M000 55.0 M ... .. code-block:: python from clinicadl.data.datasets import BidsDataset, PairedDataset from clinicadl.io.bids import BidsFileType from clinicadl.data.dataloader import DataLoader bids = BidsDataset( "bids", file_type=BidsFileType(data_type="pet", suffix="pet"), data="bids/metadata.tsv", ) dataloader = DataLoader(bids, batch_size=3, shuffle=False) .. code-block:: python >>> batch = next(iter(dataloader)) >>> batch [Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)] Now, let's see what happens with a :py:class:`~clinicadl.data.datasets.PairedDataset`: .. code-block:: python paired_dataset = PairedDataset([bids, bids]) dataloader = DataLoader(paired_dataset, batch_size=3, shuffle=False) .. code-block:: python >>> batch = next(iter(dataloader)) >>> batch ([Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)], [Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1), Sample(Keys: ('file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1)]) Because, the default behavior is to use :py:class:`~clinicadl.data.dataloader.ToBatchesCollate` to collate batches, we obtain here a tuple of :math:`n` batches, where :math:`n` is the number of datasets that we paired. See Also -------- :py:class:`torch.utils.data.DataLoader` For more details on the parameters. """ _config_type = DataLoaderConfig dataset: Dataset[SampleT] def __init__( self, dataset: Dataset[SampleT], *, batch_size: int = 1, sampling_weights: Optional[str] = None, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, collate_fn: Optional[CollateFn] = None, ): self.config = self._config_type( batch_size=batch_size, sampling_weights=sampling_weights, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, collate_fn=collate_fn, ) if self.config.collate_fn: collate_fn = self.config.collate_fn else: if isinstance(dataset[0], Sequence): collate_fn = ToBatchesCollate() else: collate_fn = ToBatchCollate() super().__init__( dataset=dataset, sampler=self._generate_sampler(dataset, dp_degree=1, rank=0), worker_init_fn=pl_worker_init_function, collate_fn=collate_fn, **self.config.to_dict( exclude={"sampling_weights", "shuffle", "collate_fn", "name_"} ), )
[docs] def set_epoch(self, epoch: int) -> None: """ Sets the epoch. This ensures a different random ordering for :py:class:`torch.utils.data.distributed.DistributedSampler` and a different random mapping for :py:class:`clinicadl.data.datasets.UnpairedDataset` for each epoch. Parameters ---------- epoch : int Epoch number. """ if callable(set_epoch := getattr(self.sampler, "set_epoch", None)): set_epoch(epoch) if callable(set_epoch := getattr(self.dataset, "set_epoch", None)): set_epoch(epoch)
def _generate_sampler( self, dataset: Dataset, dp_degree: int, rank: int, ) -> Sampler: """ Returns a WeightedRandomSampler if self.sampling_weights is not None, otherwise a a DistributedSampler, even when data parallelism is not performed (in this case the degree of data parallelism is set to 1, so it is equivalent to a simple PyTorch RandomSampler if self.shuffle is True or no sampler if self.shuffle is False). """ if self.config.sampling_weights: weights = self._get_weights(dataset, self.config.sampling_weights) length = len(weights) // dp_degree + int(rank < len(weights) % dp_degree) sampler = WeightedRandomSampler(weights, num_samples=length) # type: ignore else: sampler = DistributedSampler( dataset, num_replicas=dp_degree, rank=rank, shuffle=self.config.shuffle, drop_last=False, # not the same as self.drop_last ) return sampler @staticmethod def _get_weights(dataset: Dataset, weights_name: str) -> list[float]: """ Gets the list of weights from the column of the dataframe. """ if isinstance(dataset, UnpairedDataset): raise ValueError("Can't use 'sampling_weights' with UnpairedDataset.") try: weights = [ dataset.get_sample_info(idx, weights_name) for idx in range(len(dataset)) ] except KeyError as exc: raise KeyError( f"Failed to get the column '{weights_name}' in the dataframe of the dataset." ) from exc try: weights = [float(weight) for weight in weights] except ValueError as exc: raise ValueError( f"Got '{weights_name}' for 'sampling_weights' but cannot convert " "this column to float values." ) from exc return weights @classmethod def from_json( cls, json_path: PathType, dataset: Dataset[SampleT], **kwargs: Any ) -> DataLoader: """ To create the object from a ``json`` file saved with :py:meth:`to_json`. Parameters ---------- json_path : PathType Path to the ``json`` file. dataset : Dataset The dataset from which the data are loaded. kwargs : Any Any field of the ``json`` to override. Returns ------- Self The object instantiated from the file. """ config = cls._config_type.from_json(json_path, **kwargs) return config.get_object(dataset) @classmethod def from_dict( cls, config_dict: dict[str, Any], dataset: Dataset[SampleT], **kwargs: Any, ) -> DataLoader: """ To create the object from a dictionary returned by :py:meth:`to_dict`. Parameters ---------- config_dict : dict[str, Any] The input dictionary. dataset : Dataset The dataset from which the data are loaded. kwargs : Any Any field of the dictionary to override. Returns ------- Self The object instantiated from the dictionary. """ config = cls._config_type.from_dict(config_dict, **kwargs) return config.get_object(dataset)
@safe_factory_from_json(factory=DataLoaderConfig.from_json) def get_dataloader_from_json_safely( json_path: PathType, default: Optional[DataLoaderConfig] = None ) -> tuple[Optional[DataLoaderConfig], list[str]]: """ Factory function to get a :py:class:`DataLoaderConfig` from the file saved with :py:meth:`DataLoaderConfig.to_json`, which will not raised errors. If some fields of the serialized dataloader cannot be read, they will be reported, and the field of ``default`` will be used to override them (if not ``None``). If it was impossible to read the serialized dataloader, the factory returns ``None``. Parameters ---------- json_path : PathType The path to the serialized dataloader. default : Optional[DataLoaderConfig], default=None The :py:class:`DataLoaderConfig` from which to take the default arguments. Returns ------- Optional[DataLoaderConfig] The deserialized DataLoaderConfig. ``None`` if deserialization was impossible. list[str] The list of fields that could not be read in the serialized dataloader. """