clinicadl.split.make_split

clinicadl.split.make_split(data: Path | str | DataFrame, n_test: float = 0.2, output_dir: Path | str | None = None, subset_name: str = 'test', stratification: Sequence[str] | bool = False, p_categorical_threshold: float = 0.8, p_continuous_threshold: float = 0.8, longitudinal: bool = False, n_try_max: int = 1000, seed: int | None = None) Path[source]

Performs a single train-test split on a DataFrame with optional stratification.

Stratification can be performed based on one or several variables present in the DataFrame:

  • If a variable is categorical, a chi-squared test is performed to check that the train and test sets have the same distribution;

  • If a variable is continuous, a t-test is performed.

make_split will try random splits until one split shows a p-values greater than p_categorical_threshold for all categorical variables used for stratification, and greater than p_continuous_threshold for all continuous variables. So, p_categorical_threshold and p_continuous_threshold controls the required level of similarity between the train and the test distributions. The higher the threshold, the more demanding the similarity test. Therefore, a too high a threshold may prevent you from finding a valid split.

Parameters:
  • data (DataFrameType) – A pandas.DataFrame (or a path to a TSV file containing the dataframe) with the list of participant/session pairs to split.

  • n_test (float, default=0.2) –

    A positive float. If >=1, it specifies the number of test participants. If <1, it is treated as a proportion of the input participants to have in the test data.

    Note

    Here, we are talking about number of participants. So, if n_test=0.2, it doesn’t mean that you have 80% of your data in the training set, but rather that you have 80% of you participants in the training set.

  • output_dir (Optional[PathType], default=None) – Directory where to save the output files of the split, passed as a str or a pathlib.Path. If data is a path and output_dir is not passed, the parent directory of the TSV file will be used.

  • subset_name (str, default="test") – Name for the test subset.

  • stratification (Union[Sequence[str], bool], default=False) – Whether to perform stratification. If True, the columns "age" and "sex" will be used for stratification. If a list of str is passed, these columns will be used.

  • p_categorical_threshold (float, default=0.80) – Threshold for acceptable categorical stratification. Must be between 0 and 1.

  • p_continuous_threshold (float, default=0.80) – Threshold for acceptable continuous stratification. Must be between 0 and 1.

  • longitudinal (bool, default=False) – Whether to include only the baseline sessions in the test set (longitudinal=False). If True, all the sessions of the test participants will be included. No matter this argument, all sessions are always kept in the training set.

  • n_try_max (int, default=1000) – Maximum number of attempts to find a valid split.

  • seed (Optional[int], default=None) – Seed to control the randomness of the split. Useful for reproducibility.

Returns:

Path – Directory containing the split files, including some statistics on the stratification variables.

Raises:
  • ValueError – If data is a pandas.DataFrame and no output_dir is passed.

  • DataFrameError – If the DataFrame does not contain the columns "participant_id" and "session_id".

  • KeyError – If the stratification columns mentioned via stratification cannot be found in the DataFrame.

  • RuntimeError – If no good split was found after n_try_max tries.

See also

make_kfold()

Examples

>>> df.head(5)  # quick look at the data
    participant_id  session_id      age     sex     diagnosis
0   sub-003         ses-M000        40      M       MCI
1   sub-004         ses-M000        56      M       CN
2   sub-004         ses-M054        75      F       MCI
3   sub-005         ses-M006        85      F       CN
4   sub-005         ses-M018        64      M       AD
>>> len(df)
64
>>> from clinicadl.split import make_split
>>> split_dir = make_split(
        df,
        output_dir="splits",
        stratification=["sex", "age"],
        p_categorical_threshold=0.9,
        p_continuous_threshold=0.9,
    )
>>> split_dir
PosixPath('splits/split')
# splits/split
# ├── single_split_config.json
# ├── split_categorical_stats.tsv
# ├── split_continuous_stats.tsv
# ├── test_baseline.tsv
# ├── train.tsv
# └── train_baseline.tsv
>>> pd.read_csv(split_dir / "train.tsv", sep="\t").head(5)
    participant_id  session_id
0   sub-005         ses-M006
1   sub-005         ses-M018
2   sub-065         ses-M006
3   sub-065         ses-M018
4   sub-044         ses-M000
>>> train_baseline = pd.read_csv(split_dir / "train_baseline.tsv", sep="\t")
>>> train_baseline.head(5)
    participant_id  session_id      sex     age
0   sub-005         ses-M006        F       85
1   sub-065         ses-M006        F       58
2   sub-044         ses-M000        M       64
3   sub-014         ses-M000        M       56
4   sub-043         ses-M000        M       71
>>> len(train_baseline)
32
>>> test_baseline = pd.read_csv(split_dir / "test_baseline.tsv", sep="\t")
>>> test_baseline.head(5)
    participant_id  session_id      sex     age
0   sub-037         ses-M006        F       56
1   sub-003         ses-M000        M       40
2   sub-077         ses-M006        F       23
3   sub-023         ses-M000        M       69
4   sub-057         ses-M006        F       77
>>> len(test_baseline)
8
>>> pd.read_csv(split_dir / "split_continuous_stats.tsv", sep="\t")
    label   statistic       train   test
0   age         mean        62.8    63.6
1   age         std         18.0    22.4
>>> pd.read_csv(split_dir / "split_categorical_stats.tsv", sep="\t")
    label   value   statistic       train   test
0   sex         F       proportion  0.41    0.38
1   sex         F       count       13.0    3.0
2   sex         M       proportion  0.59    0.62
3   sex         M       count       19.0    5.0