clinicadl.data.dataloader.Batch¶
- class clinicadl.data.dataloader.Batch(datapoints: Sequence[T])[source]¶
A batch container for
DataPointobjects.Batchis simply a list ofDataPoints, with additional useful functions.- Parameters:
datapoints (Sequence[DataPoint]) – Sequence of
DataPointsforming the batch.- Raises:
ValueError – If the input sequence is empty.
- property device: device | None¶
The device where
Tensorswill be sent when callingget_field(). It is specified viato().By default it is
None, meaning that the tensors will stay on their origin device.
- property channels_last: bool¶
Whether Channels Last Memory Format is used for 4D (NCWH) or 5D (NCDWH)
Tensorsreturned byget_field(). It is specified viato().By default, it is
False.
- to(device: str | int | device | None = None, non_blocking: bool = False, channels_last: bool | None = None) None[source]¶
To send the
Tensorsin theBatchon the specified device and with the specified memory format.Important
Nothing is applied to the
Batchitself, which remains on its original device, but all the tensors obtained withget_field()will have the specified device and memory format.- Parameters:
device (Optional[DeviceType], default=None) –
The device where to send the tensors. Can be:
an
int: the device id;"cuda";"cpu""cuda-<id>": where<id>is the device id;a
torch.device;None: the device won’t be changed.
non_blocking (bool, default=False) –
“When
non_blockingis set toTrue, the function attempts to perform the conversion asynchronously with respect to the host, if possible. This asynchronous behavior applies to both pinned and pageable memory.” (see PyTorch documentation)non_blocking=Truemay speed up data transfer across devices.channels_last (Optional[bool], default=None) –
Whether to use Channels Last Memory Format for 4D (NCWH) or 5D (NCDWH) tensors. If
False, the default contiguous memory format will be used. IfNone, the current memory format will be kept.
Examples
from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch import torch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint])
>>> batch.get_field("image").device "cpu" >>> batch.get_field("image").stride() # contiguous memory format (default) (7109137, 7109137, 39277, 181, 1)
>>> batch.to("cuda", non_blocking=True, channels_last=True) >>> batch.device device(type='cuda') >>> batch.get_field("image").device # get_field is affected device(type='cuda') >>> batch.get_field("image").stride() # Channels-last memory format (7109137, 1, 39277, 181, 1)
- get_field(field_name: str, dtype: dtype | None = None, ensure_channel_dim: bool = False) Tensor | list[Any][source]¶
Gathers all the values for a key of the
DataPointsforming batch.The function will try to return the output as a batch-first
Tensor. If not possible, it will return a list of the values, which are converted toTensorsif possible.If the output is a unique
Tensor, it will respect the device and the memory format potentially specified withto(). Besides, the desired data type can be specified here viadtype.- Parameters:
field_name (str) – The key to the field in the underlying
DataPoints.dtype (Optional[torch.dtype], default=None) – Specifies the output data type, if the output is a
Tensor. IfNone, the output will not be cast into a specific data type.ensure_channel_dim (bool, default=False) – If
True, a 1DTensoroutput batch (N) will be unsqueezed to a 2DTensorwith a channel dimension (NC).
- Returns:
Union[torch.Tensor, list[Any]] – A
Tensoror a list containing all the values offield_namein the batch.- Raises:
KeyError – If not all the
DataPointshave the requestedfield_name.
Examples
from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint])
>>> datapoint Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2) >>> datapoint["head"] LabelMap(shape: (1, 181, 217, 181); spacing: (1.00, 1.00, 1.00); orientation: RAS+; ...) >>> datapoint["participant_id"] 'sub-colin'
>>> batch.get_field("head").shape torch.Size([2, 1, 181, 217, 181]) >>> batch.get_field("participant_id") ['sub-colin', 'sub-colin']
- add_field(values: Sequence[Any], field_name: str) None[source]¶
To add a field to the
DataPointsinside the currentBatch.This method is useful for example when a neural network returns a batch of outputs that one wants to store in the original
Batch.- Parameters:
values (Sequence[Any]) – The values of the field for each element of the
Batch. Obviously, the sequence must be the same size as theBatch.field_name (str) – The name fo the field.
Examples
from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch datapoint = Colin27DataPoint() batch = Batch([datapoint, datapoint])
>>> batch[0] Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id'); images: 2)
>>> import torch >>> batch.add_field(torch.randn(2, 1, 3, 3, 3), "output") >>> batch[0] Colin27DataPoint(Keys: ('head', 'image', 'participant_id', 'session_id', 'output'); images: 2) >>> batch[0]["output"].shape torch.Size([1, 3, 3, 3])
- add_images(images: Tensor, image_name: str) None[source]¶
To add an image to each of the
DataPointsinside the currentBatch.The images are expected to be passed via a batched
Tensor.Important
The image added in a
DataPointwill take the same affine matrix as the main image of thisDataPoint.- Parameters:
images (torch.Tensor) – The 4D images to add, passed via a 5D batch-first
Tensor.image_name (str) – The name that the image will take in the
DataPoints.
Examples
import torch from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.add_images(torch.randn(2, 1, 10, 10, 10), "new_image") >>> batch[0]["new_image"] ScalarImage(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.FloatTensor; memory: 3.9 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image
See also
add_field()To add any kind of value to the
Batch.DataPoint.add_imageTo add an image to a
DataPoint.
- add_masks(masks: Tensor, mask_name: str) None[source]¶
To add a mask to each of the
DataPointsinside the currentBatch.The masks are expected to be passed via a batched
Tensor.Important
The mask added in a
DataPointwill take the same affine matrix as the main image of thisDataPoint.- Parameters:
masks (torch.Tensor) – The 4D masks to add, passed via a 5D batch-first
Tensor.mask_name (str) – The name that the mask will take in the
DataPoints.
Examples
import torch from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch batch = Batch([Colin27DataPoint(), Colin27DataPoint()])
>>> batch.add_masks(torch.randint(0, 2, (2, 1, 10, 10, 10)), "new_mask") >>> batch[0]["new_mask"] LabelMap(shape: (1, 10, 10, 10); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.LongTensor; memory: 7.8 KiB) # spacing (1.00, 1.00, 1.00) like Colin27DataPoint().image
See also
add_field()To add any kind of value to the
Batch.DataPoint.add_maskTo add a mask to a
DataPoint.