clinicadl.infer.SlicesToImageInferer¶
- class clinicadl.infer.SlicesToImageInferer(slice_direction: str | SliceDirection, batch_size: int = 1, postprocessing: Sequence[Callable[[DataPointT], DataPointT] | TransformConfig] | None = None, postprocessing_on_cpu: bool = False, output_name: str = 'output', output_type: str | OutputType = OutputType.TENSOR)[source]¶
Splits a 3D volume into 2D slices, passes them in a 2D neural network, and merges the outputs in a 3D output volume.
Adapted from
monai.inferers.SliceInferer.- Parameters:
slice_direction (str | SliceDirection, default=0) –
The slicing direction. Can be
0,1or2.Warning
Be careful with the orientation of your image. If your image is in RAS+ (e.g. you used
ToCanonicalConfig),0refers to the sagittal direction,1to the coronal direction, and2to the axial direction.batch_size (int, default=1) –
The size of the batch passed to the neural network. If you pass a batch of images to the inferer, this batch will be rearranged to match
batch_size.E.g., if a batch of 2` images is passed, with 3 slices in each image, and
batch_size=4, then the first batch passed to the neural network will contain the three slices of the first image and the first slice of the second.postprocessing (Optional[Sequence[TransformOrConfig]], default=None) – To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a
DataPointand return aDataPoint, orconfiguration classes.postprocessing_on_cpu (bool, default=False) – Whether to necessarily apply postprocessing on CPU. If
False, postprocessing will be applied on the device where are the data and the neural network.output_name (str, default="output") –
The name the give to the output in the
DataPoint.Important
If you postprocessing transform comes from
clinicadl.transforms.config, do not forget to specifyinclude=["<output_name>"]to apply the postprocessing to the output of the neural network.output_type (str | OutputType, default="tensor") –
Determines the data type of the output:
"image": the output will be converted to atorchio.ScalarImage;"mask": the output will be converted to atorchio.LabelMap;"tensor": the output will remain atorch.Tensor.
Examples
import torch from clinicadl.infer import SlicesToImageInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.networks.nn import ConvEncoder net = ConvEncoder(spatial_dims=2, in_channels=1, channels=[2, 4], kernel_size=7) datapoint = Colin27DataPoint() inferer = SlicesToImageInferer(slice_direction=1, batch_size=16)
>>> datapoint.image.shape (1, 181, 217, 181) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (4, 169, 217, 169) # 2D neural network applies to the 217 coronal slices
See also
clinicadl.infer.SimpleInfererFor classical inference.
- __call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT¶
Defines the inference logic.
- Parameters:
x (DataT) – The input image(s). Can be a
DataPointor aBatchof images.network (torch.nn.Module) – The neural network.
input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by
network. IfNone, single precision (i.e.float32) will be used (except if the inferer is run in an AMP context).kwargs (Any) – Optional keyword args to be passed to
network.
- Returns:
DataT – The same data structure as the input, containing the inference output.