clinicadl.split.SingleSplit

class clinicadl.split.SingleSplit(split_dir: Path | str)[source]

To handle a single training-validation split, as opposed to KFold that can handle several splits.

This object will read a split directory returned by make_split() and can then be used to split any Dataset using get_split(), provided that all the (participant, session) pairs in the dataset are mentioned in the split directory.

Parameters:

split_dir (PathType) – The split directory, returned by make_split().

Raises:

FileNotFoundError – If split_dir does not exist or if a required file is missing in this directory.

get_split(dataset: Dataset, eval_dataset: Dataset | None = None) Split[source]

Splits a dataset according to the split found in the split directory.

Parameters:
  • dataset (Dataset) – The dataset to split.

  • eval_dataset (Optional[Dataset], default=None) – If not None, it will be understood as the dataset from which the validation dataset should be created, and dataset will be the dataset from which the training dataset will be created (see examples). If None, both training and validation datasets are built from dataset.

Returns:

Split – A Split object, with the training and validation datasets.

Examples

>>> df  # a quick look at the data
    participant_id  session_id
0   sub-000         ses-M000
1   sub-000         ses-M003
2   sub-010         ses-M003
3   sub-010         ses-M012
4   sub-100         ses-M000
5   sub-100         ses-M012
6   sub-999         ses-M099
7   sub-999         ses-M999
from clinicadl.split import SingleSplit
from clinicadl.data.datasets import BidsDataset
from clinicadl.transforms import TransformsHandler, extraction

dataset = BidsDataset(
    "bids_dir",
    data=df,
    transforms=TransformsHandler(extraction=extraction.Patch(patch_size=64)),
    ...
)
splitter = SingleSplit("split_dir")
split = splitter.get_split(dataset)
>>> split.train_dataset.df
    participant_id  session_id
0   sub-000         ses-M000
1   sub-000         ses-M003
2   sub-100         ses-M000
3   sub-100         ses-M012
4   sub-999         ses-M099
5   sub-999         ses-M999
>>> split.val_dataset.df
    participant_id  session_id
0   sub-010         ses-M003

Now, let’s say you want to train your model on patches, but evaluate it on images:

eval_dataset = BidsDataset(
    "bids_dir",
    data=df,
    transforms=TransformsHandler(),
    ...
)

split = splitter.get_split(dataset, eval_dataset=eval_dataset)
>>> split.train_dataset[0].spatial_shape
(64, 64, 64)
>>> split.val_dataset[0].spatial_shape
(181, 217, 181)