clinicadl.split.KFold

class clinicadl.split.KFold(split_dir: Path | str)[source]

To handle a K-Fold cross-validator.

This object will read a split directory returned by make_kfold() and can then be used to split any Dataset using get_splits(), provided that all the (participant, session) pairs in the dataset are mentioned in the split directory.

Parameters:

split_dir (PathType) – The split directory, returned by make_kfold().

Raises:

FileNotFoundError – If split_dir does not exist or if a required file is missing in this directory.

See also

SingleSplit

get_splits(dataset: Dataset, eval_dataset: Dataset | None = None, splits: Sequence[int] | None = None) Generator[Split, None, None][source]

Splits a dataset according to the splits found in the K-Fold directory, and yields the splits by their indices.

Parameters:
  • dataset (Dataset) – The dataset to split.

  • eval_dataset (Optional[Dataset], default=None) – If not None, it will be understood as the dataset from which the validation dataset should be created, and dataset will be the dataset from which the training dataset will be created (see examples). If None, both training and validation datasets are built from dataset.

  • splits (Optional[Sequence[int]], (optional, default=None)) – Indices of the splits to get. If None, will return all the splits.

Yields:

Split – The train and validation datasets for each requested split, in a Split object.

Raises:

IndexError – If one of the requested split indices is out of range.

Examples

>>> df  # a quick look at the data
    participant_id  session_id
0   sub-000         ses-M000
1   sub-000         ses-M003
2   sub-100         ses-M000
3   sub-100         ses-M012
4   sub-999         ses-M099
5   sub-999         ses-M999
from clinicadl.split import KFold
from clinicadl.data.datasets import BidsDataset
from clinicadl.transforms import TransformsHandler, extraction

dataset = BidsDataset(
    "bids_dir",
    data=df,
    transforms=TransformsHandler(extraction=extraction.Patch(patch_size=64)),
    ...
)
splitter = KFold("split_dir/3_fold")

splits_iterator = splitter.get_splits(dataset)
split = next(iter(splits_iterator))
>>> split.train_dataset.df
    participant_id  session_id
0   sub-000         ses-M000
1   sub-000         ses-M003
2   sub-100         ses-M000
3   sub-100         ses-M012
>>> split.val_dataset.df
    participant_id  session_id
0   sub-999         ses-M099
1   sub-999         ses-M999

Now, let’s say you want to train your model on patches, but evaluate it on images:

eval_dataset = BidsDataset(
    "bids_dir",
    data=df,
    transforms=TransformsHandler(),
    ...
)

splits_iterator = splitter.get_splits(dataset, eval_dataset=eval_dataset)
split = next(iter(splits_iterator))
>>> split.train_dataset[0].spatial_shape
(64, 64, 64)
>>> split.val_dataset[0].spatial_shape
(181, 217, 181)