from __future__ import annotations
from collections.abc import Sequence
from logging import getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from pydantic import (
NonNegativeInt,
PositiveInt,
field_validator,
model_validator,
)
from typing_extensions import Self
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.dictionary.words import (
SAMPLE_POSITION,
SAMPLE_TYPE,
SLICE_DIRECTION,
SQUEEZE,
)
from clinicadl.utils.enum import SliceDirection
from clinicadl.utils.exceptions import DataFrameError
from clinicadl.utils.typing import PathType
from .base import Extraction, ImplementedExtraction
if TYPE_CHECKING:
from clinicadl.data.structures import DataPoint
logger = getLogger(__name__)
class SliceConfig(ObjectConfig["Slice"]):
"""
Config class for slice extraction.
"""
slices: Optional[list[NonNegativeInt]]
tsv_path: Optional[Path]
discarded_slices: Optional[list[NonNegativeInt]]
borders: Optional[Tuple[PositiveInt, PositiveInt]]
slice_direction: SliceDirection
squeeze: bool
@field_validator("tsv_path", mode="after")
@classmethod
def _resolve_tsv_path(cls, v: Optional[Path]) -> Optional[Path]:
"""Resolves 'tsv_path' so it is saved as an absolute path."""
return v.resolve() if v is not None else v
@field_validator("borders", mode="before")
@classmethod
def _ensure_tuple(
cls,
value: Optional[Union[PositiveInt, Tuple[PositiveInt, PositiveInt]]],
) -> Optional[Tuple[PositiveInt, PositiveInt]]:
"""
Ensures that 'borders' is always a tuple.
"""
if not isinstance(value, Sequence) and value is not None:
return (value, value)
return value
@model_validator(mode="after")
def validate_slices(self) -> Self:
"""
Checks consistency between 'slices', 'tsv_path', 'discarded_slices' and 'borders'.
"""
if self.slices and self.tsv_path:
raise ValueError("'slices' and 'tsv_path' can't be passed simultaneously.")
slices_or_tsv = self.slices or self.tsv_path
if slices_or_tsv and self.discarded_slices:
raise ValueError(
"You can't pass 'discarded_slices' if 'slices' or 'tsv_path' was passed."
)
elif slices_or_tsv and self.borders:
raise ValueError(
"You can't pass 'borders' if 'slices' or 'tsv_path' was passed."
)
return self
@classmethod
def _get_class(cls) -> type[Slice]:
"""Returns the class associated to this config class."""
return Slice
[docs]
class Slice(Extraction[SliceConfig]):
"""
Transform class to extract slices from an image in a specified direction.
Adds the following keys to the input :py:class:`~clinicadl.data.structures.DataPoint`:
- ``sample_type`` : ``"slice"``
- ``sample_position``: int
The position of the slice in the original image.
- ``slice_direction``: 0, 1 or 2
The slicing direction.
- ``squeeze``: bool
Whether the tensors will be squeezed to work with 2D neural networks.
.. note::
To select slices, use ``slices``, ``tsv_path``, ``discarded_slices``, or
``borders``.
- If none of these parameters is passed, all slices will be kept.
- ``slices`` and ``tsv_path`` cannot be used in conjunction with another slice selection
parameter, but ``discarded_slices`` and ``borders`` can be passed together.
Parameters
----------
slices : Optional[list[int]], default=None
The slices to select. The slices selected will be the same for all images; if you
want a different selection for each image, use ``tsv_path``.
tsv_path : Optional[PathType], default=None
Path to a ``TSV`` file containing slice indices for each image.
The ``TSV`` table must have the columns: ``participant_id``, ``session_id``, and ``slice_idx``.
discarded_slices : Optional[list[int]], default=None
Indices of the slices to discard. Cannot be used with ``slices`` or ``tsv_path``.
borders : Optional[Union[int, Tuple[int, int]]], default=None
The number of border slices that will be filtered out. If an integer ``a`` is passed, the first
``a`` slices and the last ``a`` slices will be filtered out. If a tuple ``(a, b)`` is passed, the first
``a`` slices and the last ``b`` slices will be filtered out.\n
Cannot be used with ``slices`` or ``tsv_path``.
slice_direction : int | SliceDirection, default=0
The slicing direction. Can be ``0``, ``1`` or ``2``.
.. warning::
Be careful with the orientation of your image. If your image is in :term:`RAS+` (e.g. you used :py:class:`~clinicadl.transforms.config.ToCanonicalConfig`),
``0`` refers to the sagittal direction, ``1`` to the coronal direction, and ``2`` to the axial direction.
squeeze : bool, default=True
Whether to later squeeze slices to have images with 2 spatial dimensions.
If ``False``, slices will still have 3 spatial dimensions.
.. note::
Squeezing will be performed by ``ClinicaDL`` just before putting the images in the neural
network. This is because most of ``ClinicaDL`` tools work with 3D images.
Examples
--------
.. code-block::
from clinicadl.transforms.extraction import Slice
from clinicadl.data.structures.examples import Colin27DataPoint
data = Colin27DataPoint()
slices = Slice(borders=10, slice_direction=1)
.. code-block::
>>> data.spatial_shape
(181, 217, 181)
>>> patch.num_samples_per_image(data)
197
>>> slices(data, sample_index=0).spatial_shape
(181, 1, 181)
>>> next(iter(patch(data))).sample_position
10 # because of 'borders' the 10 first slices are ignored
"""
config: SliceConfig
_config_type = SliceConfig
def __init__(
self,
slices: Optional[list[int]] = None,
tsv_path: Optional[PathType] = None,
discarded_slices: Optional[list[int]] = None,
borders: Optional[Union[int, Tuple[int, int]]] = None,
slice_direction: int | SliceDirection = SliceDirection.ZERO,
squeeze: bool = True,
) -> None:
self.config = SliceConfig(
slices=slices,
tsv_path=tsv_path,
discarded_slices=discarded_slices,
borders=borders,
slice_direction=slice_direction,
squeeze=squeeze,
)
self._map: Optional[Dict[Tuple[str, str], list[int]]] = None
if self.config.tsv_path is not None:
self._map = self._load_tsv(self.config.tsv_path)
@property
def sample_type(self) -> str:
"""
The type of the sample returned by this extraction, among {"image", "slice", "patch"}.
"""
return ImplementedExtraction.SLICE.value.lower()
def _extract_tensor_sample(
self, image_tensor: torch.Tensor, sample_position: int
) -> torch.Tensor:
"""
Gets the wanted slice, according to the slicing direction.
"""
if self.config.slice_direction == 0:
slice_tensor = image_tensor[:, sample_position, :, :]
elif self.config.slice_direction == 1:
slice_tensor = image_tensor[:, :, sample_position, :]
elif self.config.slice_direction == 2:
slice_tensor = image_tensor[:, :, :, sample_position]
return slice_tensor.unsqueeze(self.config.slice_direction + 1) # pylint: disable=possibly-used-before-assignment
def _get_sample_positions(self, data_point: DataPoint) -> list[int]:
"""
Returns the positions of the selected slices in the image.
"""
n_slices = data_point.image.tensor.size(self.config.slice_direction + 1)
selection = np.ones(n_slices).astype(bool)
slice_indices = None
if self._map:
slice_indices = self._slices_for(data_point)
elif self.config.slices:
slice_indices = self.config.slices
if slice_indices:
selection = ~selection
try:
selection[slice_indices] = True
except IndexError as exc:
raise IndexError(
"Invalid slices in 'slices': "
f"slices in the image are indexed from 0 to {n_slices - 1}, but got "
f"slices={self.config.slices}."
) from exc
else:
if self.config.discarded_slices:
try:
selection[self.config.discarded_slices] = False
except IndexError as exc:
raise IndexError(
"Invalid slices in 'discarded_slices': "
f"slices in the image are indexed from 0 to {n_slices - 1}, but got "
f"discarded_slices={self.config.discarded_slices}."
) from exc
if self.config.borders:
selection[: self.config.borders[0]] = False
selection[n_slices - self.config.borders[1] :] = False
return np.arange(len(selection))[selection]
def _add_info(self, data_point: DataPoint, sample_position: int) -> None:
"""
Adds relevant info in the datapoint.
"""
data_point[SAMPLE_TYPE] = self.sample_type
data_point[SAMPLE_POSITION] = sample_position
data_point[SLICE_DIRECTION] = self.config.slice_direction
data_point[SQUEEZE] = self.config.squeeze
@classmethod
def _load_tsv(cls, path: Path) -> Dict[Tuple[str, str], list[int]]:
"""
Reads the TSV file and returns the mapping as a dict. The keys
are the (participant, session) pairs and the values are the list
of selected slices.
"""
df = pd.read_csv(path, sep="\t")
df = cls._normalize_cols(df)
if not np.issubdtype(df["slice_idx"].dtype, np.integer):
try:
df["slice_idx"] = df["slice_idx"].astype(int)
except Exception as e:
raise ValueError("Column 'slice_idx' must contain integers.") from e
mapping: Dict[Tuple[str, str], list[int]] = {}
for (sub, ses), g in df.groupby(["participant_id", "session_id"]):
mapping[(str(sub), str(ses))] = list(map(int, g["slice_idx"].tolist()))
return mapping
@classmethod
def _normalize_cols(cls, df: pd.DataFrame) -> pd.DataFrame:
"""
Lowercases the column names and look for 'participant_id', 'session_id', and 'slice_idx'.
"""
cols = {c.lower(): c for c in df.columns}
subj_col = cols.get("participant_id")
sess_col = cols.get("session_id")
slice_col = cols.get("slice_idx")
if not subj_col or not sess_col or not slice_col:
raise DataFrameError(
"TSV must contain columns: 'participant_id', 'session_id', 'slice_idx'"
)
return df.rename(
columns={
subj_col: "participant_id",
sess_col: "session_id",
slice_col: "slice_idx",
}
)
def _slices_for(self, data_point: DataPoint) -> list[int]:
"""
Gets the slice selection for a specific image.
"""
if self._map is None:
raise RuntimeError("Called _slices_for but no TSV was provided.")
key = (data_point.participant_id, data_point.session_id)
if key not in self._map:
raise ValueError(
f"No slices found in TSV for participant={key[0]}, session={key[1]}."
)
return self._map[key]