from pathlib import Path
from typing import Generator, Optional, Sequence
from pydantic import PositiveInt, field_validator
from clinicadl.data.datasets import Dataset
from clinicadl.split.split import Split
from clinicadl.split.splitter.splitter import (
Splitter,
SplitterConfig,
SubjectsSessionsSplit,
)
from clinicadl.utils.dictionary.words import SPLIT
class KFoldConfig(SplitterConfig):
"""
Configuration for K-Fold cross-validation splits.
"""
_json_name: str = "kfold_config"
n_splits: PositiveInt
stratification: Optional[str]
@field_validator("n_splits", mode="after")
@classmethod
def _n_splits_validator(cls, v: int) -> int:
"""Checks that 'n_splits' is greater than 2."""
assert v >= 2, "'n_splits' must be at least 2."
return v
def get_split_subdir(self, split: int, create: bool = False) -> Path:
"""
Returns the subdirectory of a split of a K-Fold, and creates this directory if it does not
exist yet.
Parameters
----------
split : int
The index of the split.
create : bool (optional, default=False)
Create the directory if it doesn't exist.
Returns
-------
Path
The path to the split subdirectory.
"""
split_dir = self.split_dir / f"{SPLIT}-{split}"
if create:
split_dir.mkdir(parents=True, exist_ok=True)
return split_dir
def _check_split_dirs(self) -> None:
"""Checks all the splits directories."""
for i in range(self.n_splits):
self._check_split_dir(self.get_split_subdir(i))
[docs]
class KFold(Splitter):
"""
To handle a K-Fold cross-validator.
This object will read a split directory returned by :py:func:`~clinicadl.split.make_kfold`
and can then be used to split any :py:class:`~clinicadl.data.datasets.Dataset` using :py:meth:`~KFold.get_splits`,
provided that all the (participant, session) pairs in the dataset are mentioned in the split directory.
Parameters
----------
split_dir : PathType
The split directory, returned by :py:func:`~clinicadl.split.make_kfold`.
Raises
------
FileNotFoundError
If ``split_dir`` does not exist or if a required file is missing in this directory.
See Also
--------
:py:class:`~clinicadl.split.SingleSplit`
"""
_config_type = KFoldConfig
[docs]
def get_splits(
self,
dataset: Dataset,
eval_dataset: Optional[Dataset] = None,
splits: Optional[Sequence[int]] = None,
) -> Generator[Split, None, None]:
"""
Splits a dataset according to the splits found in the K-Fold directory, and
yields the splits by their indices.
Parameters
----------
dataset : Dataset
The dataset to split.
eval_dataset : Optional[Dataset], default=None
If not ``None``, it will be understood as the dataset from which the validation dataset should be created, and
``dataset`` will be the dataset from which the training dataset will be created (see examples). If ``None``, both
training and validation datasets are built from ``dataset``.
splits : Optional[Sequence[int]], (optional, default=None)
Indices of the splits to get. If ``None``, will return all the splits.
Yields
------
Split
The train and validation datasets for each requested split, in a :py:class:`~clinicadl.split.Split`
object.
Raises
------
IndexError
If one of the requested split indices is out of range.
Examples
--------
.. code-block::
>>> df # a quick look at the data
participant_id session_id
0 sub-000 ses-M000
1 sub-000 ses-M003
2 sub-100 ses-M000
3 sub-100 ses-M012
4 sub-999 ses-M099
5 sub-999 ses-M999
.. code-block::
from clinicadl.split import KFold
from clinicadl.data.datasets import BidsDataset
from clinicadl.transforms import TransformsHandler, extraction
dataset = BidsDataset(
"bids_dir",
data=df,
transforms=TransformsHandler(extraction=extraction.Patch(patch_size=64)),
...
)
splitter = KFold("split_dir/3_fold")
splits_iterator = splitter.get_splits(dataset)
split = next(iter(splits_iterator))
.. code-block::
>>> split.train_dataset.df
participant_id session_id
0 sub-000 ses-M000
1 sub-000 ses-M003
2 sub-100 ses-M000
3 sub-100 ses-M012
>>> split.val_dataset.df
participant_id session_id
0 sub-999 ses-M099
1 sub-999 ses-M999
Now, let's say you want to train your model on patches, but evaluate it on images:
.. code-block::
eval_dataset = BidsDataset(
"bids_dir",
data=df,
transforms=TransformsHandler(),
...
)
splits_iterator = splitter.get_splits(dataset, eval_dataset=eval_dataset)
split = next(iter(splits_iterator))
.. code-block::
>>> split.train_dataset[0].spatial_shape
(64, 64, 64)
>>> split.val_dataset[0].spatial_shape
(181, 217, 181)
"""
if splits is None:
splits = list(range(self.config.n_splits))
for split in splits:
if split not in range(self.config.n_splits):
raise IndexError(
f"Split '{split}' doesn't exist. There are {self.config.n_splits} splits, numbered from 0 to {self.config.n_splits - 1}."
)
yield self._get_split(
split_id=split, dataset=dataset, eval_dataset=eval_dataset
)
def _read_splits(self) -> list[SubjectsSessionsSplit]:
"""
Load all splits in 'split_dir' from the tsv files.
"""
self.config: KFoldConfig
return [
self._read_split(self.config.get_split_subdir(i))
for i in range(self.config.n_splits)
]