Source code for clinicadl.transforms.extraction.patch

from __future__ import annotations

from collections.abc import Sequence
from enum import Enum
from logging import getLogger
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
from monai.data.utils import iter_patch_position
from pydantic import (
    NonNegativeFloat,
    PositiveInt,
    field_validator,
)

from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.dictionary.words import SAMPLE_POSITION, SAMPLE_TYPE

from .base import Extraction, ImplementedExtraction

if TYPE_CHECKING:
    from clinicadl.data.structures import DataPoint


logger = getLogger(__name__)


class PadMode(str, Enum):
    "Padding mode."

    CONSTANT = "constant"
    REFLECT = "reflect"
    REPLICATE = "replicate"
    CIRCULAR = "circular"


class PatchConfig(ObjectConfig["Patch"]):
    """
    Config class for patch extraction.
    """

    patch_size: tuple[PositiveInt, PositiveInt, PositiveInt]
    overlap: tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat]
    pad_mode: Optional[PadMode]
    pad_value: float

    @field_validator("patch_size", "overlap", mode="before")
    @classmethod
    def _ensure_tuple(
        cls,
        value: Any,
    ) -> tuple:
        """
        Ensures that arguments is a tuple.
        """
        if not isinstance(value, Sequence):
            return (value, value, value)
        return value

    @field_validator("overlap", mode="after")
    @classmethod
    def _overlap_validator(cls, value: tuple) -> tuple:
        """Checks that overlap is between 0 and 1 if it is a float."""
        for v in value:
            assert (
                0 <= v < 1
            ), f"If 'overlap' is a float, it must be between 0 (included) and 1 (excluded). Got {v}"
        return value

    @classmethod
    def _get_class(cls) -> type[Patch]:
        """Returns the class associated to this config class."""
        return Patch


[docs] class Patch(Extraction[ObjectConfig]): """ Transform class to extract patches from an image. The image is divided into smaller patches using a sliding window approach. Adds the following keys to the input :py:class:`~clinicadl.data.structures.DataPoint`: - ``sample_type`` : ``"patch"`` - ``sample_position``: tuple[int, int, int] The position of the patch in the image, which is defined as the position of its upper left voxel. The origin is defined at the upper left voxel of the image. Parameters ---------- patch_size : Union[int, tuple[int, int, int]] The size of the patches. If a single value is passed, the same patch size will be used for the three spatial dimensions. overlap: Union[float, tuple[float, float, float]], default=0.0 A ``float`` in :math:`[0.0, 1.0)` that defines relative patch overlap in each dimension. If a single value is passed, the same overlap will be used for the three spatial dimensions. pad_mode : Optional[str | PadMode], default="constant" A padding mode accepted by :py:func:`torch.nn.functional.pad`, i.e. one of ``"constant"``, ``"reflect"``, ``"replicate"`` or ``"circular"``. If ``None``, no padding will be applied, so the patches that cross the border of the image will be dropped. pad_value : float, default=0.0 The value for ``"constant"`` padding. Examples -------- .. code-block:: from clinicadl.transforms.extraction import Patch from clinicadl.data.structures.examples import Colin27DataPoint data = Colin27DataPoint() patch = Patch(patch_size=64) .. code-block:: >>> data.spatial_shape (181, 217, 181) >>> patch.num_samples_per_image(data) 36 >>> patch(data, sample_index=0).spatial_shape (64, 64, 64) >>> next(iter(patch(data))).sample_position (0, 0, 0) """ config: PatchConfig _config_type = PatchConfig def __init__( self, patch_size: Union[int, tuple[int, int, int]], overlap: Union[ float, tuple[float, float, float], ] = 0.0, pad_mode: Optional[str | PadMode] = PadMode.CONSTANT, pad_value: float = 0.0, ) -> None: self.config = PatchConfig( patch_size=patch_size, overlap=overlap, pad_mode=pad_mode, pad_value=pad_value, ) @property def sample_type(self) -> str: """ The type of the sample returned by this extraction, among {"image", "slice", "patch"}. """ return ImplementedExtraction.PATCH.value.lower() def _extract_tensor_sample( self, image_tensor: torch.Tensor, sample_position: tuple[int, int, int] ) -> torch.Tensor: """ Extracts a single patch from an image. Adapted from https://monai-dev.readthedocs.io/en/stable/inferers.html#monai.inferers.SlidingWindowSplitter.__call__. """ spatial_shape = image_tensor.shape[1:] pad_size = self._calculate_pad_size(spatial_shape) # padding if self.config.pad_mode and any(pad_size): image_tensor = torch.nn.functional.pad( image_tensor, pad_size[-2:] + pad_size[2:4] + pad_size[ :2 ], # torch.nn.functional.pad starts with the last dimension mode=self.config.pad_mode, value=self.config.pad_value, ) patch = self._get_patch( image_tensor, location=sample_position, patch_size=self.config.patch_size ) return patch def _get_sample_positions( self, data_point: DataPoint ) -> list[tuple[int, int, int]]: """ Returns the positions of the patches in the image. """ spatial_shape = data_point.image.tensor.shape[1:] padded_shape = self._get_padded_shape(spatial_shape) return list( iter_patch_position( image_size=padded_shape, patch_size=self.config.patch_size, overlap=self.config.overlap, padded=False, ) ) def _add_info( self, data_point: DataPoint, sample_position: tuple[int, int, int] ) -> None: """ Adds relevant info in the datapoint. """ data_point[SAMPLE_TYPE] = self.sample_type data_point[SAMPLE_POSITION] = sample_position @staticmethod def _get_patch( tensor: torch.Tensor, patch_size: tuple[int, int, int], location: tuple[int, int, int], ) -> torch.Tensor: """ Gets a patch from a 4D tensor. """ slices = (slice(None),) + tuple( slice(loc, loc + ps) for loc, ps in zip(location, patch_size) ) return tensor[slices] def _get_padded_shape( self, spatial_shape: tuple[int, int, int] ) -> tuple[int, int, int]: """ Returns the padded shape from the original shape. """ if not self.config.pad_mode: return spatial_shape pad_size = self._calculate_pad_size(spatial_shape) padded_spatial_shape = tuple( shape + pad for shape, pad in zip(spatial_shape, pad_size[1::2]) ) return padded_spatial_shape def _calculate_pad_size( self, spatial_shape: tuple[int, int, int] ) -> tuple[int, int, int, int, int, int]: """ Returns the pad size for each dimension. """ pad_size = [0] * 2 * len(spatial_shape) if not self.config.pad_mode: return pad_size for i, sh, ps, ov in zip( range(1, len(pad_size), 2), spatial_shape, self.config.patch_size, self.config.overlap, ): pad_size[i] = (ps - sh) % round(ps - (ps * ov)) return pad_size