Source code for clinicadl.data.datasets.bids

from collections.abc import Sequence
from copy import copy, deepcopy
from pathlib import Path
from typing import Any, Iterable, Optional, TypeAlias

import torchio as tio
from pydantic import Field, field_validator

from clinicadl.io.bids import Bids, BidsFileType
from clinicadl.transforms import TransformsHandler
from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.objects import HasConfig
from clinicadl.utils.typing import DataFrameType, PathType

from ..structures import (
    CommonMask,
    Image,
    IndividualMask,
)
from ..tensors import TensorConversion
from ..utils import DEFAULT_SPATIAL_CHECKS, SpatialCheck
from .bids_utils import (
    BidsNiftiDataset,
    BidsTypeDatasetConfig,
    BidsTypeDatasetWithConfig,
    ColumnsType,
)
from .tensor import TensorDataset

MasksType: TypeAlias = dict[
    str, PathType | BidsFileType | tuple[PathType | Bids, BidsFileType]
]


def _deserialize_masks(serialized_masks: Optional[dict]) -> Optional[MasksType]:
    """
    To read serialized masks.
    """
    if serialized_masks is None:
        return None

    masks = dict()
    for name, mask in serialized_masks.items():
        if isinstance(mask, dict):
            masks[name] = BidsFileType.from_dict(mask)
        elif isinstance(mask, Sequence) and not isinstance(mask, str):
            masks[name] = (Bids.from_dict(mask[0]), BidsFileType.from_dict(mask[1]))
        else:
            masks[name] = mask

    return masks


class BidsDatasetConfig(ObjectConfig["BidsDataset"], BidsTypeDatasetConfig):
    """Config class to check ``BidsDataset`` inputs."""

    bids: Bids = Field(json_schema_extra={"reader": Bids.from_dict})
    file_type: BidsFileType = Field(
        json_schema_extra={"reader": BidsFileType.from_dict}
    )
    masks: Optional[dict[str, Path | BidsFileType | tuple[Bids, BidsFileType]]] = Field(
        json_schema_extra={"reader": _deserialize_masks}
    )

    @field_validator("bids", mode="before")
    @classmethod
    def _convert_to_bids(cls, v: Any) -> Any:
        """
        Convert a path to a ``Bids``.
        """
        if isinstance(v, (str, Path)):
            return Bids(v)
        return v

    @field_validator("masks", mode="before")
    @classmethod
    def _convert_to_bids_(cls, v: Any) -> Any:
        """
        Convert a path to a ``Bids``.
        """
        if isinstance(v, dict):
            for name, value in v.items():
                if isinstance(value, tuple):
                    v[name] = (cls._convert_to_bids(value[0]), value[1])
        return v

    @field_validator("masks", mode="after")
    @classmethod
    def _resolve_path(
        cls, v: Optional[dict[str, Path | BidsFileType | tuple[Bids, BidsFileType]]]
    ) -> Optional[dict[str, Path | BidsFileType | tuple[Bids, BidsFileType]]]:
        if v:
            for name, value in v.items():
                if isinstance(value, Path):
                    v[name] = value.resolve()

        return v

    @classmethod
    def _get_class(cls):
        return BidsDataset


[docs] class BidsDataset( BidsNiftiDataset, HasConfig[BidsDatasetConfig], BidsTypeDatasetWithConfig ): """ A :py:class:`~clinicadl.data.datasets.Dataset` working with neuroimaging data organized in :term:`BIDS` (or derivative) format. The user specifies the path to the :term:`BIDS` directory via ``bids``, the type of data to load via ``file_type``, and the (participant, session) pairs to work on via ``data``. ``BidsDataset`` loads the image and the potential masks (see ``masks`` argument), and puts them in a :py:class:`~clinicadl.data.structures.Sample`. The user can add additional data in this ``Sample`` via the arguments ``columns``. Transformations (e.g., preprocessing or data augmentation) can be applied to the loaded data (see ``transforms`` argument). With ``BidsDataset``, it is possible to work on the whole images, or on patches or slices extracted from the images. This is also specified via the ``transforms`` argument (e.g., ``transforms=TransformsHandler(extraction=Slice())``). .. note:: - The size of the ``BidsDataset`` depends on the type of data you are working on. For example, if you have 10 images with 100 slices each, and you want to work on slices, the length of your dataset will be :math:`10\\times100=1,000`. - To avoid confusion, we will use the term "sample" to refer to the actual element of the images we are working on (patch, slice or the whole image). Finally, you may be interested in :py:meth:`to_tensors`, that will convert your :term:`NIfTI` images to tensors (saved in ``.pt`` files). Since opening a ``.pt`` file is much faster than opening a NIfTI file, this may speed up data loading. Parameters ---------- bids : PathType | Bids The :term:`BIDS` (or derivative) directory where the data will be loaded from. Can be passed as a path or directly as :py:class:`~clinicadl.io.bids.Bids`. file_type : BidsFileType Defines the files to load in the BIDS directory. The :py:class:`~clinicadl.io.bids.BidsFileType` must contain the requirements necessary to select only the relevant files. data : Optional[DataFrameType], default=None A :py:class:`pandas.DataFrame` (or a path to a ``TSV`` file containing the DataFrame) with the list of (participant, session) pairs to consider, as well as any other relevant information (e.g. the age of the participants). Only (participant, session) pairs mentioned in this TSV file will be in the ``BidsDataset``. If ``None``, all (participant, session) pairs in ``bids`` that have the right ``file_type`` will be considered. .. warning:: Be careful if you pass a DataFrame with a column named ``"n_samples"``. ``BidsDataset`` will understand it as the number of samples for each (participant, session) pair. transforms : TransformsHandler, default=TransformsHandler() Transformation pipeline to apply to the data after loading. The user also specifies here whether to work on images, patches, or slices. See :py:class:`clinicadl.transforms.TransformsHandler`. columns : Optional[ColumnsType], default=None Columns to get in the DataFrame ``data`` and to put in the output :py:class:`~clinicadl.data.structures.Sample`. Can be passed via: - a list of strings (e.g. ``["age", "sex"]``), corresponding to the names of the columns; - or a dictionary (e.g. ``{"age": <function>, "sex": None}``), where the keys are the names of the columns, and the values are functions to apply to the columns. If the function is ``None``, no function will be applied to the column. .. note:: The potential functions applied to the columns are applied to the **whole column**. They must take as input a :py:class:`pandas.Series`, and return a :py:class:`pandas.Series`. For example, it is useful to convert string labels to integer labels for classification. masks : Optional[MasksType], default=None Masks to be loaded along with the images. The masks are passed via a dictionary, whose names will be the names given to the masks in the output :py:class:`~clinicadl.data.structures.Sample`, and whose values can be: - a path (``str`` or :pathlib.Path:`pathlib.Path <>`) to a :term:`NIfTI` image: the same mask is used for all the (participant, session) pairs. - a :py:class:`~clinicadl.io.bids.BidsFileType`: the mask is participant- and session-specific and the pattern to find the mask in the ``bids`` is given via the ``BidsFileType``. - a tuple (PathType | :py:class:`~clinicadl.io.bids.Bids`, :py:class:`~clinicadl.io.bids.BidsFileType`): the mask is participant- and session-specific but is not in the same BIDS dataset as the image. So, here the BIDS where to look for the mask must be passed in the first element of the tuple. Raises ------ DataFrameError If the DataFrame in ``data`` is empty. DataFrameError If the DataFrame in ``data`` does not contain the columns ``"participant_id"`` and ``"session_id"``. DataFrameError If the DataFrame in ``data`` contains duplicated (``participant_id``, ``session_id``) pairs. RuntimeError If for some (participant, session) pairs, an image corresponding to ``file_type`` cannot be found in ``bids``. ValueError If a key is used in ``columns`` and ``masks``. ValueError If a key in ``columns`` or ``masks`` is equal to the name of one of the attributes of :py:class:`~clinicadl.data.structures.Sample`. Examples -------- .. code-block:: text bids ├── dataset_description.json ├── metadata.tsv ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ ├── sub-001_ses-M000_T1w.nii.gz │ │ └── sub-001_ses-M000_label-head_mask.nii.gz │ ... ... └── derivatives ├── registration │ ├── space-MNI152NLin2009cSym_mask.nii.gz │ ... └── masks ├── dataset_description.json ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ └── sub-001_ses-M000_label-brain_mask.nii.gz │ ... ... The "metadata.tsv" file looks like: participant_id session_id age sex diagnosis sub-001 ses-M000 55.0 M control sub-001 ses-M024 57.0 M control sub-002 ses-M000 62.0 F control sub-002 ses-M024 64.0 F patient sub-003 ses-M000 67.0 F patient ... .. code-block:: python from clinicadl.data.datasets import BidsDataset from clinicadl.io.bids import BidsFileType from clinicadl.transforms import TransformsHandler, extraction from clinicadl.transforms.config import ( ZNormalizationConfig, ResampleConfig, RandomFlipConfig, ) import pandas as pd # to convert diagnosis to numeric values def diagnosis_to_number(column: pd.Series) -> pd.Series: encoding = {"control": 0, "patient": 1} return column.apply(lambda x: encoding[x]) .. code-block:: python >>> dataset = BidsDataset( bids="bids", file_type=BidsFileType( data_type="anat", suffix="T1w", ), data="bids/metadata.tsv", columns=["age"], ) >>> dataset[0] Sample(Keys: ('age', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 1) >>> dataset[0].spatial_shape (169, 208, 179) # full image >>> len(dataset) 50 # 50 lines in the metadata.tsv >>> dataset[0].participant_id, dataset[0].session_id, dataset[0].age 'sub-001', 'ses-M000', 55.0 .. code-block:: python >>> dataset = BidsDataset( bids="bids", file_type=BidsFileType( data_type="anat", suffix="T1w", ), data="bids/metadata.tsv", columns={"age": None, "diagnosis": diagnosis_to_number}, transforms=TransformsHandler( extraction=extraction.Patch(patch_size=64), ), ) >>> dataset[0]["diagnosis"] 0 # diagnosis is now encoded >>> dataset[0].spatial_shape (64, 64, 64) # patches >>> len(dataset) 1800 # 36 patches per image .. code-block:: python >>> dataset = BidsDataset( bids="bids", file_type=BidsFileType( data_type="anat", suffix="T1w", ), transforms=TransformsHandler( image_transforms=[ ResampleConfig(target="mni"), # masks can be used in transforms ZNormalizationConfig(masking_method="head"), ], augmentations=[RandomFlipConfig(flip_probability=0.5)], ), masks={ "head": BidsFileType( data_type="anat", suffix="mask", with_entities={"label": "head"} ), # participant- and session-specific mask that is in the same BIDS "brain": ( "bids/derivatives/masks", BidsFileType( data_type="anat", suffix="mask", with_entities={"label": "brain"} ), # participant- and session-specific mask that is in another BIDS ), "mni": "bids/derivatives/registration/space-MNI152NLin2009cSym_mask.nii.gz", # same mask for all (participant, session) }, ) >>> dataset[0] Sample(Keys: ('head', 'brain', 'mni', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 4) >>> len(dataset) 60 # all the (participant, session) that have T1w images. Not only the ones in metadata.tsv See Also -------- :py:class:`~clinicadl.data.datasets.TensorDataset` :py:class:`~clinicadl.data.datasets.ConcatDataset` :py:class:`~clinicadl.data.datasets.PairedDataset` :py:class:`~clinicadl.data.datasets.UnpairedDataset` """ _config_type = BidsDatasetConfig def __init__( self, bids: PathType | Bids, file_type: BidsFileType, data: Optional[DataFrameType] = None, transforms: TransformsHandler = TransformsHandler(), columns: Optional[ColumnsType] = None, masks: Optional[MasksType] = None, ): self.config = self._config_type( bids=bids, file_type=file_type, data=data, transforms=transforms, columns=columns, masks=masks, ) super().__init__( image=Image(self.config.bids, self.config.file_type), data=self.config.data, transforms=self.config.transforms, columns=self.config.columns, masks=self._read_masks(copy(self.config.masks)), ) def _read_masks( self, masks: Optional[dict[str, PathType | BidsFileType | tuple[Bids, BidsFileType]]], ) -> Optional[dict[str, IndividualMask | CommonMask]]: """ Converts masks to the right format. """ if not masks: return None for name, mask in masks.items(): if isinstance(mask, BidsFileType): masks[name] = IndividualMask(self.config.bids, mask) elif isinstance(mask, tuple): masks[name] = IndividualMask(mask[0], mask[1]) else: masks[name] = CommonMask(mask) return masks
[docs] def to_tensors( self, conversion_name: Optional[str] = None, spatial_checks: Optional[Iterable[str | SpatialCheck]] = DEFAULT_SPATIAL_CHECKS, save_transforms: bool = False, description: Optional[str] = None, overwrite: bool = False, check_transforms: bool = True, n_proc: int = 1, ) -> TensorDataset: """ Converts NifTI files in the current ``BidsDataset`` to tensors (in PyTorch's ``.pt`` format). Conversion to tensors may significantly **speed up data loading**. The tensors are saved in a BIDS derivative named ``tensors``. The location of this folder relative to the original BIDS depends on the type of BIDS (see :py:class:`clinicadl.io.bids.Bids`). A ``.json`` file describing the conversion is also saved at the root of ``tensors``, as well as other metadata files (see examples). Masks will be converted and saved in the same file as the image they are associated with. The user has the possibility to save transformed images, i.e. images on which image transforms have already been applied (see ``image_transforms`` in :py:class:`clinicadl.transforms.TransformsHandler`). This practice will speed up dataloading during training or inference as the images will not have to be transformed each time they are loaded. .. note:: Images are converted to the same coordinate system (:term:`RAS+`). Parameters ---------- conversion_name : Optional[str], default=None The name of the tensor conversion. Must be **alphanumerical**. The output tensors and the output ``.json`` file describing the conversion will have ``"conv-<conversion_name>"`` in their filenames. If a conversion with this name already exists: - if ``overwrite=True``, the old tensors will be overwritten; - else, ``to_tensors`` will try to append the new conversion to the pre-existing tensor conversion if they concern the same type of data (same modality, same transforms applied, etc.), otherwise an error will be raised. If ``None``, the conversion name will be ``"raw"``. ``conversion_name`` **cannot** be ``None`` if ``save_transforms=True``. spatial_checks : Optional[Iterable[str | SpatialCheck]], default=("affine", "shape", "global_spacing") Potential spatial checks to perform on the images while converting them: - ``"spacing"``: checks **intra-sample voxel spacing consistency**, i.e. that all the images and masks in the :py:class:`~clinicadl.data.structures.Sample` output by the current dataset have the same voxel spacing. - ``"affine"``: checks **intra-sample affine matrix consistency** (so it includes ``"spacing"``). - ``"shape"``: checks **intra-sample spatial shape consistency**. - ``"global_spacing"``: checks **inter-sample voxel spacing consistency**, i.e. that all the ``Samples`` in the dataset have the same voxel spacing (so it includes ``"spacing"``). - "``global_shape"``: checks **inter-sample spatial shape consistency** (so it includes ``"shape"``). If ``None``, no spatial check performed. save_transforms : bool, default=False Whether to save raw images without transforms as tensors (``False``), or images with the applied transforms (``True``). description : Optional[str], default=None A potential description of the tensor conversion that will be saved in the description ``.json`` file. overwrite : bool, default=False Whether to overwrite a pre-existing tensor conversion that has the same ``conversion_name``. If a conversion named ``conversion_name`` already exists and ``overwrite=False``, ``to_tensors`` will try to append the current tensor conversion to the pre-existing one. check_transforms : bool, default=True ``check_transforms`` determines whether transforms will be checked when appending to a pre-existing conversion. If ``True``, ``to_tensors`` will check that the current transforms match the transforms applied during the pre-existing conversion.\n ``check_transforms=False`` is useful when you use custom transforms (i.e. transforms not in ``ClinicaDL``), which cannot be checked. .. note:: If ``save_transforms=False``, no such check will be performed. .. warning:: **To use carefully**. You must be sure that the transforms match before setting ``check_transforms=False``. n_proc : int, default=1 Number of cores to use to parallelize the conversion. Returns ------- TensorDataset A ``TensorDataset`` containing the converted tensors. Raises ------ ValueError If the user passed ``"raw"`` as a ``conversion_name``. ValueError If ``conversion_name`` is ``None`` and ``save_transforms=True``. TensorConversionError If a conversion named ``conversion_name`` already exists and the new conversion cannot be appended to the pre-existing one. RuntimeError If some checks in ``spatial_checks`` fail. Examples -------- .. code-block:: text bids ├── dataset_description.json ├── metadata.tsv ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ ├── sub-001_ses-M000_T1w.nii.gz │ │ └── sub-001_ses-M000_label-head_mask.nii.gz │ ... ... └── derivatives └── registration ├── space-MNI152NLin2009cSym_mask.nii.gz ... .. code-block:: python from clinicadl.data.datasets import BidsDataset from clinicadl.io.bids import BidsFileType from clinicadl.transforms import TransformsHandler, extraction from clinicadl.transforms.config import ResampleConfig dataset = BidsDataset( bids="bids", file_type=BidsFileType( data_type="anat", suffix="T1w", ), transforms=TransformsHandler( image_transforms=[ ResampleConfig(target="mni"), ], ), masks={ "head": BidsFileType( data_type="anat", suffix="mask", with_entities={"label": "head"} ), "mni": "bids/derivatives/registration/space-MNI152NLin2009cSym_mask.nii.gz", }, ) tensor_dataset = dataset.to_tensors( conversion_name="T1WithMasks", save_transforms=True, ) Data are now as follows: .. code-block:: text bids ├── dataset_description.json ├── metadata.tsv ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ ├── sub-001_ses-M000_T1w.nii.gz │ │ └── sub-001_ses-M000_label-head_mask.nii.gz │ ... ... └── derivatives ├── registration │ ├── space-MNI152NLin2009cSym_mask.nii.gz │ ... └── tensors ├── dataset_description.json ├── conversions.tsv <- contains the list of all the conversions ├── src-T1w_conv-T1WithMasks_description.json <- contains a description of the conversion ├── src-T1w_conv-T1WithMasks_participantsXsessions.tsv <- contains the list of the (participant, session) pairs converted ├── sub-001 │ ├── ses-M000 │ │ └── anat │ │ ├── sub-001_ses-M000_src-T1w_conv-T1WithMasks_tensors.json <- contains path to the source files │ │ └── sub-001_ses-M000_src-T1w_conv-T1WithMasks_tensors.pt <- contains the tensors (the transformed image and masks) │ ... ... .. code-block:: python >>> tensor_dataset[0] Sample(Keys: ('head', 'mni', 'file_type', 'image_path', 'sample_type', 'sample_position', 'image', 'participant_id', 'session_id'); images: 3) """ converter = TensorConversion(self) conversion = converter.to_tensors( conversion_name=conversion_name, spatial_checks=spatial_checks, save_transforms=save_transforms, description=description, overwrite=overwrite, check_transforms=check_transforms, n_proc=n_proc, ) transforms = deepcopy(self.transforms) if conversion.transforms: transforms.image_transforms = tio.Compose([]) return TensorDataset( conversion.get_json_path(converter.tensors_dir.path), data=copy(self.df), transforms=transforms, columns=copy(self.columns), to_load=list(conversion.masks.keys()) + conversion.additional_data, )