clinicadl.infer.SlicesToImageInferer

class clinicadl.infer.SlicesToImageInferer(slice_direction: str | SliceDirection, batch_size: int = 1, postprocessing: Sequence[Callable[[DataPointT], DataPointT] | TransformConfig] | None = None, postprocessing_on_cpu: bool = False, output_name: str = 'output', output_type: str | OutputType = OutputType.TENSOR)[source]

Splits a 3D volume into 2D slices, passes them in a 2D neural network, and merges the outputs in a 3D output volume.

Adapted from monai.inferers.SliceInferer.

Parameters:
  • slice_direction (str | SliceDirection, default=0) –

    The slicing direction. Can be 0, 1 or 2.

    Warning

    Be careful with the orientation of your image. If your image is in RAS+ (e.g. you used ToCanonicalConfig), 0 refers to the sagittal direction, 1 to the coronal direction, and 2 to the axial direction.

  • batch_size (int, default=1) –

    The size of the batch passed to the neural network. If you pass a batch of images to the inferer, this batch will be rearranged to match batch_size.

    E.g., if a batch of 2` images is passed, with 3 slices in each image, and batch_size=4, then the first batch passed to the neural network will contain the three slices of the first image and the first slice of the second.

  • postprocessing (Optional[Sequence[TransformOrConfig]], default=None) – To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a DataPoint and return a DataPoint, or configuration classes.

  • postprocessing_on_cpu (bool, default=False) – Whether to necessarily apply postprocessing on CPU. If False, postprocessing will be applied on the device where are the data and the neural network.

  • output_name (str, default="output") –

    The name the give to the output in the DataPoint.

    Important

    If you postprocessing transform comes from clinicadl.transforms.config, do not forget to specify include=["<output_name>"] to apply the postprocessing to the output of the neural network.

  • output_type (str | OutputType, default="tensor") –

    Determines the data type of the output:

Examples

import torch
from clinicadl.infer import SlicesToImageInferer
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.networks.nn import ConvEncoder

net = ConvEncoder(spatial_dims=2, in_channels=1, channels=[2, 4], kernel_size=7)
datapoint = Colin27DataPoint()
inferer = SlicesToImageInferer(slice_direction=1, batch_size=16)
>>> datapoint.image.shape
(1, 181, 217, 181)
>>> with torch.no_grad(): out = inferer(datapoint, net)
>>> out["output"].shape
(4, 169, 217, 169)  # 2D neural network applies to the 217 coronal slices

See also

clinicadl.infer.SimpleInferer

For classical inference.

__call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT

Defines the inference logic.

Parameters:
  • x (DataT) – The input image(s). Can be a DataPoint or a Batch of images.

  • network (torch.nn.Module) – The neural network.

  • input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by network. If None, single precision (i.e. float32) will be used (except if the inferer is run in an AMP context).

  • kwargs (Any) – Optional keyword args to be passed to network.

Returns:

DataT – The same data structure as the input, containing the inference output.