Source code for clinicadl.infer.slices_to_image

from typing import Callable, Optional, Sequence

import torch
from monai.inferers import SliceInferer
from pydantic import PositiveInt

from clinicadl.transforms.extraction.slice import SliceDirection
from clinicadl.transforms.types import TransformOrConfig
from clinicadl.utils.dictionary.words import OUTPUT
from clinicadl.utils.objects import HasConfig

from .base import BaseInfererConfig, OutputType
from .utils import Batched3DTo3DInferer


class SlicesToImageInfererConfig(BaseInfererConfig):
    """Config class for ``SlicesToImageInferer``."""

    slice_direction: SliceDirection
    batch_size: PositiveInt

    @classmethod
    def _get_class(cls):
        return SlicesToImageInferer


[docs] class SlicesToImageInferer(Batched3DTo3DInferer, HasConfig[SlicesToImageInfererConfig]): """ Splits a 3D volume into 2D slices, passes them in a 2D neural network, and merges the outputs in a 3D output volume. Adapted from :py:class:`monai.inferers.SliceInferer`. Parameters ---------- slice_direction : str | SliceDirection, default=0 The slicing direction. Can be ``0``, ``1`` or ``2``. .. warning:: Be careful with the orientation of your image. If your image is in :term:`RAS+` (e.g. you used :py:class:`~clinicadl.transforms.config.ToCanonicalConfig`), ``0`` refers to the sagittal direction, ``1`` to the coronal direction, and ``2`` to the axial direction. batch_size : int, default=1 The size of the batch passed to the neural network. If you pass a batch of images to the inferer, this batch will be rearranged to match ``batch_size``. E.g., if a batch of 2` images is passed, with 3 slices in each image, and ``batch_size=4``, then the first batch passed to the neural network will contain the three slices of the first image and the first slice of the second. postprocessing : Optional[Sequence[TransformOrConfig]], default=None To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a ``DataPoint`` and return a ``DataPoint``, or :py:mod:`configuration classes <clinicadl.transforms.config>`. postprocessing_on_cpu : bool, default=False Whether to necessarily apply postprocessing on CPU. If ``False``, postprocessing will be applied on the device where are the data and the neural network. output_name : str, default="output" The name the give to the output in the ``DataPoint``. .. important:: If you postprocessing transform comes from :py:class:`clinicadl.transforms.config`, do not forget to specify ``include=["<output_name>"]`` to apply the postprocessing to the output of the neural network. output_type : str | OutputType, default="tensor" Determines the data type of the output: - ``"image"``: the output will be converted to a :py:class:`torchio.ScalarImage`; - ``"mask"``: the output will be converted to a :py:class:`torchio.LabelMap`; - ``"tensor"``: the output will remain a :py:class:`torch.Tensor`. Examples -------- .. code-block:: import torch from clinicadl.infer import SlicesToImageInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.networks.nn import ConvEncoder net = ConvEncoder(spatial_dims=2, in_channels=1, channels=[2, 4], kernel_size=7) datapoint = Colin27DataPoint() inferer = SlicesToImageInferer(slice_direction=1, batch_size=16) .. code-block:: >>> datapoint.image.shape (1, 181, 217, 181) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (4, 169, 217, 169) # 2D neural network applies to the 217 coronal slices See Also -------- clinicadl.infer.SimpleInferer For classical inference. """ config: SlicesToImageInfererConfig _config_type = SlicesToImageInfererConfig def __init__( self, slice_direction: str | SliceDirection, batch_size: int = 1, postprocessing: Optional[Sequence[TransformOrConfig]] = None, postprocessing_on_cpu: bool = False, output_name: str = OUTPUT, output_type: str | OutputType = OutputType.TENSOR, ): super().__init__( slice_direction=slice_direction, batch_size=batch_size, postprocessing=postprocessing, postprocessing_on_cpu=postprocessing_on_cpu, output_name=output_name, output_type=output_type, ) def _forward_pass( self, tensor: torch.Tensor, network: Callable[..., torch.Tensor], **kwargs ) -> torch.Tensor: shape_2d = list(tensor.shape[2:]) shape_2d.pop(self.config.slice_direction) inferer = SliceInferer( spatial_dim=self.config.slice_direction, roi_size=shape_2d, # we always keep the whole slice here sw_batch_size=self.config.batch_size, ) return inferer(inputs=tensor, network=network, **kwargs)