"""We chose to override MONAI's CumulativeIterationMetric because herewe wanted to work with ``DataPoints``, and to be able to compute the metricfor each element of the batch individually.Besides, we think our implementation facilitates the creation fo customtransforms by the user."""from__future__importannotationsfromabcimportABC,abstractmethodfromcollections.abcimportSequencefromtypingimportTYPE_CHECKING,Literal,Unionimporttorchfrommonai.metrics.metricimportCumulativeIterationMetricfrom.enumimportOptimumifTYPE_CHECKING:fromclinicadl.data.dataloaderimportBatchTensorOrList=Union[torch.Tensor,Sequence[torch.Tensor]]
[docs]classMetric(CumulativeIterationMetric,ABC):""" To define metrics to evaluate a model. Adapted from :py:class:`monai.metrics.CumulativeIterationMetric`. A metric must inherit from this class to work with ``ClinicaDL``. :py:meth:`_aggregate` and :py:meth:`_accumulate` must be implemented. The user must also define the attribute ``_optimum``: - use "min" when a lower metric value indicates better performance; - use "max" when a higher metric value indicates better performance. Finally, ``__init__`` can be overwritten, but don't forget to call ``super().__init__()`` inside. Examples -------- .. code-block:: from clinicadl.metrics import Metric class MyMetric(Metric): def __init__(self, ...): ... def _aggregate(self, data: TensorOrList) -> float: ... def _accumulate(self, batch: Batch) -> TensorOrList: ... metric = Metric(...) .. code-block:: >>> loader_iterator = iter(dataloader) >>> metric(next(loader_iterator)) tensor([0., 1., 0.]) # metric value for the 3 images of the batch >>> metric(next(loader_iterator)) tensor([0., 1., 0.]) >>> metric.aggregate() 0.3333333333333333 # here it is the average on all the images """_optimum:Literal["min","max"]
[docs]@abstractmethoddef_aggregate(self,data:TensorOrList)->float:""" Aggregation logic. This function tells how to aggregate the data returned by :py:meth:`_accumulate` to compute the metric. Parameters ---------- data : TensorOrList Data useful to compute the metric, as returned by :py:meth:`_accumulate`. Returns ------- float The aggregated metric. """
[docs]@abstractmethoddef_accumulate(self,batch:Batch)->TensorOrList:""" To accumulate data useful for the final metric computation. For example, for segmentation, to compute the accuracy, this function would just return the confusion matrix for each element of the batch. Parameters ---------- batch : Batch The batch of :py:class:`~clinicadl.data.structures.DataPoint`, passed via a :py:class:`~clinicadl.data.dataloader.Batch`. Returns ------- TensorOrList Useful results for the final aggregation, as a "batch-first" tensor, or a sequence of "batch-first" tensors. """
@propertydefoptimum(self)->Optimum:"""Optimization criterion for the metric."""returnOptimum(self._optimum)# pylint: disable=arguments-differ
[docs]defaggregate(self)->float:""" See :py:meth:`monai.metrics.Cumulative.aggregate`. """data=self.get_buffer()returnself._aggregate(data)
# pylint: disable=signature-differs
[docs]def__call__(self,batch:Batch)->torch.Tensor:""" See :py:meth:`monai.metrics.CumulativeIterationMetric.__call__`. It is modified to accept a batch of :py:class:`~clinicadl.data.structures.DataPoint`, and to get the metric for each element of the batch, whereas the original method only accumulates. Parameters ---------- batch : Batch The batch of :py:class:`~clinicadl.data.structures.DataPoint`, passed via a :py:class:`~clinicadl.data.dataloader.Batch`. Returns ------- torch.Tensor The metric value for each element of the batch. """# get the data for metric computationdata=self._accumulate(batch)# store the data in the buffersifisinstance(data,Sequence):self.extend(*data)else:self.extend(data)# compute the metric for each element of the batchresults=[]ifisinstance(data,torch.Tensor):forelemindata:res=self._aggregate(elem.unsqueeze(0))results.append(res)elifisinstance(data,Sequence):forelemsinzip(*data):res=self._aggregate(list(elem.unsqueeze(0)foreleminelems))results.append(res)returntorch.tensor(results)
def_compute_tensor(self,batch:Batch)->TensorOrList:""" See :py:meth:`monai.metrics.metric.IterationMetric._compute_tensor`. Note: :py:meth:`_accumulate` is defined just to have a name more explicit. ``_compute_tensor`` is actually not used, but it is mandatory to override it (see :py:class:`monai.metrics.metric.IterationMetric`). """returnself._accumulate(batch)