clinicadl.split.make_kfold

clinicadl.split.make_kfold(data: Path | str | DataFrame, n_splits: int = 5, output_dir: Path | str | None = None, subset_name: str = 'validation', stratification: str | bool = False, longitudinal: bool = False, seed: int | None = None) Path[source]

Performs K-Fold splitting on a DataFrame with optional stratification.

Stratification can be performed based on a categorical variable of the DataFrame.

Note

make_kfold splits the participants in your data. This means that, if all the participants don’t have the same number of sessions, you may likely end up with training sets of different sizes across your splits. Besides, by default, only one session per participant is kept in the validation sets (see the argument longitudinal).

Parameters:
  • data (Union[DataFrameType]) – A pandas.DataFrame (or a path to a TSV file containing the dataframe) with the list of participant/session pairs to split.

  • n_splits (int, default=5) – Number of folds. Must be at least 2.

  • output_dir (Optional[Path, str], default=None) – Directory where to save the output files of the split, passed as a str or a pathlib.Path. If data is a path and output_dir is not passed, the parent directory of the TSV file will be used.

  • subset_name (str, default="validation") – Name for the validation subset.

  • stratification (Union[str, bool], default=False) – Whether to perform stratification. If True, the columns "sex" will be used for stratification. If a str is passed, this column will be used. The variable associated to the column must be categorical.

  • longitudinal (bool, default=False) – Whether to include only the baseline sessions in the validation set (longitudinal=False). If True, all the sessions of the validation participants will be included. No matter this argument, all sessions are always kept in the training set.

  • seed (Optional[int], default=None) – Seed to control the randomness of the split. Useful for reproducibility.

Returns:

Path – Directory containing the generated split files.

Raises:
  • ValueError – If data is a pandas.DataFrame and no output_dir is passed.

  • DataFrameError – If the DataFrame does not contain the columns "participant_id" and "session_id".

  • KeyError – If the stratification column mentioned via stratification cannot be found in the DataFrame.

  • ValueError – If the stratification column mentioned via stratification is not a categorical variable.

See also

-

py:func:~clinicadl.split.make_split

Examples

>>> df.head(5)  # quick look at the data
    participant_id  session_id      age     sex     diagnosis
0   sub-003         ses-M000        40      M       MCI
1   sub-004         ses-M000        56      M       CN
2   sub-004         ses-M054        75      F       MCI
3   sub-005         ses-M006        85      F       CN
4   sub-005         ses-M018        64      M       AD
>>> len(df)
64
>>> from clinicadl.split import make_kfold
>>> split_dir = make_kfold(
        df,
        n_splits=5,
        output_dir="splits",
        stratification="sex",
    )
>>> split_dir
PosixPath('splits/5_fold')
# splits/5_fold
# ├── kfold_config.json
# ├── split-0
# │   ├── train.tsv
# │   ├── train_baseline.tsv
# │   └── validation_baseline.tsv
# ├── split-1
# │   └── ...
# ├── split-2
# │   └── ...
# ├── split-3
# │   └── ...
# └── split-4
#     └── ...
>>> train_baseline = pd.read_csv(split_dir / "split-0" / "train_baseline.tsv", sep="\t")
>>> train_baseline.head(5)
    participant_id  session_id      sex
0   sub-003         ses-M000        M
1   sub-005         ses-M006        F
2   sub-006         ses-M006        M
3   sub-014         ses-M000        M
4   sub-015         ses-M006        F
>>> len(train_baseline)
32
>>> val_baseline = pd.read_csv(split_dir / "split-0" / "validation_baseline.tsv", sep="\t")
>>> val_baseline.head(5)
    participant_id  session_id      sex
0   sub-004         ses-M000        M
1   sub-007         ses-M006        F
2   sub-013         ses-M000        M
3   sub-023         ses-M000        M
4   sub-026         ses-M006        M
>>> len(val_baseline)
8