from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Sequence, Union
import pandas as pd
from pydantic import Field, ValidationError, ValidationInfo, field_validator
from typing_extensions import Self
from clinicadl.utils.config import DictOfObjects, ObjectConfig
from clinicadl.utils.dictionary.utils import TSV_SEP
from clinicadl.utils.dictionary.words import (
CPU,
EPOCH,
METRICS,
PARTICIPANT_ID,
SESSION_ID,
)
from clinicadl.utils.exceptions import CannotReadFieldError
from clinicadl.utils.objects import HasConfig
from .base import Metric
from .config import MetricConfig
from .factory import get_metric_from_dict
from .types import MetricOrConfig
if TYPE_CHECKING:
from clinicadl.data.dataloader import Batch
from clinicadl.models import Model
class MetricsHandlerConfig(ObjectConfig["MetricsHandler"]):
"""
To check and convert metrics passed by the user.
"""
metrics: DictOfObjects[Metric, MetricConfig] = Field(
json_schema_extra={"reader": DictOfObjects.build_reader(get_metric_from_dict)}
)
metrics_on_cpu: bool
@field_validator("metrics", mode="before")
@classmethod
def _handle_dict(cls, v: Any, info: ValidationInfo) -> DictOfObjects:
return DictOfObjects.from_dict(v, field_name=info.field_name)
@classmethod
def from_dict(cls, dict_: dict[str, Any], **kwargs) -> Self:
dict_ = cls._check_dict(dict_)
dict_[METRICS].update(
{arg: value for arg, value in kwargs.items() if arg != "metrics_on_cpu"}
)
if cpu := kwargs.get("metrics_on_cpu", None):
dict_["metrics_on_cpu"] = cpu
try:
return super().from_dict(dict_)
except CannotReadFieldError as e:
if wrong_metrics := cls._read_field_reading_error(e):
raise CannotReadFieldError(
field_names=wrong_metrics,
object_name=cls._get_name(),
error=e.error,
) from e
raise
@property
def metric_names(self) -> list[str]:
"""
The names of the metrics.
"""
return list(self.metrics.values.keys())
def add_metrics(
self,
metrics: dict[str, MetricOrConfig],
) -> None:
"""
Adds metrics.
"""
for name in metrics:
if name in self.metric_names:
raise ValueError(f"A metric named '{name}' already exists!")
self.metrics = self.metrics.values | metrics
def remove_metrics(
self,
metrics: Sequence[str],
) -> None:
"""
Removes metrics.
"""
for metric in metrics:
self.metrics.values.pop(metric)
@staticmethod
def _read_field_reading_error(error: CannotReadFieldError) -> Optional[list[str]]:
"""
To get the potential metrics that are causing troubles.
"""
if isinstance(error.error, CannotReadFieldError):
return error.error.field_names
elif isinstance(error.error, ValidationError):
return sorted(list(set(e["loc"][2] for e in error.error.errors())))
@classmethod
def _get_class(cls) -> type[MetricsHandler]:
"""Returns the class associated to this config class."""
return MetricsHandler
[docs]
class MetricsHandler(HasConfig[MetricsHandlerConfig]):
"""
To handle all the :py:class:`~clinicadl.metrics.Metric` computed during an evaluation phase.
``MetricsHandler`` is itself a callable that works like :py:class:`monai.metrics.CumulativeIterationMetric`,
with :py:meth:`reset` and :py:meth:`aggregate` methods. So, it can be used like a raw :py:class:`~clinicadl.metrics.Metric`
object.
The results are stored in DataFrames (:py:attr:`df` and :py:attr:`detailed_df`), that can be saved with
:py:meth:`save`.
Parameters
----------
metrics_on_cpu : boo, True
Whether to necessarily apply metrics computation on CPU. If ``False``, postprocessing will
be applied on the device where are currently the data
**metrics : MetricOrConfig
Metrics to add to the ``MetricsHandler``. They must be passed as :py:class:`~clinicadl.metrics.Metric`
or :py:mod:`configuration classes <clinicadl.metrics.config>`.
Examples
--------
.. code-block::
from clinicadl.metrics import MetricsHandler, config
from clinicadl.data.structures.examples import Colin27DataPoint
from clinicadl.data.dataloader import Batch
import torch
metrics = MetricsHandler(
mse=config.MSEMetricConfig(label_key="ground_truth"),
mae=config.MAEMetricConfig(label_key="ground_truth"),
)
metrics.init_metrics()
torch.manual_seed(0)
batch = Batch(
[
Colin27DataPoint(output=torch.randn(1, 10), ground_truth=torch.randn(1, 10)),
Colin27DataPoint(output=torch.randn(1, 10), ground_truth=torch.randn(1, 10)),
]
)
.. code-block::
>>> metrics(batch)
epoch participant_id session_id mse mae
0 1 sub-000 ses-M000 1.155020 0.871509
1 1 sub-001 ses-M000 1.540214 0.835756
>>> metrics.detailed_df
epoch participant_id session_id mse mae
0 1 sub-000 ses-M000 1.155020 0.871509
1 1 sub-001 ses-M000 1.540214 0.835756
>>> metrics.aggregate(epoch=1)
>>> metrics.df
epoch mse mae
0 1 1.347617 0.853633
"""
_config_type = MetricsHandlerConfig
def __init__(
self,
metrics_on_cpu: bool = True,
**metrics: MetricOrConfig,
):
if not metrics:
metrics = {}
self.config = MetricsHandlerConfig(
metrics_on_cpu=metrics_on_cpu, metrics=metrics
)
self._metrics = None
self._model = None
self._df = self._init_df()
self._detailed_df = self._init_detailed_df()
[docs]
def init_metrics(self, model: Optional[Model] = None) -> None:
"""
Instantiates the metrics from their configuration classes.
Parameters
----------
model : Optional[Model], default=None
The model that contains the potential losses to compute
on the validation set (defined in :py:meth:`Model.get_loss_functions <clinicadl.models.Model.get_loss_functions>`).
"""
self._metrics = self.config.metrics.get_object(model=model)
self._model = model
@property
def metrics(self) -> Optional[dict[str, Metric]]:
"""
The metrics currently in the ``MetricsHandler``.
If ``None``, it means that :py:meth:`init_metrics` must be called.
"""
return self._metrics
@property
def df(self) -> pd.DataFrame:
"""
The :py:class:`pandas.DataFrame` containing the aggregated results, i.e. the results on
the whole dataset obtained by calling :py:meth:`aggregate`.
"""
return self._df
@property
def detailed_df(self) -> pd.DataFrame:
"""
The :py:class:`pandas.DataFrame` containing the detailed results,
i.e. the results for each image.
"""
return self._detailed_df
def _init_df(self) -> pd.DataFrame:
"""
Create an empty DataFrame with a column for each metric.
"""
return pd.DataFrame(columns=self.config.metric_names)
def _init_detailed_df(self) -> pd.DataFrame:
"""
Create an empty DataFrame with a column for each metric,
as well as columns "participant_id" and "session_id".
"""
columns = [PARTICIPANT_ID, SESSION_ID] + self.config.metric_names
return pd.DataFrame(columns=columns)
[docs]
def add_metrics(
self,
**metrics: MetricOrConfig,
) -> None:
"""
Adds metrics to the ``MetricsHandler`` instance.
.. warning::
To be sure that all the metrics are computed on the
same dataset, ``add_metrics`` will reset all the present
metrics.
Parameters
----------
**metrics : MetricConfig
Metrics to add to the ``MetricsHandler``. They must be passed as :py:class:`~clinicadl.metrics.Metric`
or :py:mod:`configuration classes <clinicadl.metrics.config>`.
"""
self.config.add_metrics(metrics)
self.reset(reset_df=False)
if self.metrics is not None:
self._metrics = self.config.metrics.get_object(model=self._model)
new_columns = self._df.columns.join(self.config.metric_names)
self._df = self._df.reindex(columns=new_columns, fill_value=pd.NA)
new_columns = self._detailed_df.columns.join(self.config.metric_names)
self._detailed_df = self._detailed_df.reindex(
columns=new_columns,
fill_value=pd.NA,
)
[docs]
def remove_metrics(
self,
metrics: Union[str, Sequence[str]],
) -> None:
"""
Removes metrics from the ``MetricsHandler`` instance.
Parameters
----------
metrics : Union[str, Sequence[str]]
The name of the metric(s) to remove.
"""
if isinstance(metrics, str):
metrics = [metrics]
self.config.remove_metrics(metrics)
if self.metrics is not None:
for metric in metrics:
self._metrics.pop(metric)
self._df.drop(columns=metrics, inplace=True)
self._detailed_df.drop(columns=metrics, inplace=True)
[docs]
def reset(self, reset_df: bool = False) -> None:
"""
Reset all metric states.
Parameters
----------
reset_df : bool, default=False
If ``True``, also reset the DataFrames containing the results.
See Also
--------
:py:meth:`monai.metrics.Cumulative.reset`
"""
if self.metrics is not None:
for metric in self.metrics.values():
metric.reset()
if reset_df:
self._df = self._init_df()
self._detailed_df = self._init_detailed_df()
[docs]
def aggregate(
self,
epoch: Optional[int] = None,
) -> None:
"""
Aggregates and stores metric results.
Parameters
----------
epoch : Optional[int], default=None
Current epoch. This information will be added in the DataFrame.
See Also
--------
:py:meth:`monai.metrics.Cumulative.aggregate`
"""
if self.metrics is None:
raise RuntimeError("First, call 'init_metrics' to instantiate the metrics.")
values = {name: metric.aggregate() for name, metric in self.metrics.items()}
new_df = pd.DataFrame(values, index=[0])
if epoch is not None:
new_df.insert(loc=0, column=EPOCH, value=epoch)
self._df = (
new_df.copy()
if self._df.empty
else pd.concat([self._df, new_df], ignore_index=True)
)
if epoch is not None:
self._df.insert(0, EPOCH, self._df.pop(EPOCH)) # ensure epoch first column
try:
self._df = self._df.astype(
{EPOCH: int}
) # type may have been modified by concat
except pd.errors.IntCastingNaNError:
pass
[docs]
def __call__(
self,
batch: Batch,
epoch: Optional[int] = None,
) -> pd.DataFrame:
"""
Updates metrics with a new batch.
Parameters
----------
batch : Batch
The batch, with the predictions, and the ground truths if required
by some metrics.
epoch : Optional[int], default=None
Current epoch. This information will be added in the DataFrame.
Returns
-------
pd.DataFrame
The metrics for all the images in the batch.
"""
if self.metrics is None:
raise RuntimeError("First, call 'init_metrics' to instantiate the metrics.")
if self.config.metrics_on_cpu:
batch.to(device=CPU)
participants = batch.get_field(PARTICIPANT_ID)
sessions = batch.get_field(SESSION_ID)
values = {PARTICIPANT_ID: participants, SESSION_ID: sessions}
values.update({name: metric(batch) for name, metric in self.metrics.items()})
new_df = pd.DataFrame(values)
if epoch is not None:
new_df.insert(loc=0, column=EPOCH, value=epoch)
self._detailed_df = (
new_df.copy()
if self._detailed_df.empty
else pd.concat([self._detailed_df, new_df], ignore_index=True)
)
if epoch is not None:
self._detailed_df.insert(
0, EPOCH, self._detailed_df.pop(EPOCH)
) # ensure epoch first column
try:
self._detailed_df = self._detailed_df.astype(
{EPOCH: int}
) # type may have been modified by concat
except pd.errors.IntCastingNaNError:
pass
return new_df
[docs]
def get_metric_value(self, metric: str, epoch: Optional[int] = None) -> float:
"""
To get the (aggregated) value of a metric.
Parameters
----------
metric : str
The name of the metric.
epoch : Optional[int], default=None
The epoch for which the value is wanted. If ``None``, the method will
return the last computed value.
Returns
-------
float
The value of the metric.
"""
self.check_metric_name(metric)
if epoch is not None:
value = self.df.set_index(EPOCH).loc[epoch, metric]
else:
value = self.df.iloc[-1][metric]
try:
return float(value)
except (TypeError, ValueError) as exc:
raise ValueError(
f"""Value for metric '{metric}' {f"at epoch {epoch} " if epoch is not None else ""}is not numeric."""
) from exc
[docs]
def get_metric_values(self, epoch: Optional[int] = None) -> pd.DataFrame:
"""
To get the (aggregated) values of all the computed metrics.
Parameters
----------
epoch : int
The epoch for which the values are wanted. If ``None``, the method will
return the last computed values.
Returns
-------
pd.DataFrame
A :py:class:`pandas.DataFrame` containing the metric values.
"""
if epoch is not None:
return self.df[self.df[EPOCH] == epoch]
return self.df.iloc[-1:]
[docs]
def get_detailed_metric_values(self, epoch: int) -> pd.DataFrame:
"""
To get the detailed values of a all the computed metrics
for a specific epoch.
Parameters
----------
epoch : int
The epoch for which the values are wanted.
Returns
-------
pd.DataFrame
A :py:class:`pandas.DataFrame` containing the metric values.
"""
return self.detailed_df[self.detailed_df[EPOCH] == epoch]
[docs]
def check_metric_name(self, metric: str) -> None:
"""
Checks if a metric is in the computed metrics.
Parameters
----------
metric : str
The name of the metric to check.
"""
if metric not in self.config.metric_names:
raise KeyError(
f"'{metric}' not found in the computed metrics! Metrics are: {self.config.metric_names}"
)
[docs]
def save(self, path: Path, details_path: Optional[Path] = None) -> None:
"""
Saves the DataFrames containing the results.
Parameters
----------
path : Path
The path where to save :py:attr:`df`.
details_path: Optional[Path], default=None
The path where to save :py:attr:`detailed_df`.
If ``None``, this DataFrame will not be saved.
"""
self._df.to_csv(path, sep=TSV_SEP, index=False)
if details_path:
self._detailed_df.to_csv(details_path, sep=TSV_SEP, index=False)
[docs]
def merge(self, path: Path, details_path: Optional[Path] = None) -> None:
"""
Merges the current DataFrame(s) with the one(s) in the file(s) and
saves the result.
Parameters
----------
path : Path
The path to the DataFrame to merge with :py:attr:`df`.
details_path: Optional[Path], default=None
The path to the DataFrame to merge with :py:attr:`detailed_df`.
"""
old_df = pd.read_csv(path, sep=TSV_SEP)
try:
new_df = pd.merge(old_df, self._df, how="outer")
except pd.errors.MergeError:
new_df = pd.concat([old_df, self._df], axis=1)
new_df.to_csv(path, sep=TSV_SEP, index=False)
if details_path:
old_df = pd.read_csv(details_path, sep=TSV_SEP)
new_df = pd.merge(old_df, self._detailed_df, how="outer")
new_df.to_csv(details_path, sep=TSV_SEP, index=False)
[docs]
def load(self, path: Path, details_path: Optional[Path] = None) -> None:
"""
Loads checkpoint DataFrames saved with :py:meth:`save`.
Parameters
----------
path : Path
The path to the DataFrame with the aggregated results.
details_path: Optional[Path], default=None
The path to the DataFrame with the detailed results.
If ``None``, this DataFrame will not be loaded.
"""
df = pd.read_csv(path, sep=TSV_SEP)
expected_columns = set(self.config.metric_names)
assert (
len(expected_columns.difference(df.columns)) == 0
), f"Checkpoint in {str(path)} is not a valid metric file, some columns are missing: {expected_columns.difference(df.columns)}"
self.reset(reset_df=True)
self._df = df
if details_path:
detailed_df = pd.read_csv(details_path, sep=TSV_SEP)
expected_columns = expected_columns.union({PARTICIPANT_ID, SESSION_ID})
assert (
len(expected_columns.difference(detailed_df.columns)) == 0
), f"Checkpoint in {str(path)} is not a valid metric details file, some columns are missing: {expected_columns.difference(detailed_df.columns)}"
self._detailed_df = detailed_df
[docs]
def subset(self, metrics: Sequence[str]) -> MetricsHandler:
"""
Creates a new ``MetricsHandler`` containing only a subset of
the current metrics.
Parameters
----------
metrics : Sequence[str]
The names of the metrics to keep in the new instance.
Returns
-------
MetricsHandler
The new instance.
"""
for metric in metrics:
self.check_metric_name(metric)
subset = {
name: metric
for name, metric in self.config.to_raw_dict()[METRICS].items()
if name in metrics
}
new_metrics = MetricsHandler(
**deepcopy(subset), metrics_on_cpu=self.config.metrics_on_cpu
)
if self.metrics:
new_metrics.init_metrics(model=self._model)
return new_metrics
@classmethod
def _from_config(cls: type[Self], config: MetricsHandlerConfig) -> Self:
"""To create the object from the associated config."""
args = config.to_raw_dict()
return cls(metrics_on_cpu=args["metrics_on_cpu"], **args[METRICS])