[docs]classInferer(JsonReaderWriter,ABC):""" Abstract class for inferers, which define how an image is passed in a neural network during inference. The only method to override is :py:meth:`__call__`. See Also -------- clinicadl.infer.SimpleInferer For classical inference. clinicadl.infer.PatchesToImageInferer To feed 3D patches into a neural network and merge the outputs in a 3D image. clinicadl.infer.SlicesToImageInferer To feed 2D slices into a 2D neural network and merge the outputs in a 3D image. """
[docs]@abstractmethoddef__call__(self,x:DataT,network:torch.nn.Module,input_dtype:Optional[torch.dtype]=None,**kwargs:Any,)->DataT:""" Defines the inference logic. Parameters ---------- x : DataT The input image(s). Can be a :py:class:`~clinicadl.data.structures.DataPoint` or a :py:class:`~clinicadl.data.dataloader.Batch` of images. network : torch.nn.Module The neural network. input_dtype : Optional[torch.dtype], default=None The data type to which the input image is converted before being processed by ``network``. If ``None``, single precision (i.e. ``float32``) will be used (except if the inferer is run in an :term:`AMP` context). kwargs : Any Optional keyword args to be passed to ``network``. Returns ------- DataT The same data structure as the input, containing the inference output. """
@classmethoddef_get_input_tensor(cls,x:Union[DataPoint,Batch],input_dtype:Optional[torch.dtype]=None)->torch.Tensor:""" Gets the image(s) and returns a :py:class:`torch.Tensor`. """ifisinstance(x,DataPoint):tensor=x.get_image_tensor(IMAGE).to(dtype=input_dtype)elifisinstance(x,Batch):tensor=x.get_field(IMAGE,dtype=input_dtype)else:raiseTypeError(f"'x' can be either a DataPoint or a Batch. Got: {x}")returntensor