Source code for clinicadl.infer.simple

from typing import Callable, Optional, Sequence

import torch

from clinicadl.transforms.types import TransformOrConfig
from clinicadl.utils.dictionary.words import OUTPUT
from clinicadl.utils.objects import HasConfig

from .base import BaseInferer, BaseInfererConfig, OutputType


class SimpleInfererConfig(BaseInfererConfig):
    """Config class for ``SimpleInferer``."""

    @classmethod
    def _get_class(cls):
        return SimpleInferer


[docs] class SimpleInferer(BaseInferer, HasConfig[SimpleInfererConfig]): """ An :py:class:`~clinicadl.infer.Inferer` for classical inference, i.e. when the whole image is passed in the neural network and the output is returned (with a potential postprocessing). The inference output will be added to the input :py:class:`~clinicadl.data.structures.DataPoint` (or to each ``DataPoint`` of the input :py:class:`~clinicadl.data.dataloader.Batch`). Parameters ---------- postprocessing : Optional[Sequence[TransformOrConfig]], default=None To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a ``DataPoint`` and return a ``DataPoint``, or :py:mod:`configuration classes <clinicadl.transforms.config>`. postprocessing_on_cpu : bool, default=False Whether to necessarily apply postprocessing on CPU. If ``False``, postprocessing will be applied on the device where are the data and the neural network. output_name : str, default="output" The name the give to the output in the ``DataPoint``. .. important:: If you postprocessing transform comes from :py:class:`clinicadl.transforms.config`, do not forget to specify ``include=["<output_name>"]`` to apply the postprocessing to the output of the neural network. output_type : OutputType, default="tensor" Determines the data type of the output: - ``"image"``: the output will be converted to a :py:class:`torchio.ScalarImage`; - ``"mask"``: the output will be converted to a :py:class:`torchio.LabelMap`; - ``"tensor"``: the output will remain a :py:class:`torch.Tensor`. Examples -------- .. code-block:: import torch from clinicadl.infer import SimpleInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch from clinicadl.networks.nn import ConvEncoder from clinicadl.transforms.config import ActivationsConfig net = ConvEncoder(spatial_dims=3, in_channels=1, channels=[2, 4]) datapoint = Colin27DataPoint() .. code-block:: >>> datapoint.image.shape (1, 181, 217, 181) >>> inferer = SimpleInferer() >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (4, 177, 213, 177) >>> out["output"].tensor tensor([[[[ 0.5923, 0.5979, 0.5816, ..., 0.6592, 0.6592, 0.6592], [ 0.5887, 0.5832, 0.5785, ..., 0.6592, 0.6592, 0.6592], [ 0.5883, 0.5832, 0.5858, ..., 0.6592, 0.6592, 0.6592], ..., Working with a specific precision: .. code-block:: >>> net.to(dtype=torch.half) >>> with torch.no_grad(): out = inferer(datapoint, net, input_dtype=torch.half) >>> out["output"].tensor.dtype torch.float16 With postprocessing: >>> net.to(dtype=torch.float) >>> inferer = SimpleInferer(postprocessing=[ActivationsConfig(sigmoid=True, include=["output"])]) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].tensor tensor([[[[0.6439, 0.6452, 0.6414, ..., 0.6591, 0.6591, 0.6591], [0.6430, 0.6418, 0.6407, ..., 0.6591, 0.6591, 0.6591], [0.6429, 0.6418, 0.6424, ..., 0.6591, 0.6591, 0.6591], ..., With a :py:class:`clinicadl.data.dataloader.Batch`: >>> from copy import deepcopy >>> batch = Batch([datapoint, deepcopy(datapoint)]) >>> with torch.no_grad(): out = inferer(batch, net) >>> out[0]["output"].shape (4, 177, 213, 177) """ _config_type = SimpleInfererConfig def __init__( self, postprocessing: Optional[Sequence[TransformOrConfig]] = None, postprocessing_on_cpu: bool = False, output_name: str = OUTPUT, output_type: str | OutputType = OutputType.TENSOR, ): super().__init__( postprocessing=postprocessing, postprocessing_on_cpu=postprocessing_on_cpu, output_name=output_name, output_type=output_type, ) def _forward_pass( self, tensor: torch.Tensor, network: Callable[..., torch.Tensor], **kwargs ) -> torch.Tensor: return network(tensor, **kwargs)