clinicadl.data.dataloader.Batch

class clinicadl.data.dataloader.Batch(datapoints: Sequence[T])[source]

A batch container for DataPoint objects.

Batch is simply a list of DataPoints, with additional useful functions.

Parameters:

datapoints (Sequence[DataPoint]) – Sequence of DataPoints forming the batch.

Raises:

ValueError – If the input sequence is empty.

property device: device | None

The device where Tensors will be sent when calling get_field(). It is specified via to().

By default it is None, meaning that the tensors will stay on their origin device.

property channels_last: bool

Whether Channels Last Memory Format is used for 4D (NCWH) or 5D (NCDWH) Tensors returned by get_field(). It is specified via to().

By default, it is False.

to(device: str | int | device | None = None, non_blocking: bool = False, channels_last: bool | None = None) None[source]

To send the Tensors in the Batch on the specified device and with the specified memory format.

Important

Nothing is applied to the Batch itself, which remains on its original device, but all the tensors obtained with get_field() will have the specified device and memory format.

Parameters:
  • device (Optional[DeviceType], default=None) –

    The device where to send the tensors. Can be:

    • an int: the device id;

    • "cuda";

    • "cpu"

    • "cuda-<id>": where <id> is the device id;

    • a torch.device;

    • None: the device won’t be changed.

  • non_blocking (bool, default=False) –

    “When non_blocking is set to True, the function attempts to perform the conversion asynchronously with respect to the host, if possible. This asynchronous behavior applies to both pinned and pageable memory.” (see PyTorch documentation)

    non_blocking=True may speed up data transfer across devices.

  • channels_last (Optional[bool], default=None) –

    Whether to use Channels Last Memory Format for 4D (NCWH) or 5D (NCDWH) tensors. If False, the default contiguous memory format will be used. If None, the current memory format will be kept.

Examples

from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
import torch

datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
>>> batch.get_field("image").device
"cpu"
>>> batch.get_field("image").stride()   # contiguous memory format (default)
(7109137, 7109137, 39277, 181, 1)
>>> batch.to("cuda", non_blocking=True, channels_last=True)
>>> batch.device
device(type='cuda')
>>> batch.get_field("image").device   # get_field is affected
device(type='cuda')
>>> batch.get_field("image").stride()   # Channels-last memory format
(7109137, 1, 39277, 181, 1)
get_field(field_name: str, dtype: dtype | None = None, ensure_channel_dim: bool = False) Tensor | list[Any][source]

Gathers all the values for a key of the DataPoints forming batch.

The function will try to return the output as a batch-first Tensor. If not possible, it will return a list of the values, which are converted to Tensors if possible.

If the output is a unique Tensor, it will respect the device and the memory format potentially specified with to(). Besides, the desired data type can be specified here via dtype.

Parameters:
  • field_name (str) – The key to the field in the underlying DataPoints.

  • dtype (Optional[torch.dtype], default=None) – Specifies the output data type, if the output is a Tensor. If None, the output will not be cast into a specific data type.

  • ensure_channel_dim (bool, default=False) – If True, a 1D Tensor output batch (N) will be unsqueezed to a 2D Tensor with a channel dimension (NC).

Returns:

Union[torch.Tensor, list[Any]] – A Tensor or a list containing all the values of field_name in the batch.

Raises:

KeyError – If not all the DataPoints have the requested field_name.

Examples

from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch

datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
>>> datapoint
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2)
>>> datapoint["head"]
LabelMap(shape: (1, 181, 217, 181); spacing: (1.00, 1.00, 1.00); orientation: RAS+; ...)
>>> datapoint["participant_id"]
'sub-colin'
>>> batch.get_field("head").shape
torch.Size([2, 1, 181, 217, 181])
>>> batch.get_field("participant_id")
['sub-colin', 'sub-colin']
add_field(values: Sequence[Any], field_name: str) None[source]

To add a field to the DataPoints inside the current Batch.

This method is useful for example when a neural network returns a batch of outputs that one wants to store in the original Batch.

Parameters:
  • values (Sequence[Any]) – The values of the field for each element of the Batch. Obviously, the sequence must be the same size as the Batch.

  • field_name (str) – The name fo the field.

Examples

from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch

datapoint = Colin27DataPoint()
batch = Batch([datapoint, datapoint])
>>> batch[0]
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2)
>>> import torch
>>> batch.add_field(torch.randn(2, 1, 3, 3, 3), "output")
>>> batch[0]
Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id', 'output'); images: 2)
>>> batch[0]["output"].shape
torch.Size([1, 3, 3, 3])
add_images(images: Tensor, image_name: str) None[source]

To add an image to each of the DataPoints inside the current Batch.

The images are expected to be passed via a batched Tensor.

Important

The image added in a DataPoint will take the same affine matrix as the main image of this DataPoint.

Parameters:
  • images (torch.Tensor) – The 4D images to add, passed via a 5D batch-first Tensor.

  • image_name (str) – The name that the image will take in the DataPoints.

Examples

import torch
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch

batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.add_images(torch.randn(2, 1, 10, 10, 10), "new_image")
>>> batch[0]["new_image"]
ScalarImage(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.FloatTensor; memory: 3.9 KiB)  # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image

See also

add_field()

To add any kind of value to the Batch.

DataPoint.add_image

To add an image to a DataPoint.

add_masks(masks: Tensor, mask_name: str) None[source]

To add a mask to each of the DataPoints inside the current Batch.

The masks are expected to be passed via a batched Tensor.

Important

The mask added in a DataPoint will take the same affine matrix as the main image of this DataPoint.

Parameters:
  • masks (torch.Tensor) – The 4D masks to add, passed via a 5D batch-first Tensor.

  • mask_name (str) – The name that the mask will take in the DataPoints.

Examples

import torch
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch

batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.add_masks(torch.randint(0, 2, (2, 1, 10, 10, 10)), "new_mask")
>>> batch[0]["new_mask"]
LabelMap(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.LongTensor; memory: 7.8 KiB)  # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image

See also

add_field()

To add any kind of value to the Batch.

DataPoint.add_mask

To add a mask to a DataPoint.