Source code for clinicadl.data.dataloader.batch

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar, Union

import numpy as np
import torch

from clinicadl.utils.device import DeviceType, check_device

if TYPE_CHECKING:
    from clinicadl.data.structures import DataPoint

BatchType = Union["Batch", Sequence["Batch"], dict[Any, "Batch"]]

T = TypeVar("T", bound="DataPoint")


[docs] class Batch(list[T]): """ A batch container for :class:`~clinicadl.data.structures.DataPoint` objects. ``Batch`` is simply a list of ``DataPoints``, with additional useful functions. Parameters ---------- datapoints : Sequence[DataPoint] Sequence of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` forming the batch. Raises ------ ValueError If the input sequence is empty. """ _device: Optional[torch.device] = None _non_blocking: bool = False _channels_last: bool = False def __init__(self, datapoints: Sequence[T]): super().__init__(datapoints) if len(self) == 0: raise ValueError("The batch is empty!") @property def device(self) -> Optional[torch.device]: """ The device where :py:class:`Tensors <torch.Tensor>` will be sent when calling :py:meth:`get_field`. It is specified via :py:meth:`to`.\n By default it is ``None``, meaning that the tensors will stay on their origin device. """ return self._device @property def channels_last(self) -> bool: """ Whether `Channels Last Memory Format <https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_ is used for 4D (NCWH) or 5D (NCDWH) :py:class:`Tensors <torch.Tensor>` returned by :py:meth:`get_field`. It is specified via :py:meth:`to`.\n By default, it is ``False``. """ return self._channels_last
[docs] def to( self, device: Optional[DeviceType] = None, non_blocking: bool = False, channels_last: Optional[bool] = None, ) -> None: """ To send the :py:class:`Tensors <torch.Tensor>` in the ``Batch`` on the specified device and with the specified memory format. .. important:: Nothing is applied to the ``Batch`` itself, which remains on its original device, but all the tensors obtained with :py:meth:`get_field` will have the specified device and memory format. Parameters ---------- device : Optional[DeviceType], default=None The device where to send the tensors. Can be: - an ``int``: the device id; - ``"cuda"``; - ``"cpu"`` - ``"cuda-<id>"``: where ``<id>`` is the device id; - a :py:class:`torch.device`; - ``None``: the device won't be changed. non_blocking : bool, default=False "When ``non_blocking`` is set to ``True``, the function attempts to perform the conversion asynchronously with respect to the host, if possible. This asynchronous behavior applies to both pinned and pageable memory." (see :torch:`PyTorch documentation <generated/torch.Tensor.to.html>`) ``non_blocking=True`` may speed up data transfer across devices. channels_last : Optional[bool], default=None Whether to use `Channels Last Memory Format <https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_ for 4D (NCWH) or 5D (NCDWH) tensors. If ``False``, the default contiguous memory format will be used. If ``None``, the current memory format will be kept.\n Examples -------- .. code-block:: python from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch import torch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint]) .. code-block:: python >>> batch.get_field("image").device "cpu" >>> batch.get_field("image").stride() # contiguous memory format (default) (7109137, 7109137, 39277, 181, 1) .. code-block:: python >>> batch.to("cuda", non_blocking=True, channels_last=True) >>> batch.device device(type='cuda') >>> batch.get_field("image").device # get_field is affected device(type='cuda') >>> batch.get_field("image").stride() # Channels-last memory format (7109137, 1, 39277, 181, 1) """ if device is not None: self._device = check_device(device) self._non_blocking = non_blocking if channels_last is not None: self._channels_last = channels_last
[docs] def get_field( self, field_name: str, dtype: Optional[torch.dtype] = None, ensure_channel_dim: bool = False, ) -> Union[torch.Tensor, list[Any]]: """ Gathers all the values for a key of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` forming batch. The function will try to return the output as a batch-first :py:class:`~torch.Tensor`. If not possible, it will return a list of the values, which are converted to ``Tensors`` if possible. If the output is a unique ``Tensor``, it will respect the device and the memory format potentially specified with :py:meth:`to`. Besides, the desired data type can be specified here via ``dtype``. Parameters ---------- field_name : str The key to the field in the underlying :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`. dtype : Optional[torch.dtype], default=None Specifies the output data type, if the output is a ``Tensor``. If ``None``, the output will not be cast into a specific data type. ensure_channel_dim : bool, default=False If ``True``, a 1D ``Tensor`` output batch (N) will be unsqueezed to a 2D ``Tensor`` with a channel dimension (NC). Returns ------- Union[torch.Tensor, list[Any]] A :py:class:`~torch.Tensor` or a list containing all the values of ``field_name`` in the batch. Raises ------ KeyError If not all the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` have the requested ``field_name``. Examples -------- .. code-block:: python from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint]) .. code-block:: python >>> datapoint Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2) >>> datapoint["head"] LabelMap(shape: (1, 181, 217, 181); spacing: (1.00, 1.00, 1.00); orientation: RAS+; ...) >>> datapoint["participant_id"] 'sub-colin' .. code-block:: python >>> batch.get_field("head").shape torch.Size([2, 1, 181, 217, 181]) >>> batch.get_field("participant_id") ['sub-colin', 'sub-colin'] """ # collect all the values and try to convert them to tensors batch = [] all_tensors = True for datapoint in self: value = self._get_field(datapoint, field_name) try: value = self._to_tensor(value) except TypeError: all_tensors = False batch.append(value) if not all_tensors: return batch try: batch = torch.stack(batch, dim=0) except RuntimeError: # not the same shape, batch as tensor is not possible return batch # format the batch tensor if len(batch.shape) == 1 and ensure_channel_dim: # at least two dimensions batch = batch.unsqueeze(1) memory_format = self._get_memory_format( batch, channels_last=self._channels_last ) return batch.to( dtype=dtype, device=self._device, non_blocking=self._non_blocking, memory_format=memory_format, )
[docs] def add_field(self, values: Sequence[Any], field_name: str) -> None: """ To add a field to the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` inside the current ``Batch``. This method is useful for example when a neural network returns a batch of outputs that one wants to store in the original ``Batch``. Parameters ---------- values : Sequence[Any] The values of the field for each element of the ``Batch``. Obviously, the sequence must be the same size as the ``Batch``. field_name : str The name fo the field. Examples -------- .. code-block:: python from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint]) .. code-block:: python >>> batch[0] Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2) .. code-block:: python >>> import torch >>> batch.add_field(torch.randn(2, 1, 3, 3, 3), "output") >>> batch[0] Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id', 'output'); images: 2) >>> batch[0]["output"].shape torch.Size([1, 3, 3, 3]) """ assert len(values) == len(self), ( f"'values' must have the same length as the batch. Got {len(values)} values, " f"whereas the batch has only {len(self)} elements" ) for datapoint, value in zip(self, values): datapoint[field_name] = value
[docs] def add_images(self, images: torch.Tensor, image_name: str) -> None: """ To add an image to each of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` inside the current ``Batch``. The images are expected to be passed via a batched :py:class:`~torch.Tensor`. .. important:: The image added in a ``DataPoint`` will take the same affine matrix as the main image of this ``DataPoint``. Parameters ---------- images : torch.Tensor The 4D images to add, passed via a 5D batch-first ``Tensor``. image_name : str The name that the image will take in the ``DataPoints``. Examples -------- .. code-block:: python import torch from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch batch = Batch([Colin27DataPoint(), Colin27DataPoint()]) .. code-block:: python >>> batch.add_images(torch.randn(2, 1, 10, 10, 10), "new_image") >>> batch[0]["new_image"] ScalarImage(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.FloatTensor; memory: 3.9 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image See Also -------- :py:meth:`add_field` To add any kind of value to the ``Batch``. :py:meth:`DataPoint.add_image <clinicadl.data.structures.DataPoint.add_image>` To add an image to a ``DataPoint``. """ for datapoint, value in zip(self, images): datapoint.add_image(value, image_name)
[docs] def add_masks(self, masks: torch.Tensor, mask_name: str) -> None: """ To add a mask to each of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` inside the current ``Batch``. The masks are expected to be passed via a batched ``Tensor``. .. important:: The mask added in a ``DataPoint`` will take the same affine matrix as the main image of this ``DataPoint``. Parameters ---------- masks : torch.Tensor The 4D masks to add, passed via a 5D batch-first ``Tensor``. mask_name : str The name that the mask will take in the ``DataPoints``. Examples -------- .. code-block:: python import torch from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch batch = Batch([Colin27DataPoint(), Colin27DataPoint()]) .. code-block:: python >>> batch.add_masks(torch.randint(0, 2, (2, 1, 10, 10, 10)), "new_mask") >>> batch[0]["new_mask"] LabelMap(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.LongTensor; memory: 7.8 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image See Also -------- :py:meth:`add_field` To add any kind of value to the ``Batch``. :py:meth:`DataPoint.add_mask <clinicadl.data.structures.DataPoint.add_mask>` To add a mask to a ``DataPoint``. """ for datapoint, value in zip(self, masks): datapoint.add_mask(value, mask_name)
@staticmethod def _get_field(datapoint: DataPoint, field_name: str) -> Any: """Returns the specified field and transforms images to tensors.""" try: if field_name in datapoint.get_images_names(): return datapoint.get_image_tensor(field_name) return datapoint[field_name] except KeyError as e: raise KeyError( f"You want to get '{field_name}', but there is no such key in some DataPoints in the batch." ) from e @classmethod def _to_tensor(cls, value: Any) -> torch.Tensor: """ Tries to convert to a tensor. """ if isinstance(value, np.ndarray): tensor = torch.from_numpy(value) elif isinstance(value, torch.Tensor): tensor = value else: try: tensor = torch.tensor(value) except (TypeError, ValueError, RuntimeError) as exc: raise TypeError from exc return tensor.clone() @staticmethod def _get_memory_format( tensor: torch.Tensor, channels_last: bool, ) -> torch.memory_format: """ Gets the desired memory format. """ if channels_last: if len(tensor.shape) == 4: return torch.channels_last elif len(tensor.shape) == 5: return torch.channels_last_3d else: return torch.contiguous_format