Source code for clinicadl.data.datasets.tensor

from collections.abc import Sequence
from pathlib import Path
from typing import Optional

from pydantic import field_validator

from clinicadl.io.bids import Bids
from clinicadl.transforms import TransformsHandler
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.objects import HasConfig
from clinicadl.utils.typing import DataFrameType, PathType

from ..structures import Tensor
from ..tensors import TensorDescription
from .bids_utils import (
    BidsTensorDataset,
    BidsTypeDatasetConfig,
    BidsTypeDatasetWithConfig,
    ColumnsType,
)


class TensorDatasetConfig(ObjectConfig["TensorDataset"], BidsTypeDatasetConfig):
    """Config class to check ``TensorDataset`` inputs."""

    description_json: Path
    to_load: Optional[Sequence[str]]

    @field_validator("description_json", mode="after")
    @classmethod
    def _resolve_path(cls, v: Path) -> Path:
        return v.resolve()

    @classmethod
    def _get_class(cls):
        return TensorDataset


[docs] class TensorDataset( BidsTensorDataset, HasConfig[TensorDatasetConfig], BidsTypeDatasetWithConfig ): """ A :py:class:`~clinicadl.data.datasets.Dataset` to read images saved as tensors in ``.pt`` files. This dataset enables to load tensors saved with :py:meth:`BidsDataset.to_tensors <clinicadl.data.datasets.BidsDataset.to_tensors>`. Parameters ---------- description_json : PathType The path to the ``.json`` file saved by :py:meth:`BidsDataset.to_tensors <clinicadl.data.datasets.BidsDataset.to_tensors>` that describes the conversion. data : Optional[DataFrameType], default=None A :py:class:`pandas.DataFrame` (or a path to a ``TSV`` file containing the DataFrame) with the list of (participant, session) pairs to consider, as well as any other relevant information (e.g. the age of the participants). Only (participant, session) pairs mentioned in this TSV file will be in the ``TensorDataset``. If ``None``, all (participant, session) pairs whose images have been converted during this conversion will be considered. .. warning:: Be careful if you pass a DataFrame with a column named ``"n_samples"``. ``BidsDataset`` will understand it as the number of samples for each (participant, session) pair. transforms : TransformsHandler, default=TransformsHandler() Transformation pipeline to apply to the data after loading. The user also specifies here whether to work on images, patches, or slices. See :py:class:`clinicadl.transforms.TransformsHandler`. .. warning:: If transformed images were saved in the ``.pt`` files, make sure that you don't apply these transforms again here (``image_transforms`` should probably be empty in the ``TransformsHandler`` here). columns : Optional[ColumnsType], default=None Columns to get in the DataFrame ``data`` and to put in the output :py:class:`~clinicadl.data.structures.Sample`. Can be passed via: - a list of strings (e.g. ``["age", "sex"]``), corresponding to the names of the columns; - or a dictionary (e.g. ``{"age": <function>, "sex": None}``), where the keys are the names of the columns, and the values are functions to apply to the columns. If the function is ``None``, no function will be applied to the column. .. note:: The potential functions applied to the columns are applied to the **whole column**. They must take as input a :py:class:`pandas.Series`, and return a :py:class:`pandas.Series`. For example, it is useful to convert string labels to integer labels for classification. to_load : Optional[Sequence[str]], default=None The data to load from the ``.pt`` files. Data saved in this files are described in the descriptive ``.json`` file. If ``None``, everything inside the files will be loaded. Examples -------- .. code-block:: text bids ├── dataset_description.json ├── metadata.tsv ... └── derivatives └── tensors ├── dataset_description.json ├── conversions.tsv ├── src-T1w_conv-T1WithMasks_description.json ├── src-T1w_conv-T1WithMasks_participantsXsessions.tsv ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ ├── sub-001_ses-M000_src-T1w_conv-T1WithMasks_tensors.json │ │ └── sub-001_ses-M000_src-T1w_conv-T1WithMasks_tensors.pt <- contains the image + 2 masks named 'head' adn 'mni' │ ... ... The "metadata.tsv" file looks like: participant_id session_id age sex diagnosis sub-001 ses-M000 55.0 M control sub-001 ses-M024 57.0 M control sub-002 ses-M000 62.0 F control sub-002 ses-M024 64.0 F patient sub-003 ses-M000 67.0 F patient ... .. code-block:: from clinicadl.data.datasets import TensorDataset from clinicadl.transforms import TransformsHandler, extraction import pandas as pd # to convert diagnosis to numeric values def diagnosis_to_number(column: pd.Series) -> pd.Series: encoding = {"CN": 0, "MCI": 1, "AD": 2} return column.apply(lambda x: encoding[x]) .. code-block:: >>> dataset = TensorDataset( description_json="bids/derivatives/tensors/src-T1w_conv-T1WithMasks_description.json", data="bids/metadata.tsv", columns=["age"], ) >>> dataset[0] Sample(Keys: ('head', 'mni', 'age', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 3) >>> dataset[0].spatial_shape (169, 208, 179) .. code-block:: >>> dataset = TensorDataset( description_json=bids / "derivatives" / "tensors" / "res-1d3x1d2x1d1_src-T1w_conv-T1Masks_description.json", transforms=TransformsHandler( extraction=extraction.Patch(patch_size=2), ), to_load=["head"], ) >>> dataset[0] Sample(Keys: ('head', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 3) >>> dataset[0].spatial_shape (64, 64, 64) See Also -------- :py:class:`clinicadl.data.datasets.BidsDataset` """ _config_type = TensorDatasetConfig def __init__( self, description_json: PathType, data: Optional[DataFrameType] = None, transforms: TransformsHandler = TransformsHandler(), columns: Optional[ColumnsType] = None, to_load: Optional[Sequence[str]] = None, ): self.config = self._config_type( description_json=description_json, data=data, transforms=transforms, columns=columns, to_load=to_load, ) tensors = TensorDescription.read(self.config.description_json) super().__init__( tensor=Tensor( Bids(self.config.description_json.parent), tensors.tensor_type, to_load=self.config.to_load, ), data=self.config.data, transforms=self.config.transforms, columns=self.config.columns, )