Source code for clinicadl.transforms.homemade

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Sequence, Union

import numpy as np
import torch
import torchio as tio
from monai.utils.type_conversion import (
    convert_data_type,
    convert_to_dst_type,
)

from clinicadl.utils.dtype import DtypeLike
from clinicadl.utils.numerics import merge_numerics

if TYPE_CHECKING:
    from clinicadl.data.structures import DataPoint


[docs] class Format(tio.Transform): """ Transform to modify the shape or the type of some values in a :py:class:`~clinicadl.data.structures.DataPoint`. To be transformed, the values are expected to be :py:class:`torch.Tensor` or a :py:class:`numpy.ndarray`. This transform inherits from :py:class:`torchio.Transform`, and therefore any argument accepted by this parent class can be passed via a keyword argument here. Particularly, you may be interested in ``include`` or ``exclude`` to specify the keys whose values should be modified, and ``copy`` to specify if the output ``DataPoint`` should be the same object as the input or a deepcopy. Parameters ---------- dtype : Optional[torch.dtype], default=None The wanted data type, passed as a :py:class:`torch.dtype`, a :py:class:`numpy.dtype`, or a ``str`` (e.g., "float32", "int64"). If ``None``, input's dtype will be kept. squeeze : Union[bool, int, Sequence[int]], default=False Whether to squeeze the tensor/array, i.e. removing dimension(s) of size 1. If ``True``, all such dimensions will be removed. Specific dimension(s) to remove can be specified via an ``int`` or a sequence of ``ints``. unsqueeze : Optional[int], default=None The position where to insert the new dimension. If ``None``, no unsqueezing will be performed. .. note:: Squeezing is performed before unsqueezing. **kwargs: Any Any keyword argument accepted by :py:class:`torchio.Transform`. Example ------- .. code-block:: from clinicadl.transforms import Format from clinicadl.data.structures.examples import Colin27DataPoint import numpy as np data = Colin27DataPoint(age=55.0, array=np.array([1, 2])) .. code-block:: >>> Format(dtype="int64", include=["age"])(data)["age"] 55 >>> Format(unsqueeze=1, include=["array"])(data)["array"] [[1] [2]] """ def __init__( self, dtype: Optional[DtypeLike] = None, squeeze: Union[bool, int, Sequence[int]] = False, unsqueeze: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self.dtype = dtype self.squeeze = squeeze self.unsqueeze = unsqueeze self.args_names = ["dtype", "squeeze", "unsqueeze"] def apply_transform(self, datapoint: DataPoint) -> DataPoint: for key in datapoint.get_keys(include=self.include, exclude=self.exclude): datapoint[key] = self._format(datapoint[key]) return datapoint def _format( self, value: Union[np.ndarray, torch.Tensor], ) -> Union[np.ndarray, torch.Tensor]: if not isinstance(value, (np.ndarray, torch.Tensor)): value = np.array(value) value_t, *_ = convert_data_type(value, output_type=torch.Tensor) if self.squeeze is True: value_t.squeeze_() elif self.squeeze is not False: value_t.squeeze_(self.squeeze) if self.unsqueeze is not None: value_t.unsqueeze_(self.unsqueeze) out, *_ = convert_to_dst_type(value_t, value, dtype=self.dtype) return out
[docs] class MergeFields(tio.Transform): """ Transform to merge several values of a :py:class:`~clinicadl.data.structures.DataPoint`. The result of the merger depends on the type of values: - :py:class:`torch.Tensor` and :py:class:`numpy.ndarray` are stacked along a new dimension (so they are expected to all have the same shape); - :py:class:`torchio.Image` are concatenated along the channel dimension (so they are expected to all have the same spatial shape); - ``lists`` and ``tuples`` are concatenated; - otherwise, the values are just put in a list. This transform inherits from :py:class:`torchio.Transform`, and therefore any argument accepted by this parent class can be passed via a keyword argument here. Particularly, you may be interested in ``copy`` to specify if the output ``DataPoint`` should be the same object as the input or a deepcopy. Parameters ---------- *keys : str The keys whose values should be merged. output_key : str The name of key in the :py:class:`~clinicadl.data.structures.DataPoint` corresponding to the output of the merger. **kwargs: Any Any keyword argument accepted by :py:class:`torchio.Transform`. Example ------- .. code-block:: from clinicadl.transforms import MergeFields from clinicadl.data.structures.examples import Colin27DataPoint import numpy as np data = Colin27DataPoint(age=55, sex="M", array_1=np.zeros(2), array_2=np.ones(2)) .. code-block:: >>> MergeFields("age", "sex", output_key="label")(data)["label"] [55, 'M'] >>> MergeFields("array_1", "array_2", output_key="label")(data)["label"] [[0. 0.] [1. 1.]] """ def __init__(self, *keys: str, output_key: str, **kwargs): super().__init__(**kwargs) self.keys = keys self.output_key = output_key self.args_names = ["keys", "output_key"] def apply_transform(self, datapoint: DataPoint) -> DataPoint: values = [datapoint[key] for key in self.keys] datapoint[self.output_key] = merge_numerics(values) return datapoint