Source code for clinicadl.metrics.config.loss

from __future__ import annotations

from copy import deepcopy
from logging import getLogger
from typing import TYPE_CHECKING

from clinicadl.losses.types import Loss
from clinicadl.utils.doc import add_suffix_to_doc
from clinicadl.utils.factories import get_defaults_from

from ..base import Metric
from ..enum import Optimum
from ..loss import LossMetric
from ..monai_wrapper import MonaiMetricWrapper
from .base import DOCUMENT_EXTRA_PARAMETERS, MetricConfig

if TYPE_CHECKING:
    from clinicadl.models import Model

logger = getLogger(__name__)

LOSS_METRIC_MONAI_DEFAULTS = get_defaults_from(LossMetric)


[docs] @add_suffix_to_doc(DOCUMENT_EXTRA_PARAMETERS) class LossMetricConfig(MetricConfig): """ Special config class to use a loss function as a metric. Useful to compute your training losses on your validation set. ``loss_name`` is the name given to the loss in :py:meth:`clinicadl.models.Model.get_loss_functions`. """ loss_name: str = "loss" @staticmethod def optimum() -> Optimum: """The optimum of the metric.""" return Optimum.MIN def get_object(self, model: Model) -> Metric: """ Returns the metric associated to this configuration, parametrized with the parameters passed by the user. Parameters ---------- model : Model The :py:class:`clinicadl.model.Model` where the loss is. Returns ------- Metric: The associated metric. """ losses = model.get_loss_functions() try: loss = losses[self.loss_name] except KeyError as exc: raise ValueError( f"In LossMetricConfig, loss_name='{self.loss_name}' but there is no such loss (returned by the 'get_loss_functions' method of you Model). " f"Losses are: {list(losses.keys())}" ) from exc loss, reduction = self._check_reduction(loss) monai_metric = LossMetric(loss_fn=loss, reduction=reduction) metric = MonaiMetricWrapper( monai_metric, pred_key=self.pred_key, label_key=self.label_key, optimum=self.optimum(), postprocessing=self.postprocessing, ) return metric def _check_reduction(self, loss: Loss) -> tuple[Loss, str]: """Remove the reduction of the loss to put it at the metric level.""" try: loss_reduction = getattr(loss, "reduction") except AttributeError as exc: raise RuntimeError( f"The loss '{self.loss_name}' (returned by the 'get_loss_functions' method of you Model) " "doesn't have a 'reduction' attribute, so ClinicaDL can't compute the validation loss at the image level." ) from exc loss = deepcopy(loss) setattr(loss, "reduction", "none") return loss, loss_reduction