from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar, Union
import numpy as np
import torch
from clinicadl.utils.device import DeviceType, check_device
if TYPE_CHECKING:
from clinicadl.data.structures import DataPoint
BatchType = Union["Batch", Sequence["Batch"], dict[Any, "Batch"]]
T = TypeVar("T", bound="DataPoint")
[docs]
class Batch(list[T]):
"""
A batch container for :class:`~clinicadl.data.structures.DataPoint` objects.
``Batch`` is simply a list of ``DataPoints``, with additional useful functions.
Parameters
----------
datapoints : Sequence[DataPoint]
Sequence of :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` forming the batch.
Raises
------
ValueError
If the input sequence is empty.
"""
_device: Optional[torch.device] = None
_non_blocking: bool = False
_channels_last: bool = False
def __init__(self, datapoints: Sequence[T]):
super().__init__(datapoints)
if len(self) == 0:
raise ValueError("The batch is empty!")
@property
def device(self) -> Optional[torch.device]:
"""
The device where :py:class:`Tensors <torch.Tensor>` will be sent
when calling :py:meth:`get_field`.
It is specified via :py:meth:`to`.\n
By default it is ``None``, meaning that the tensors will stay on their
origin device.
"""
return self._device
@property
def channels_last(self) -> bool:
"""
Whether `Channels Last Memory Format <https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_
is used for 4D (NCWH) or 5D (NCDWH) :py:class:`Tensors <torch.Tensor>` returned by :py:meth:`get_field`.
It is specified via :py:meth:`to`.\n
By default, it is ``False``.
"""
return self._channels_last
[docs]
def to(
self,
device: Optional[DeviceType] = None,
non_blocking: bool = False,
channels_last: Optional[bool] = None,
) -> None:
"""
To send the :py:class:`Tensors <torch.Tensor>` in the ``Batch`` on the specified device
and with the specified memory format.
.. important::
Nothing is applied to the ``Batch`` itself, which remains on its original device, but all the tensors obtained with :py:meth:`get_field`
will have the specified device and memory format.
Parameters
----------
device : Optional[DeviceType], default=None
The device where to send the tensors. Can be:
- an ``int``: the device id;
- ``"cuda"``;
- ``"cpu"``
- ``"cuda-<id>"``: where ``<id>`` is the device id;
- a :py:class:`torch.device`;
- ``None``: the device won't be changed.
non_blocking : bool, default=False
"When ``non_blocking`` is set to ``True``, the function attempts to perform the
conversion asynchronously with respect to the host, if possible.
This asynchronous behavior applies to both pinned and pageable memory."
(see :torch:`PyTorch documentation <generated/torch.Tensor.to.html>`)
``non_blocking=True`` may speed up data transfer across devices.
channels_last : Optional[bool], default=None
Whether to use `Channels Last Memory Format <https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_
for 4D (NCWH) or 5D (NCDWH) tensors.
If ``False``, the default contiguous memory format will be used.
If ``None``, the current memory format will be kept.\n
Examples
--------
.. code-block:: python
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
import torch
datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
.. code-block:: python
>>> batch.get_field("image").device
"cpu"
>>> batch.get_field("image").stride() # contiguous memory format (default)
(7109137, 7109137, 39277, 181, 1)
.. code-block:: python
>>> batch.to("cuda", non_blocking=True, channels_last=True)
>>> batch.device
device(type='cuda')
>>> batch.get_field("image").device # get_field is affected
device(type='cuda')
>>> batch.get_field("image").stride() # Channels-last memory format
(7109137, 1, 39277, 181, 1)
"""
if device is not None:
self._device = check_device(device)
self._non_blocking = non_blocking
if channels_last is not None:
self._channels_last = channels_last
[docs]
def get_field(
self,
field_name: str,
dtype: Optional[torch.dtype] = None,
ensure_channel_dim: bool = False,
) -> Union[torch.Tensor, list[Any]]:
"""
Gathers all the values for a key of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`
forming batch.
The function will try to return the output as a batch-first :py:class:`~torch.Tensor`. If not possible,
it will return a list of the values, which are converted to ``Tensors`` if possible.
If the output is a unique ``Tensor``, it will respect the device and the memory format potentially
specified with :py:meth:`to`. Besides, the desired data type can be specified here via ``dtype``.
Parameters
----------
field_name : str
The key to the field in the underlying :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`.
dtype : Optional[torch.dtype], default=None
Specifies the output data type, if the output is a ``Tensor``. If ``None``, the output will not
be cast into a specific data type.
ensure_channel_dim : bool, default=False
If ``True``, a 1D ``Tensor`` output batch (N) will be unsqueezed to a 2D ``Tensor`` with a channel dimension (NC).
Returns
-------
Union[torch.Tensor, list[Any]]
A :py:class:`~torch.Tensor` or a list containing all the values of ``field_name`` in the batch.
Raises
------
KeyError
If not all the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>` have the requested ``field_name``.
Examples
--------
.. code-block:: python
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
.. code-block:: python
>>> datapoint
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2)
>>> datapoint["head"]
LabelMap(shape: (1, 181, 217, 181); spacing: (1.00, 1.00, 1.00); orientation: RAS+; ...)
>>> datapoint["participant_id"]
'sub-colin'
.. code-block:: python
>>> batch.get_field("head").shape
torch.Size([2, 1, 181, 217, 181])
>>> batch.get_field("participant_id")
['sub-colin', 'sub-colin']
"""
# collect all the values and try to convert them to tensors
batch = []
all_tensors = True
for datapoint in self:
value = self._get_field(datapoint, field_name)
try:
value = self._to_tensor(value)
except TypeError:
all_tensors = False
batch.append(value)
if not all_tensors:
return batch
try:
batch = torch.stack(batch, dim=0)
except RuntimeError: # not the same shape, batch as tensor is not possible
return batch
# format the batch tensor
if len(batch.shape) == 1 and ensure_channel_dim: # at least two dimensions
batch = batch.unsqueeze(1)
memory_format = self._get_memory_format(
batch, channels_last=self._channels_last
)
return batch.to(
dtype=dtype,
device=self._device,
non_blocking=self._non_blocking,
memory_format=memory_format,
)
[docs]
def add_field(self, values: Sequence[Any], field_name: str) -> None:
"""
To add a field to the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`
inside the current ``Batch``.
This method is useful for example when a neural network returns a batch of outputs
that one wants to store in the original ``Batch``.
Parameters
----------
values : Sequence[Any]
The values of the field for each element of the ``Batch``. Obviously, the sequence must
be the same size as the ``Batch``.
field_name : str
The name fo the field.
Examples
--------
.. code-block:: python
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
.. code-block:: python
>>> batch[0]
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2)
.. code-block:: python
>>> import torch
>>> batch.add_field(torch.randn(2, 1, 3, 3, 3), "output")
>>> batch[0]
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id', 'output'); images: 2)
>>> batch[0]["output"].shape
torch.Size([1, 3, 3, 3])
"""
assert len(values) == len(self), (
f"'values' must have the same length as the batch. Got {len(values)} values, "
f"whereas the batch has only {len(self)} elements"
)
for datapoint, value in zip(self, values):
datapoint[field_name] = value
[docs]
def add_images(self, images: torch.Tensor, image_name: str) -> None:
"""
To add an image to each of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`
inside the current ``Batch``.
The images are expected to be passed via a batched :py:class:`~torch.Tensor`.
.. important::
The image added in a ``DataPoint`` will
take the same affine matrix as the main image of this ``DataPoint``.
Parameters
----------
images : torch.Tensor
The 4D images to add, passed via a 5D batch-first ``Tensor``.
image_name : str
The name that the image will take in the ``DataPoints``.
Examples
--------
.. code-block:: python
import torch
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
.. code-block:: python
>>> batch.add_images(torch.randn(2, 1, 10, 10, 10), "new_image")
>>> batch[0]["new_image"]
ScalarImage(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.FloatTensor; memory: 3.9 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image
See Also
--------
:py:meth:`add_field`
To add any kind of value to the ``Batch``.
:py:meth:`DataPoint.add_image <clinicadl.data.structures.DataPoint.add_image>`
To add an image to a ``DataPoint``.
"""
for datapoint, value in zip(self, images):
datapoint.add_image(value, image_name)
[docs]
def add_masks(self, masks: torch.Tensor, mask_name: str) -> None:
"""
To add a mask to each of the :py:class:`DataPoints <clinicadl.data.structures.DataPoint>`
inside the current ``Batch``.
The masks are expected to be passed via a batched ``Tensor``.
.. important::
The mask added in a ``DataPoint`` will
take the same affine matrix as the main image of this ``DataPoint``.
Parameters
----------
masks : torch.Tensor
The 4D masks to add, passed via a 5D batch-first ``Tensor``.
mask_name : str
The name that the mask will take in the ``DataPoints``.
Examples
--------
.. code-block:: python
import torch
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
.. code-block:: python
>>> batch.add_masks(torch.randint(0, 2, (2, 1, 10, 10, 10)), "new_mask")
>>> batch[0]["new_mask"]
LabelMap(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.LongTensor; memory: 7.8 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image
See Also
--------
:py:meth:`add_field`
To add any kind of value to the ``Batch``.
:py:meth:`DataPoint.add_mask <clinicadl.data.structures.DataPoint.add_mask>`
To add a mask to a ``DataPoint``.
"""
for datapoint, value in zip(self, masks):
datapoint.add_mask(value, mask_name)
@staticmethod
def _get_field(datapoint: DataPoint, field_name: str) -> Any:
"""Returns the specified field and transforms images to tensors."""
try:
if field_name in datapoint.get_images_names():
return datapoint.get_image_tensor(field_name)
return datapoint[field_name]
except KeyError as e:
raise KeyError(
f"You want to get '{field_name}', but there is no such key in some DataPoints in the batch."
) from e
@classmethod
def _to_tensor(cls, value: Any) -> torch.Tensor:
"""
Tries to convert to a tensor.
"""
if isinstance(value, np.ndarray):
tensor = torch.from_numpy(value)
elif isinstance(value, torch.Tensor):
tensor = value
else:
try:
tensor = torch.tensor(value)
except (TypeError, ValueError, RuntimeError) as exc:
raise TypeError from exc
return tensor.clone()
@staticmethod
def _get_memory_format(
tensor: torch.Tensor,
channels_last: bool,
) -> torch.memory_format:
"""
Gets the desired memory format.
"""
if channels_last:
if len(tensor.shape) == 4:
return torch.channels_last
elif len(tensor.shape) == 5:
return torch.channels_last_3d
else:
return torch.contiguous_format