Source code for clinicadl.infer.patches_to_image

from collections.abc import Sequence
from enum import Enum
from typing import Any, Callable, Optional, Union

import torch
from monai.inferers import SlidingWindowInferer
from pydantic import (
    NonNegativeFloat,
    PositiveFloat,
    PositiveInt,
    field_validator,
)

from clinicadl.transforms.types import TransformOrConfig
from clinicadl.utils.dictionary.words import OUTPUT
from clinicadl.utils.objects import HasConfig

from .base import BaseInfererConfig, OutputType
from .utils import Batched3DTo3DInferer


class AveragingMode(str, Enum):
    """Averaging method where patches overlap."""

    CONSTANT = "constant"
    GAUSSIAN = "gaussian"


class PatchesToImageInfererConfig(BaseInfererConfig):
    """Config class for ``PatchesToImageInferer``."""

    patch_size: tuple[PositiveInt, PositiveInt, PositiveInt]
    overlap: tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat]
    avg_mode: AveragingMode
    sigma_scale: PositiveFloat
    batch_size: PositiveInt

    @field_validator("patch_size", "overlap", mode="before")
    @classmethod
    def _ensure_tuple(
        cls,
        value: Any,
    ) -> tuple:
        """
        Ensures that arguments is a tuple.
        """
        if not isinstance(value, Sequence):
            return (value, value, value)
        return value

    @field_validator("overlap", mode="after")
    @classmethod
    def _overlap_validator(cls, value: tuple) -> tuple:
        """Checks that overlap is between 0 and 1 if it is a float."""
        for v in value:
            assert (
                0 <= v < 1
            ), f"If 'overlap' is a float, it must be between 0 (included) and 1 (excluded). Got {v}"
        return value

    @classmethod
    def _get_class(cls):
        return PatchesToImageInferer


[docs] class PatchesToImageInferer( Batched3DTo3DInferer, HasConfig[PatchesToImageInfererConfig] ): """ Splits a 3D volume into 3D patches, passes them in a 3D neural network, and merges the outputs in a 3D output volume. Adapted from :py:class:`monai.inferers.SlidingWindowInferer`. Parameters ---------- patch_size : Union[int, tuple[int, int, int]] The size of the patches. If a single value is passed, the same patch size will be used for the three spatial dimensions. overlap: Union[float, tuple[float, float, float]], default=0.0 A ``float`` in :math:`[0.0, 1.0)` that defines relative patch overlap in each dimension. If a single value is passed, the same overlap will be used for the three spatial dimensions. avg_mode : AveragingMode, default="constant" How to average the outputs when patches overlap. - ``"constant"``: gives equal weight to all predictions; - ``"gaussian"``: gives less weight to the prediction on edges of patches. sigma_scale: float, default=0.125 The standard deviation coefficient of the Gaussian window when ``avg_mode`` is ``"gaussian"``. The actual sigma is ``sigma_scale * patch_size``. batch_size : int, default=1 The size of the batch passed to the neural network. If you pass a batch of images to the inferer, this batch will be rearranged to match ``batch_size``. E.g., if a batch of 2 images is passed, with 3 patches in each image, and ``batch_size=4``, then the first batch passed to the neural network will contain the three patches of the first image, and the first patch of the second. postprocessing : Optional[Sequence[TransformOrConfig]], default=None To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a ``DataPoint`` and return a ``DataPoint``, or :py:mod:`configuration classes <clinicadl.transforms.config>`. postprocessing_on_cpu : bool, default=False Whether to necessarily apply postprocessing on CPU. If ``False``, postprocessing will be applied on the device where are the data and the neural network. output_name : str, default="output" The name the give to the output in the ``DataPoint``. .. important:: If you postprocessing transform comes from :py:class:`clinicadl.transforms.config`, do not forget to specify ``include=["<output_name>"]`` to apply the postprocessing to the output of the neural network. output_type : OutputType, default="tensor" Determines the data type of the output: - ``"image"``: the output will be converted to a :py:class:`torchio.ScalarImage`; - ``"mask"``: the output will be converted to a :py:class:`torchio.LabelMap`; - ``"tensor"``: the output will remain a :py:class:`torch.Tensor`. Examples -------- .. code-block:: import torch from clinicadl.infer import PatchesToImageInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.networks.nn import ConvEncoder net = AutoEncoder(in_shape=(1, 64, 64, 64), latent_size=16, conv_args={"channels": [2]}) datapoint = Colin27DataPoint() inferer = PatchesToImageInferer(patch_size=64, batch_size=16, overlap=1/5) .. code-block:: >>> datapoint.image.shape (1, 181, 217, 181) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (1, 181, 217, 181) See Also -------- clinicadl.infer.SimpleInferer For classical inference. """ config: PatchesToImageInfererConfig _config_type = PatchesToImageInfererConfig def __init__( self, patch_size: Union[int, tuple[int, int, int]], overlap: Union[float, tuple[float, float, float]] = 0.25, avg_mode: str | AveragingMode = AveragingMode.CONSTANT, sigma_scale: float = 0.125, batch_size: int = 1, postprocessing: Optional[Sequence[TransformOrConfig]] = None, postprocessing_on_cpu: bool = False, output_name: str = OUTPUT, output_type: str | OutputType = OutputType.TENSOR, ): super().__init__( patch_size=patch_size, overlap=overlap, avg_mode=avg_mode, sigma_scale=sigma_scale, batch_size=batch_size, postprocessing=postprocessing, postprocessing_on_cpu=postprocessing_on_cpu, output_name=output_name, output_type=output_type, ) self._sliding_window = SlidingWindowInferer( roi_size=self.config.patch_size, sw_batch_size=self.config.batch_size, overlap=self.config.overlap, sw_device=None, # where inference is done device=torch.device("cpu") # where patch are assembled if postprocessing_on_cpu else None, mode=self.config.avg_mode, sigma_scale=self.config.sigma_scale, ) def _forward_pass( self, tensor: torch.Tensor, network: Callable[..., torch.Tensor], **kwargs ) -> torch.Tensor: self._check_shape(tensor) return self._sliding_window(inputs=tensor, network=network, **kwargs) def _check_shape(self, tensor: torch.Tensor) -> None: """ Checks spatial shape of the input image(s). """ spatial_shape = tensor.shape[-3:] if spatial_shape < self.config.patch_size: raise ValueError( f"'patch_size' is bigger than the image. Got an image of spatial shape {spatial_shape} but patch_size={self.config.patch_size}" )