clinicadl.infer.SimpleInferer

class clinicadl.infer.SimpleInferer(postprocessing: Sequence[Callable[[DataPointT], DataPointT] | TransformConfig] | None = None, postprocessing_on_cpu: bool = False, output_name: str = 'output', output_type: str | OutputType = OutputType.TENSOR)[source]

An Inferer for classical inference, i.e. when the whole image is passed in the neural network and the output is returned (with a potential postprocessing).

The inference output will be added to the input DataPoint (or to each DataPoint of the input Batch).

Parameters:
  • postprocessing (Optional[Sequence[TransformOrConfig]], default=None) – To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a DataPoint and return a DataPoint, or configuration classes.

  • postprocessing_on_cpu (bool, default=False) – Whether to necessarily apply postprocessing on CPU. If False, postprocessing will be applied on the device where are the data and the neural network.

  • output_name (str, default="output") –

    The name the give to the output in the DataPoint.

    Important

    If you postprocessing transform comes from clinicadl.transforms.config, do not forget to specify include=["<output_name>"] to apply the postprocessing to the output of the neural network.

  • output_type (OutputType, default="tensor") –

    Determines the data type of the output:

Examples

import torch
from clinicadl.infer import SimpleInferer
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
from clinicadl.networks.nn import ConvEncoder
from clinicadl.transforms.config import ActivationsConfig

net = ConvEncoder(spatial_dims=3, in_channels=1, channels=[2, 4])
datapoint = Colin27DataPoint()
>>> datapoint.image.shape
(1, 181, 217, 181)
>>> inferer = SimpleInferer()
>>> with torch.no_grad(): out = inferer(datapoint, net)
>>> out["output"].shape
(4, 177, 213, 177)
>>> out["output"].tensor
tensor([[[[ 0.5923,  0.5979,  0.5816,  ...,  0.6592,  0.6592,  0.6592],
        [ 0.5887,  0.5832,  0.5785,  ...,  0.6592,  0.6592,  0.6592],
        [ 0.5883,  0.5832,  0.5858,  ...,  0.6592,  0.6592,  0.6592],
        ...,

Working with a specific precision:

>>> net.to(dtype=torch.half)
>>> with torch.no_grad(): out = inferer(datapoint, net, input_dtype=torch.half)
>>> out["output"].tensor.dtype
torch.float16

With postprocessing:

>>> net.to(dtype=torch.float)
>>> inferer = SimpleInferer(postprocessing=[ActivationsConfig(sigmoid=True, include=["output"])])
>>> with torch.no_grad(): out = inferer(datapoint, net)
>>> out["output"].tensor
tensor([[[[0.6439, 0.6452, 0.6414,  ..., 0.6591, 0.6591, 0.6591],
        [0.6430, 0.6418, 0.6407,  ..., 0.6591, 0.6591, 0.6591],
        [0.6429, 0.6418, 0.6424,  ..., 0.6591, 0.6591, 0.6591],
        ...,

With a clinicadl.data.dataloader.Batch:

>>> from copy import deepcopy
>>> batch = Batch([datapoint, deepcopy(datapoint)])
>>> with torch.no_grad(): out = inferer(batch, net)
>>> out[0]["output"].shape
(4, 177, 213, 177)
__call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT

Defines the inference logic.

Parameters:
  • x (DataT) – The input image(s). Can be a DataPoint or a Batch of images.

  • network (torch.nn.Module) – The neural network.

  • input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by network. If None, single precision (i.e. float32) will be used (except if the inferer is run in an AMP context).

  • kwargs (Any) – Optional keyword args to be passed to network.

Returns:

DataT – The same data structure as the input, containing the inference output.