clinicadl.split.KFold¶
- class clinicadl.split.KFold(split_dir: Path | str)[source]¶
To handle a K-Fold cross-validator.
This object will read a split directory returned by
make_kfold()and can then be used to split anyDatasetusingget_splits(), provided that all the (participant, session) pairs in the dataset are mentioned in the split directory.- Parameters:
split_dir (PathType) – The split directory, returned by
make_kfold().- Raises:
FileNotFoundError – If
split_dirdoes not exist or if a required file is missing in this directory.
See also
- get_splits(dataset: Dataset, eval_dataset: Dataset | None = None, splits: Sequence[int] | None = None) Generator[Split, None, None][source]¶
Splits a dataset according to the splits found in the K-Fold directory, and yields the splits by their indices.
- Parameters:
dataset (Dataset) – The dataset to split.
eval_dataset (Optional[Dataset], default=None) – If not
None, it will be understood as the dataset from which the validation dataset should be created, anddatasetwill be the dataset from which the training dataset will be created (see examples). IfNone, both training and validation datasets are built fromdataset.splits (Optional[Sequence[int]], (optional, default=None)) – Indices of the splits to get. If
None, will return all the splits.
- Yields:
Split – The train and validation datasets for each requested split, in a
Splitobject.- Raises:
IndexError – If one of the requested split indices is out of range.
Examples
>>> df # a quick look at the data participant_id session_id 0 sub-000 ses-M000 1 sub-000 ses-M003 2 sub-100 ses-M000 3 sub-100 ses-M012 4 sub-999 ses-M099 5 sub-999 ses-M999
from clinicadl.split import KFold from clinicadl.data.datasets import BidsDataset from clinicadl.transforms import TransformsHandler, extraction dataset = BidsDataset( "bids_dir", data=df, transforms=TransformsHandler(extraction=extraction.Patch(patch_size=64)), ... ) splitter = KFold("split_dir/3_fold") splits_iterator = splitter.get_splits(dataset) split = next(iter(splits_iterator))
>>> split.train_dataset.df participant_id session_id 0 sub-000 ses-M000 1 sub-000 ses-M003 2 sub-100 ses-M000 3 sub-100 ses-M012 >>> split.val_dataset.df participant_id session_id 0 sub-999 ses-M099 1 sub-999 ses-M999
Now, let’s say you want to train your model on patches, but evaluate it on images:
eval_dataset = BidsDataset( "bids_dir", data=df, transforms=TransformsHandler(), ... ) splits_iterator = splitter.get_splits(dataset, eval_dataset=eval_dataset) split = next(iter(splits_iterator))
>>> split.train_dataset[0].spatial_shape (64, 64, 64) >>> split.val_dataset[0].spatial_shape (181, 217, 181)