clinicadl.infer.SimpleInferer¶
- class clinicadl.infer.SimpleInferer(postprocessing: Sequence[Callable[[DataPointT], DataPointT] | TransformConfig] | None = None, postprocessing_on_cpu: bool = False, output_name: str = 'output', output_type: str | OutputType = OutputType.TENSOR)[source]¶
An
Infererfor classical inference, i.e. when the whole image is passed in the neural network and the output is returned (with a potential postprocessing).The inference output will be added to the input
DataPoint(or to eachDataPointof the inputBatch).- Parameters:
postprocessing (Optional[Sequence[TransformOrConfig]], default=None) – To apply postprocessing transformations (e.g. activations) after the pass forward in the neural network. Accepted transforms are functions that take as input a
DataPointand return aDataPoint, orconfiguration classes.postprocessing_on_cpu (bool, default=False) – Whether to necessarily apply postprocessing on CPU. If
False, postprocessing will be applied on the device where are the data and the neural network.output_name (str, default="output") –
The name the give to the output in the
DataPoint.Important
If you postprocessing transform comes from
clinicadl.transforms.config, do not forget to specifyinclude=["<output_name>"]to apply the postprocessing to the output of the neural network.output_type (OutputType, default="tensor") –
Determines the data type of the output:
"image": the output will be converted to atorchio.ScalarImage;"mask": the output will be converted to atorchio.LabelMap;"tensor": the output will remain atorch.Tensor.
Examples
import torch from clinicadl.infer import SimpleInferer from clinicadl.data.structures.examples import Colin27DataPoint from clinicadl.data.dataloader import Batch from clinicadl.networks.nn import ConvEncoder from clinicadl.transforms.config import ActivationsConfig net = ConvEncoder(spatial_dims=3, in_channels=1, channels=[2, 4]) datapoint = Colin27DataPoint()
>>> datapoint.image.shape (1, 181, 217, 181) >>> inferer = SimpleInferer() >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].shape (4, 177, 213, 177) >>> out["output"].tensor tensor([[[[ 0.5923, 0.5979, 0.5816, ..., 0.6592, 0.6592, 0.6592], [ 0.5887, 0.5832, 0.5785, ..., 0.6592, 0.6592, 0.6592], [ 0.5883, 0.5832, 0.5858, ..., 0.6592, 0.6592, 0.6592], ...,
Working with a specific precision:
>>> net.to(dtype=torch.half) >>> with torch.no_grad(): out = inferer(datapoint, net, input_dtype=torch.half) >>> out["output"].tensor.dtype torch.float16
With postprocessing:
>>> net.to(dtype=torch.float) >>> inferer = SimpleInferer(postprocessing=[ActivationsConfig(sigmoid=True, include=["output"])]) >>> with torch.no_grad(): out = inferer(datapoint, net) >>> out["output"].tensor tensor([[[[0.6439, 0.6452, 0.6414, ..., 0.6591, 0.6591, 0.6591], [0.6430, 0.6418, 0.6407, ..., 0.6591, 0.6591, 0.6591], [0.6429, 0.6418, 0.6424, ..., 0.6591, 0.6591, 0.6591], ...,
With a
clinicadl.data.dataloader.Batch:>>> from copy import deepcopy >>> batch = Batch([datapoint, deepcopy(datapoint)]) >>> with torch.no_grad(): out = inferer(batch, net) >>> out[0]["output"].shape (4, 177, 213, 177)
- __call__(x: DataT, network: Module, input_dtype: dtype | None = None, **kwargs: Any) DataT¶
Defines the inference logic.
- Parameters:
x (DataT) – The input image(s). Can be a
DataPointor aBatchof images.network (torch.nn.Module) – The neural network.
input_dtype (Optional[torch.dtype], default=None) – The data type to which the input image is converted before being processed by
network. IfNone, single precision (i.e.float32) will be used (except if the inferer is run in an AMP context).kwargs (Any) – Optional keyword args to be passed to
network.
- Returns:
DataT – The same data structure as the input, containing the inference output.