from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar
import numpy as np
import torch
import torchio as tio
from clinicadl.data.structures import Sample2D
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.dictionary.words import (
FILE_TYPE,
IMAGE_PATH,
PARTICIPANT_ID,
SESSION_ID,
SQUEEZE,
)
from clinicadl.utils.numerics import merge_numerics
from clinicadl.utils.objects import HasConfig
from ..batch import Batch
from .base import ImplementedCollateFn
if TYPE_CHECKING:
from clinicadl.data.structures import Sample
T = TypeVar("T", bound="Sample")
FieldT = TypeVar("FieldT")
class MergeBatchesCollateConfig(ObjectConfig["MergeBatchesCollate"]):
"""
Config class for ``MergeBatchesCollate``.
"""
ignore: Optional[Sequence[str]]
@classmethod
def _get_class(cls) -> type[MergeBatchesCollate]:
return MergeBatchesCollate
[docs]
class MergeBatchesCollate(ImplementedCollateFn, HasConfig[MergeBatchesCollateConfig]):
"""
To merge several batches into a single batch.
This collating mode is typically to get a single batch from the outputs
of a :py:class:`~clinicadl.data.datasets.Dataset` returning a sequence of samples.
``MergeBatchesCollate`` will try to merge this sequence of samples by merging each field
of the :py:class:`~clinicadl.data.structures.Sample`, except those in ``ignore``.
More precisely:
- :py:class:`torchio.Images <torchio.Image>` will be concatenated along the channel dimension;
- :py:class:`numpy.ndarrays <numpy.ndarray>` and :py:class:`torch.Tensors <torch.Tensor>`
will be stacked along a new dimension;
- otherwise, the values will be merged in a tuple. If this tuple contains only one unique
element (i.e. the value is the same in all the samples), a single value will be returned.
Parameters
----------
ignore : Optional[Sequence[str]], default=None
To ignore some fields in the samples to merge. Thus, the output sample will not
have these fields.
.. important::
The mandatory arguments of :py:class:`~clinicadl.data.structures.Sample` cannot be
ignored.
Examples
--------
.. code-block::
from clinicadl.data.dataloader import MergeBatchesCollate
from clinicadl.data.structures.examples import Colin27Sample
import numpy as np
sample = Colin27Sample(participant_id="sub-001", label=np.array([0, 1]), age=55, sex="M")
sample_bis = Colin27Sample(participant_id="sub-001", label=np.array([1, 2]), age=56, to_ignore="abc")
batch = MergeBatchesCollate(ignore=["to_ignore"])([(sample, sample_bis)])
.. code-block::
>>> batch
[Colin27Sample(Keys: ('head', 'sex', 'age', 'label', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 2)]
>>> batch[0].participant_id # same value in the two samples
'sub-001'
>>> batch[0].age
(55, 56)
>>> batch[0].sex # only in one sample, so kept as it is
'M'
>>> batch[0].image.shape # 2 channels now!
(2, 181, 217, 181)
>>> batch[0].label # arrays are stacked
array([[0, 1],
[1, 2]])
See Also
--------
~clinicadl.data.dataloader.ToBatchesCollate
To return a sequence of batches.
"""
_config_type = MergeBatchesCollateConfig
def __init__(self, ignore: Optional[Sequence[str]] = None):
self.config = MergeBatchesCollateConfig(ignore=ignore)
[docs]
def __call__(self, samples: Sequence[Sequence[T]]) -> Batch[T]:
"""
Merges a sequence of sequences of :py:class:`~clinicadl.data.structures.Sample`
in a single :py:class:`~clinicadl.data.dataloader.Batch`.
Parameters
----------
samples : Sequence[Sequence[T]]
A sequence of sequences of :py:class:`~clinicadl.data.structures.Sample`, e.g. a sequence
of outputs of a :py:class:`~clinicadl.data.datasets.PairedDataset`.
Returns
-------
Batch[T]
A :py:class:`~clinicadl.data.dataloader.Batch`, whose samples
are the results of the merger of the inner input sequences.
"""
mergers = []
for samples_collection in samples:
args = {}
type_ = self._check_types(samples_collection)
args[PARTICIPANT_ID] = self._get_unique_field(
[sample.participant_id for sample in samples_collection], PARTICIPANT_ID
)
args[SESSION_ID] = self._get_unique_field(
[sample.session_id for sample in samples_collection], SESSION_ID
)
if isinstance(samples_collection[0], Sample2D):
args[SQUEEZE] = self._get_unique_field(
[sample[SQUEEZE] for sample in samples_collection],
SQUEEZE,
)
args[FILE_TYPE] = tuple(
d for sample in samples_collection for d in sample.file_type
)
args[IMAGE_PATH] = tuple(
p for sample in samples_collection for p in sample.image_path
)
for field in self._get_all_fields(samples_collection):
if field in args or (
self.config.ignore and field in self.config.ignore
):
continue
args[field] = self._merge_field(
[sample[field] for sample in samples_collection if field in sample],
)
mergers.append(type_(**args))
return Batch(mergers)
@staticmethod
def _check_types(samples_collection: Sequence[Sample]) -> type:
"""
Checks that the values of a field are consistent.
"""
types = set(type(sample) for sample in samples_collection)
if len(types) > 1:
raise TypeError(f"Cannot merge samples of different types. Got {types}")
return types.pop()
@staticmethod
def _get_unique_field(values: Sequence[FieldT], field_name: str) -> FieldT:
"""
Checks that the values of a field are consistent.
"""
unique_values = set(values)
if len(unique_values) > 1:
raise RuntimeError(
f"Got different values for '{field_name}': {unique_values}"
)
return unique_values.pop()
@staticmethod
def _get_all_fields(samples_collection: Sequence[Sample]) -> set[str]:
"""
Gets the list of all fields in a set of ``Samples``.
"""
return set(field for sample in samples_collection for field in sample.keys())
def _merge_field(self, values: Sequence[Any]) -> Any:
"""
Tries to merge any field.
"""
values = merge_numerics(values, merge_lists=False)
if isinstance(values, (torch.Tensor, np.ndarray, tio.Image)):
return values
first = values[0]
if all(value == first for value in values):
return first
else:
return tuple(values)