clinicadl.infer.PatchesToImageInferer¶
- class clinicadl.infer.PatchesToImageInferer(patch_size: int | tuple[int, int, int], overlap: float | tuple[float, float, float] = 0.25, avg_mode: str | AveragingMode = AveragingMode.CONSTANT, sigma_scale: float = 0.125, batch_size: int = 1, postprocessing: Sequence[Callable[[DataPointT], DataPointT] | TransformConfig] | None = None, postprocessing_on_cpu: bool = False, output_name: str = 'output', output_type: str | OutputType = OutputType.TENSOR)[source]¶
Splits a 3D volume into 3D patches, passes them in a 3D neural network, and merges the outputs in a 3D output volume.
Adapted from
monai.inferers.SlidingWindowInferer.- Parameters:
patch_size (Union[int, tuple[int, int, int]]) – The size of the patches. If a single value is passed, the same patch size will be used for the three spatial dimensions.
overlap (Union[float, tuple[float, float, float]], default=0.0) – A
floatin \([0.0, 1.0)\) that defines relative patch overlap in each dimension. If a single value is passed, the same overlap will be used for the three spatial dimensions.avg_mode (AveragingMode, default="constant") –
How to average the outputs when patches overlap.
"constant": gives equal weight to all predictions;"gaussian": gives less weight to the prediction on edges of patches.
sigma_scale (float, default=0.125) – The standard deviation coefficient of the Gaussian window when
avg_modeis"gaussian". The actual sigma issigma_scale * patch_size.batch_size (int, default=1) –
The size of the batch passed to the neural network. If you pass a batch of images to the inferer, this batch will be rearranged to match
batch_size.E.g., if a batch of 2 images is passed, with 3 patches in each image, and
batch_size=4, then the first batch passed to the neural network will contain the three patches of the first image, and the first patch of the second.postprocessing (Optional[Sequence[TransformOrConfig]], default=None) – To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a
DataPointand return aDataPoint, orconfiguration classes.postprocessing_on_cpu (bool, default=False) – Whether to necessarily apply postprocessing on CPU. If
False, postprocessing will be applied on the device where are the data and the neural network.output_name (str, default="output") –
The name the give to the output in the
DataPoint.Important
If you postprocessing transform comes from
clinicadl.transforms.config, do not forget to specifyinclude=["<output_name>"]to apply the postprocessing to the output of the neural network.output_type (OutputType, default="tensor") –
Determines the data type of the output:
"image": the output will be converted to atorchio.ScalarImage;"mask": the output will be converted to atorchio.LabelMap;"tensor": the output will remain atorch.Tensor.
Examples
import torch from clinicadl.infer import PatchesToImageInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.networks.nn import ConvEncoder net = AutoEncoder(in_shape=(1, 64, 64, 64), latent_size=16, conv_args={"channels": [2]}) datapoint = Colin27DataPoint() inferer = PatchesToImageInferer(patch_size=64, batch_size=16, overlap=1/5)
>>> datapoint.image.shape (1, 181, 217, 181) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (1, 181, 217, 181)
See also
clinicadl.infer.SimpleInfererFor classical inference.
- __call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT¶
Defines the inference logic.
- Parameters:
x (DataT) – The input image(s). Can be a
DataPointor aBatchof images.network (torch.nn.Module) – The neural network.
input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by
network. IfNone, single precision (i.e.float32) will be used (except if the inferer is run in an AMP context).kwargs (Any) – Optional keyword args to be passed to
network.
- Returns:
DataT – The same data structure as the input, containing the inference output.