clinicadl.split.SingleSplit¶
- class clinicadl.split.SingleSplit(split_dir: Path | str)[source]¶
To handle a single training-validation split, as opposed to
KFoldthat can handle several splits.This object will read a split directory returned by
make_split()and can then be used to split anyDatasetusingget_split(), provided that all the (participant, session) pairs in the dataset are mentioned in the split directory.- Parameters:
split_dir (PathType) – The split directory, returned by
make_split().- Raises:
FileNotFoundError – If
split_dirdoes not exist or if a required file is missing in this directory.
- get_split(dataset: Dataset, eval_dataset: Dataset | None = None) Split[source]¶
Splits a dataset according to the split found in the split directory.
- Parameters:
dataset (Dataset) – The dataset to split.
eval_dataset (Optional[Dataset], default=None) – If not
None, it will be understood as the dataset from which the validation dataset should be created, anddatasetwill be the dataset from which the training dataset will be created (see examples). IfNone, both training and validation datasets are built fromdataset.
- Returns:
Split – A
Splitobject, with the training and validation datasets.
Examples
>>> df # a quick look at the data participant_id session_id 0 sub-000 ses-M000 1 sub-000 ses-M003 2 sub-010 ses-M003 3 sub-010 ses-M012 4 sub-100 ses-M000 5 sub-100 ses-M012 6 sub-999 ses-M099 7 sub-999 ses-M999
from clinicadl.split import SingleSplit from clinicadl.data.datasets import BidsDataset from clinicadl.transforms import TransformsHandler, extraction dataset = BidsDataset( "bids_dir", data=df, transforms=TransformsHandler(extraction=extraction.Patch(patch_size=64)), ... ) splitter = SingleSplit("split_dir") split = splitter.get_split(dataset)
>>> split.train_dataset.df participant_id session_id 0 sub-000 ses-M000 1 sub-000 ses-M003 2 sub-100 ses-M000 3 sub-100 ses-M012 4 sub-999 ses-M099 5 sub-999 ses-M999 >>> split.val_dataset.df participant_id session_id 0 sub-010 ses-M003
Now, let’s say you want to train your model on patches, but evaluate it on images:
eval_dataset = BidsDataset( "bids_dir", data=df, transforms=TransformsHandler(), ... ) split = splitter.get_split(dataset, eval_dataset=eval_dataset)
>>> split.train_dataset[0].spatial_shape (64, 64, 64) >>> split.val_dataset[0].spatial_shape (181, 217, 181)