Source code for clinicadl.transforms.config.spatial

from pathlib import Path
from typing import Optional, Tuple, Union

import numpy as np
import torchio as tio
from pydantic import (
    PositiveFloat,
    PositiveInt,
    field_validator,
    model_validator,
)
from torchio import Image

from clinicadl.utils.doc import add_suffix_to_doc
from clinicadl.utils.factories import get_defaults_from

from .base import DOCUMENT_EXTRA_PARAMETERS, Bounds, TorchioTransformConfig
from .enum import EnsureShapeMultipleMode, InterpolationMode, PaddingMode

__all__ = [
    "CropOrPadConfig",
    "ToCanonicalConfig",
    "ResizeConfig",
    "ResampleConfig",
    "EnsureShapeMultipleConfig",
    "CropConfig",
    "PadConfig",
]

CROP_OR_PAD_TORCHIO_DEFAULTS = get_defaults_from(tio.transforms.CropOrPad)
TO_CANONICAL_TORCHIO_DEFAULTS = get_defaults_from(tio.transforms.ToCanonical)
RESIZE_TORCHIO_DEFAULTS = get_defaults_from(tio.transforms.Resize)
RESAMPLE_TORCHIO_DEFAULTS = get_defaults_from(tio.transforms.Resample)
ENSURE_SHAPE_MULTIPLE_TORCHIO_DEFAULTS = get_defaults_from(
    tio.transforms.EnsureShapeMultiple
)
CROP_TORCHIO_DEFAULT = get_defaults_from(tio.transforms.Crop)
PAD_TORCHIO_DEFAULT = get_defaults_from(tio.transforms.Pad)


[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class CropOrPadConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.CropOrPad`. """ target_shape: Optional[ Union[ PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt], ] ] = CROP_OR_PAD_TORCHIO_DEFAULTS["target_shape"] padding_mode: Union[float, PaddingMode] = CROP_OR_PAD_TORCHIO_DEFAULTS[ "padding_mode" ] mask_name: Optional[str] = CROP_OR_PAD_TORCHIO_DEFAULTS["mask_name"] labels: Optional[Tuple[int, ...]] = CROP_OR_PAD_TORCHIO_DEFAULTS["labels"] @model_validator(mode="after") def check_shape(self): """Checks consistency between 'target_shape', 'mask_name' and 'labels'.""" if not self.target_shape and not self.mask_name: raise ValueError( "If 'target_shape' is None or is not passed, a valid 'mask_name' must be passed." ) if not self.mask_name and self.labels: raise ValueError( "If 'mask_name' is not passed, 'labels' must be left to None." ) return self
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class ToCanonicalConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.ToCanonical`. """
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class ResizeConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.Resize`. """ target_shape: Union[int, Tuple[int, int, int]] image_interpolation: InterpolationMode = RESIZE_TORCHIO_DEFAULTS[ "image_interpolation" ] label_interpolation: InterpolationMode = RESIZE_TORCHIO_DEFAULTS[ "label_interpolation" ] @field_validator("target_shape", mode="after") @classmethod def validator_target_shape(cls, v): """Checks that 'target_shape' contains positive integers (or -1).""" if isinstance(v, int): cls._check_dimension(v) elif isinstance(v, tuple): for v_ in v: cls._check_dimension(v_) return v @staticmethod def _check_dimension(dim: int) -> None: """Checks that the value given for a dimension is either -1 or a positive integer.""" if (dim <= 0) and (dim != -1): raise ValueError( "The size of dimensions passed in 'target_shape' must be positive " f"integers or -1. Got {dim}" )
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class ResampleConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.Resample`. """ target: Union[ PositiveFloat, Tuple[PositiveFloat, PositiveFloat, PositiveFloat], str, Path, Tuple[Tuple[PositiveInt, PositiveInt, PositiveInt], np.ndarray], ] = RESAMPLE_TORCHIO_DEFAULTS["target"] pre_affine_name: Optional[str] = RESAMPLE_TORCHIO_DEFAULTS["pre_affine_name"] image_interpolation: InterpolationMode = RESAMPLE_TORCHIO_DEFAULTS[ "image_interpolation" ] label_interpolation: InterpolationMode = RESAMPLE_TORCHIO_DEFAULTS[ "label_interpolation" ] scalars_only: bool = RESAMPLE_TORCHIO_DEFAULTS["scalars_only"] @field_validator("pre_affine_name", mode="before") @classmethod def validator_pre_affine_name(cls, v): """Checks that 'pre_affine_name' is not passed.""" if v is not None: raise ValueError("'pre_affine_name' is not supported in ClinicaDL.") return v @field_validator("target", mode="before") @classmethod def not_tio_image(cls, v): """Checks that 'target' is not a TorchIO Image.""" if isinstance(v, Image): raise ValueError("TorchIO Image not supported for 'target'.") return v @field_validator("target", mode="after") @classmethod def validator_target(cls, v): """Validates 'target' argument.""" if isinstance(v, tuple) and len(v) == 2: affine: np.ndarray = v[1] if affine.shape != (4, 4): raise ValueError( "If 'target' is passed as '(spatial_shape, affine)', 'affine' must be " f"a numpy array of shape (4, 4). Got shape {affine.shape}" ) elif isinstance(v, Path) and not v.is_file(): raise ValueError(f"Got a path for 'target', but {v} is not a valid file.") return v
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class EnsureShapeMultipleConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.EnsureShapeMultiple`. """ target_multiple: Union[PositiveInt, Tuple[PositiveInt, PositiveInt, PositiveInt]] method: EnsureShapeMultipleMode = ENSURE_SHAPE_MULTIPLE_TORCHIO_DEFAULTS["method"]
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class CropConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.Crop`. """ cropping: Bounds
[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class PadConfig(TorchioTransformConfig): """ Config class for :py:class:`torchio.transforms.Pad`. """ padding: Bounds padding_mode: Union[float, PaddingMode] = PAD_TORCHIO_DEFAULT["padding_mode"]