clinicadl.infer.Inferer

class clinicadl.infer.Inferer[source]

Abstract class for inferers, which define how an image is passed in a neural network during inference.

The only method to override is __call__().

See also

clinicadl.infer.SimpleInferer

For classical inference.

clinicadl.infer.PatchesToImageInferer

To feed 3D patches into a neural network and merge the outputs in a 3D image.

clinicadl.infer.SlicesToImageInferer

To feed 2D slices into a 2D neural network and merge the outputs in a 3D image.

abstract __call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT[source]

Defines the inference logic.

Parameters:
  • x (DataT) – The input image(s). Can be a DataPoint or a Batch of images.

  • network (torch.nn.Module) – The neural network.

  • input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by network. If None, single precision (i.e. float32) will be used (except if the inferer is run in an AMP context).

  • kwargs (Any) – Optional keyword args to be passed to network.

Returns:

DataT – The same data structure as the input, containing the inference output.