"""
Other functions to perform various utility task on data.
"""
from __future__ import annotations
from dataclasses import fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Optional
import numpy as np
import pandas as pd
from clinicadl.io.bids import Bids
from clinicadl.utils.dictionary.utils import TSV_SEP
from clinicadl.utils.enum import BaseEnum
from clinicadl.utils.exceptions import add_note
from clinicadl.utils.typing import PathType
from clinicadl.utils.variables import SPACING_RTOL
if TYPE_CHECKING:
from .datasets import Dataset
from .structures import DataPoint, Sample
[docs]
def remove_tensors(description_json: PathType) -> None:
"""
To delete tensors in a dataset.
It will remove all the tensors saved with :py:meth:`BidsDataset.to_tensors <clinicadl.data.datasets.BidsDataset.to_tensors>`
associated with the input ``.json`` file.
Parameters
----------
description_json : PathType
Path to the ``.json`` file associated to the tensor conversion you want to delete.
Examples
--------
.. code-block::
from clinicadl.data.datasets import BidsDataset
from clinicadl.data.utils import remove_tensors
from clinicadl.io.bids import BidsFileType
from pathlib import Path
dataset = BidsDataset(
bids="bids_path", file_type=BidsFileType(data_type="anat", suffix="T1w")
)
dataset.to_tensors() # the json file is "bids_path/derivatives/tensors/src-T1w_conv-raw_description.json"
.. code-block::
>>> remove_tensors("bids_path/derivatives/tensors/src-T1w_conv-raw_description.json")
>>> Path("bids_path/derivatives/tensors/src-T1w_conv-raw_description.json").is_file()
False
>>> Path("bids_path/derivatives/tensors/src-T1w_conv-raw_participantsXsessions.tsv").is_file()
False
>>> Path("bids_path/derivatives/tensors/sub-000/ses-M000/sub-000_ses-M000_src-T1w_conv-raw_tensors.pt").is_file()
False
See Also
--------
:py:meth:`clinicadl.data.datasets.BidsDataset.to_tensors`
"""
from .tensors.utils import ConversionRow, TensorDescription
description_json = Path(description_json)
tensors_dir = Bids(description_json.parent)
tensor_conversion = TensorDescription.read(description_json)
pattern = tensors_dir.build_path(tensor_conversion.tensor_type).stem
for path in tensors_dir.path.rglob("*" + pattern + "*"):
path.unlink()
tensor_conversion.get_tsv_path(tensors_dir.path).unlink()
description_json.unlink()
conversions_tsv_path = tensor_conversion.get_conversions_tsv_path(tensors_dir.path)
df = pd.read_csv(conversions_tsv_path, sep=TSV_SEP)
col_name = next(f for f in fields(ConversionRow) if "json" in f.name).name
df = df[df[col_name] != description_json.name]
df.to_csv(conversions_tsv_path, sep=TSV_SEP, index=False)
class SpatialCheck(str, BaseEnum):
"""
Possible spatial checks performed to check consistency of images and
masks in a sample and across a samples.
"""
SPACING = "spacing"
AFFINE = "affine"
SHAPE = "shape"
GLOBAL_SPACING = "global_spacing"
GLOBAL_SHAPE = "global_shape"
DEFAULT_SPATIAL_CHECKS = (
"affine",
"shape",
"global_spacing",
)
class DatasetChecker:
"""
To perform spatial checks on a :py:class:`clinicadl.data.datasets.Dataset`.
"""
def __init__(
self,
spatial_checks: Optional[Iterable[str | SpatialCheck]],
):
self.spatial_checks = (
[SpatialCheck(check) for check in spatial_checks] if spatial_checks else []
)
self.ref_sample: Optional[Sample] = None
self.enabled = True
def reset(self):
"""
Resets the running statistics tracked on the dataset.
"""
self.ref_sample = None
self.enabled = True
def check(self, dataset: Dataset[Sample]) -> None:
"""
Performs the checks on the input dataset.
It iterates over all the samples to see if they are loaded correctly,
and performs spatial checks on the samples, depending on the value of ``self.spatial_checks``.
"""
self.reset()
for sample in dataset:
self.check_data_point(sample)
def check_data_point(self, data_point: DataPoint) -> None:
"""
Checks spacing, affine matrix, and/or image shape consistency in a sample.
Also compares to a reference sample to check consistency across samples.
"""
if not self.enabled:
return
if (
SpatialCheck.SPACING in self.spatial_checks
and SpatialCheck.GLOBAL_SPACING not in self.spatial_checks
):
_check_intra_sample_consistency(
data_point, attr="spacing", desc="voxel spacing"
)
if SpatialCheck.AFFINE in self.spatial_checks:
_check_intra_sample_consistency(
data_point, attr="affine", desc="affine matrix"
)
if (
SpatialCheck.SHAPE in self.spatial_checks
and SpatialCheck.GLOBAL_SHAPE not in self.spatial_checks
):
_check_intra_sample_consistency(
data_point, attr="spatial_shape", desc="spatial shape"
)
if self.ref_sample is None:
self.ref_sample = data_point
if SpatialCheck.GLOBAL_SPACING in self.spatial_checks:
_check_dataset_consistency(
data_point,
self.ref_sample,
tolerance=SPACING_RTOL,
attr="spacing",
desc="voxel spacing",
)
if SpatialCheck.GLOBAL_SHAPE in self.spatial_checks:
_check_dataset_consistency(
data_point,
self.ref_sample,
tolerance=0,
attr="spatial_shape",
desc="spatial shape",
)
def _check_intra_sample_consistency(data_point: DataPoint, attr: str, desc: str) -> Any:
"""
Checks that an attribute (e.g., voxel spacing) is consistent
in a sample.
"""
try:
return getattr(data_point, attr)
except RuntimeError as exc:
add_note(
exc,
f"\nAn error occurred when checking ({data_point.participant_id}, {data_point.session_id}) (see above). "
f"If you don't care about {desc} consistency and want to ignore this error, please modify 'spatial_checks'.",
)
raise
def _check_dataset_consistency(
data_point: DataPoint,
ref_data_point: Optional[DataPoint],
tolerance: float,
attr: str,
desc: str,
) -> None:
"""
Checks that a sample attribute (e.g., voxel spacing) is consistent
with a reference sample.
"""
attr_value = _check_intra_sample_consistency(data_point, attr, desc)
ref_attr_value = getattr(ref_data_point, attr)
if not np.isclose(attr_value, ref_attr_value, rtol=tolerance).all():
raise RuntimeError(
f"Different {desc} found in the dataset: "
f"for example, {desc} is {attr_value} for ({data_point.participant_id}, {data_point.session_id}), "
f"but {ref_attr_value} for ({ref_data_point.participant_id}, {ref_data_point.session_id}).\n"
f"If you don't care about {desc} consistency and want to ignore this error, please modify 'spatial_checks'."
)