clinicadl.split.make_kfold¶
- clinicadl.split.make_kfold(data: Path | str | DataFrame, n_splits: int = 5, output_dir: Path | str | None = None, subset_name: str = 'validation', stratification: str | bool = False, longitudinal: bool = False, seed: int | None = None) Path[source]¶
Performs K-Fold splitting on a DataFrame with optional stratification.
Stratification can be performed based on a categorical variable of the DataFrame.
Note
make_kfoldsplits the participants in your data. This means that, if all the participants don’t have the same number of sessions, you may likely end up with training sets of different sizes across your splits. Besides, by default, only one session per participant is kept in the validation sets (see the argumentlongitudinal).- Parameters:
data (Union[DataFrameType]) – A
pandas.DataFrame(or a path to aTSVfile containing the dataframe) with the list of participant/session pairs to split.n_splits (int, default=5) – Number of folds. Must be at least 2.
output_dir (Optional[Path, str], default=None) – Directory where to save the output files of the split, passed as a
stror a pathlib.Path. Ifdatais a path andoutput_diris not passed, the parent directory of theTSVfile will be used.subset_name (str, default="validation") – Name for the validation subset.
stratification (Union[str, bool], default=False) – Whether to perform stratification. If
True, the columns"sex"will be used for stratification. If astris passed, this column will be used. The variable associated to the column must be categorical.longitudinal (bool, default=False) – Whether to include only the baseline sessions in the validation set (
longitudinal=False). IfTrue, all the sessions of the validation participants will be included. No matter this argument, all sessions are always kept in the training set.seed (Optional[int], default=None) – Seed to control the randomness of the split. Useful for reproducibility.
- Returns:
Path – Directory containing the generated split files.
- Raises:
ValueError – If
datais apandas.DataFrameand nooutput_diris passed.DataFrameError – If the DataFrame does not contain the columns
"participant_id"and"session_id".KeyError – If the stratification column mentioned via
stratificationcannot be found in the DataFrame.ValueError – If the stratification column mentioned via
stratificationis not a categorical variable.
See also
-py:func:~clinicadl.split.make_split
Examples
>>> df.head(5) # quick look at the data participant_id session_id age sex diagnosis 0 sub-003 ses-M000 40 M MCI 1 sub-004 ses-M000 56 M CN 2 sub-004 ses-M054 75 F MCI 3 sub-005 ses-M006 85 F CN 4 sub-005 ses-M018 64 M AD >>> len(df) 64
>>> from clinicadl.split import make_kfold >>> split_dir = make_kfold( df, n_splits=5, output_dir="splits", stratification="sex", ) >>> split_dir PosixPath('splits/5_fold') # splits/5_fold # ├── kfold_config.json # ├── split-0 # │ ├── train.tsv # │ ├── train_baseline.tsv # │ └── validation_baseline.tsv # ├── split-1 # │ └── ... # ├── split-2 # │ └── ... # ├── split-3 # │ └── ... # └── split-4 # └── ...
>>> train_baseline = pd.read_csv(split_dir / "split-0" / "train_baseline.tsv", sep="\t") >>> train_baseline.head(5) participant_id session_id sex 0 sub-003 ses-M000 M 1 sub-005 ses-M006 F 2 sub-006 ses-M006 M 3 sub-014 ses-M000 M 4 sub-015 ses-M006 F >>> len(train_baseline) 32 >>> val_baseline = pd.read_csv(split_dir / "split-0" / "validation_baseline.tsv", sep="\t") >>> val_baseline.head(5) participant_id session_id sex 0 sub-004 ses-M000 M 1 sub-007 ses-M006 F 2 sub-013 ses-M000 M 3 sub-023 ses-M000 M 4 sub-026 ses-M006 M >>> len(val_baseline) 8