clinicadl.split.make_split¶
- clinicadl.split.make_split(data: Path | str | DataFrame, n_test: float = 0.2, output_dir: Path | str | None = None, subset_name: str = 'test', stratification: Sequence[str] | bool = False, p_categorical_threshold: float = 0.8, p_continuous_threshold: float = 0.8, longitudinal: bool = False, n_try_max: int = 1000, seed: int | None = None) Path[source]¶
Performs a single train-test split on a DataFrame with optional stratification.
Stratification can be performed based on one or several variables present in the DataFrame:
If a variable is categorical, a chi-squared test is performed to check that the train and test sets have the same distribution;
If a variable is continuous, a t-test is performed.
make_splitwill try random splits until one split shows a p-values greater thanp_categorical_thresholdfor all categorical variables used for stratification, and greater thanp_continuous_thresholdfor all continuous variables. So,p_categorical_thresholdandp_continuous_thresholdcontrols the required level of similarity between the train and the test distributions. The higher the threshold, the more demanding the similarity test. Therefore, a too high a threshold may prevent you from finding a valid split.- Parameters:
data (DataFrameType) – A
pandas.DataFrame(or a path to aTSVfile containing the dataframe) with the list of participant/session pairs to split.n_test (float, default=0.2) –
A positive float. If
>=1, it specifies the number of test participants. If<1, it is treated as a proportion of the input participants to have in the test data.Note
Here, we are talking about number of participants. So, if
n_test=0.2, it doesn’t mean that you have 80% of your data in the training set, but rather that you have 80% of you participants in the training set.output_dir (Optional[PathType], default=None) – Directory where to save the output files of the split, passed as a
stror a pathlib.Path. Ifdatais a path andoutput_diris not passed, the parent directory of theTSVfile will be used.subset_name (str, default="test") – Name for the test subset.
stratification (Union[Sequence[str], bool], default=False) – Whether to perform stratification. If
True, the columns"age"and"sex"will be used for stratification. If a list ofstris passed, these columns will be used.p_categorical_threshold (float, default=0.80) – Threshold for acceptable categorical stratification. Must be between 0 and 1.
p_continuous_threshold (float, default=0.80) – Threshold for acceptable continuous stratification. Must be between 0 and 1.
longitudinal (bool, default=False) – Whether to include only the baseline sessions in the test set (
longitudinal=False). IfTrue, all the sessions of the test participants will be included. No matter this argument, all sessions are always kept in the training set.n_try_max (int, default=1000) – Maximum number of attempts to find a valid split.
seed (Optional[int], default=None) – Seed to control the randomness of the split. Useful for reproducibility.
- Returns:
Path – Directory containing the split files, including some statistics on the stratification variables.
- Raises:
ValueError – If
datais apandas.DataFrameand nooutput_diris passed.DataFrameError – If the DataFrame does not contain the columns
"participant_id"and"session_id".KeyError – If the stratification columns mentioned via
stratificationcannot be found in the DataFrame.RuntimeError – If no good split was found after
n_try_maxtries.
See also
Examples
>>> df.head(5) # quick look at the data participant_id session_id age sex diagnosis 0 sub-003 ses-M000 40 M MCI 1 sub-004 ses-M000 56 M CN 2 sub-004 ses-M054 75 F MCI 3 sub-005 ses-M006 85 F CN 4 sub-005 ses-M018 64 M AD >>> len(df) 64
>>> from clinicadl.split import make_split >>> split_dir = make_split( df, output_dir="splits", stratification=["sex", "age"], p_categorical_threshold=0.9, p_continuous_threshold=0.9, ) >>> split_dir PosixPath('splits/split') # splits/split # ├── single_split_config.json # ├── split_categorical_stats.tsv # ├── split_continuous_stats.tsv # ├── test_baseline.tsv # ├── train.tsv # └── train_baseline.tsv
>>> pd.read_csv(split_dir / "train.tsv", sep="\t").head(5) participant_id session_id 0 sub-005 ses-M006 1 sub-005 ses-M018 2 sub-065 ses-M006 3 sub-065 ses-M018 4 sub-044 ses-M000 >>> train_baseline = pd.read_csv(split_dir / "train_baseline.tsv", sep="\t") >>> train_baseline.head(5) participant_id session_id sex age 0 sub-005 ses-M006 F 85 1 sub-065 ses-M006 F 58 2 sub-044 ses-M000 M 64 3 sub-014 ses-M000 M 56 4 sub-043 ses-M000 M 71 >>> len(train_baseline) 32 >>> test_baseline = pd.read_csv(split_dir / "test_baseline.tsv", sep="\t") >>> test_baseline.head(5) participant_id session_id sex age 0 sub-037 ses-M006 F 56 1 sub-003 ses-M000 M 40 2 sub-077 ses-M006 F 23 3 sub-023 ses-M000 M 69 4 sub-057 ses-M006 F 77 >>> len(test_baseline) 8
>>> pd.read_csv(split_dir / "split_continuous_stats.tsv", sep="\t") label statistic train test 0 age mean 62.8 63.6 1 age std 18.0 22.4 >>> pd.read_csv(split_dir / "split_categorical_stats.tsv", sep="\t") label value statistic train test 0 sex F proportion 0.41 0.38 1 sex F count 13.0 3.0 2 sex M proportion 0.59 0.62 3 sex M count 19.0 5.0