Source code for clinicadl.transforms.handlers.transforms
from __future__ import annotations
from logging import getLogger
from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar
import torchio as tio
from pydantic import Field, ValidationInfo, field_validator
from clinicadl.transforms.config import TransformConfig
from clinicadl.transforms.extraction import Extraction, Image
from clinicadl.utils.config import ObjectConfig, SequenceOfObjects
from clinicadl.utils.dictionary.words import AUGMENTATION, IMAGE, SAMPLE, TRANSFORMATION
from clinicadl.utils.objects import HasConfig, equal_if_config_equal
from ..extraction import get_extraction_from_dict
from ..factory import get_transform_from_dict
from ..types import Transform, TransformOrConfig
from .utils import get_transform_name
if TYPE_CHECKING:
from clinicadl.data.structures import DataPoint
logger = getLogger(__name__)
DataPointT = TypeVar("DataPointT", bound="DataPoint")
class TransformsHandlerConfig(ObjectConfig["TransformsHandler"]):
"""Config class for ``TransformsHandler``."""
extraction: Extraction = Field(
json_schema_extra={"reader": get_extraction_from_dict}
)
image_transforms: SequenceOfObjects[Transform, TransformConfig] = Field(
json_schema_extra={
"reader": SequenceOfObjects.build_reader(get_transform_from_dict)
}
)
sample_transforms: SequenceOfObjects[Transform, TransformConfig] = Field(
json_schema_extra={
"reader": SequenceOfObjects.build_reader(get_transform_from_dict)
}
)
augmentations: SequenceOfObjects[Transform, TransformConfig] = Field(
json_schema_extra={
"reader": SequenceOfObjects.build_reader(get_transform_from_dict)
}
)
@field_validator(
"image_transforms", "sample_transforms", "augmentations", mode="before"
)
@classmethod
def _handle_sequence(cls, v: Any, info: ValidationInfo) -> SequenceOfObjects:
return SequenceOfObjects.from_sequence(v, field_name=info.field_name)
@classmethod
def _get_class(cls) -> type[TransformsHandler]:
"""Returns the class associated to this config class."""
return TransformsHandler
[docs]
@equal_if_config_equal
class TransformsHandler(HasConfig[TransformsHandlerConfig]):
"""
Handles the transformation pipeline applied to images.
``ClinicaDL`` defines 4 types of transforms:\n
- ``extraction``: defines on what type of elements of the image one wants to work
(the whole image, patches or slices).
- ``image_transforms``: transforms applied on the whole image, **before the
potential extraction** is applied. This is typically where you want to
do normalization (to normalize on the whole image and not only on a patch
or a slice).
- ``sample_transforms``: transforms applied on a sample (a patch or a slice),
**after extraction**. This is typically where you want to
resize your sample so that it fits in your network.
- ``augmentations``: transforms applied after ``image_transforms``, ``extraction``
and ``sample_transforms``, **only during training**.
For ``image_transforms``, ``sample_transforms`` and ``augmentations``, the transforms must be passed as sequences.
``TransformsHandler`` will compose the transforms in these sequences. So, **the order in the sequences is important**.
Parameters
----------
extraction : Optional[Extraction], default=None
The extraction applied. See :py:mod:`clinicadl.transforms.extraction`. If ``None``,
:py:class:`~clinicadl.transforms.extraction.Image` is used (which is equivalent to no extraction).
image_transforms : Sequence[TransformOrConfig], default=()
A sequence of transforms to apply on the whole image, **before extraction**.
Passed as callables that take as input and return a :py:class:`~clinicadl.data.structures.DataPoint`,
or :py:mod:`configuration class <clinicadl.transforms.config>`.
sample_transforms : Sequence[TransformOrConfig], default=()
A sequence of transforms to apply on samples (patches or slices).
Passed as callables that take as input and return a :py:class:`~clinicadl.data.structures.DataPoint`,
or :py:mod:`configuration class <clinicadl.transforms.config>`.
augmentations : Sequence[TransformOrConfig], default=()
A sequence of augmentation transforms, to apply on samples, only during training.
Passed as callables that take as input and return a :py:class:`~clinicadl.data.structures.DataPoint`,
or :py:mod:`configuration class <clinicadl.transforms.config>`.
Examples
--------
.. code-block:: python
>>> from clinicadl.transforms import TransformsHandler
>>> from clinicadl.transforms.extraction import Patch
>>> from clinicadl.transforms.config import ZNormalizationConfig, RandomFlipConfig
>>> import torchio
>>> transforms = TransformsHandler(
extraction=Patch(patch_size=32, stride=32),
image_transforms=[ZNormalizationConfig(), torchio.CropOrPad(64)],
sample_transforms=[],
augmentations=[RandomFlipConfig(flip_probability=0.3)],
)
"""
_config_type = TransformsHandlerConfig
def __init__(
self,
extraction: Optional[Extraction] = None,
image_transforms: Sequence[TransformOrConfig] = (),
sample_transforms: Sequence[TransformOrConfig] = (),
augmentations: Sequence[TransformOrConfig] = (),
):
if not extraction:
extraction = Image()
self.config = TransformsHandlerConfig(
extraction=extraction,
image_transforms=image_transforms,
sample_transforms=sample_transforms,
augmentations=augmentations,
)
self.extraction = self.config.extraction
self.image_transforms = tio.Compose(
self.config.image_transforms.get_object(), copy=False
) # copy is specified in the transforms
self.sample_transforms = tio.Compose(
self.config.sample_transforms.get_object(), copy=False
)
self.augmentations = tio.Compose(
self.config.augmentations.get_object(), copy=False
)
def __str__(self) -> str:
"""
Returns a detailed string representation of the ``TransformsHandler`` object,
showing the current configuration of image and sample transforms,
augmentations, and other settings.
"""
transform_str = f"TransformsHandler configuration for {self.extraction.sample_type} extraction:\n"
def _to_str(
list_: list[Transform],
object_: str,
transfo_: str,
):
str_ = ""
if list_:
str_ += f"* {object_} {transfo_}:\n"
for transform in list_:
str_ += f" - {get_transform_name(transform)}\n"
else:
str_ += f"* No {object_} {transfo_} applied.\n"
return str_
transform_str += _to_str(
self.image_transforms.transforms,
object_=IMAGE,
transfo_=TRANSFORMATION,
)
transform_str += _to_str(
self.sample_transforms.transforms,
object_=SAMPLE,
transfo_=TRANSFORMATION,
)
transform_str += _to_str(
self.augmentations.transforms,
object_=SAMPLE,
transfo_=AUGMENTATION,
)
return transform_str
[docs]
def apply_image_transforms(self, datapoint: DataPointT) -> DataPointT:
"""
Applies the transforms passed in ``image_transforms`` and returns the
output.
Parameters
----------
datapoint : DataPoint
The input ``DataPoint``.
Returns
-------
DataPoint
The transformed ``DataPoint``.
"""
return self.image_transforms(datapoint)
[docs]
def extract_sample(self, datapoint: DataPointT, sample_index: int) -> DataPointT:
"""
Extracts the sample.
See: :py:class:`clinicadl.transforms.extraction.Extraction`.
Parameters
----------
datapoint : DataPoint
The input ``DataPoint``.
sample_index : int
Index of the sample to extract.
Returns
-------
DataPoint
The sample in a ``DataPoint``.
"""
return self.extraction(datapoint, sample_index)
[docs]
def apply_sample_transforms(self, datapoint: DataPointT) -> DataPointT:
"""
Applies the transforms passed in ``sample_transforms`` and returns the
output.
Parameters
----------
datapoint : DataPoint
The input ``DataPoint``.
Returns
-------
DataPoint
The transformed ``DataPoint``.
"""
return self.sample_transforms(datapoint)
[docs]
def apply_augmentations(self, datapoint: DataPointT) -> DataPointT:
"""
Applies the transforms passed in ``augmentations`` and returns the
output.
Parameters
----------
datapoint : DataPoint
The input ``DataPoint``.
Returns
-------
DataPoint
The transformed ``DataPoint``.
"""
return self.augmentations(datapoint)