Source code for clinicadl.split.make_splits.single_split

from logging import getLogger
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from scipy.stats import chisquare, ttest_ind
from sklearn.model_selection import ShuffleSplit

from clinicadl.split.splitter.single_split import SingleSplitConfig
from clinicadl.utils.dictionary.words import (
    AGE,
    COUNT,
    LABEL,
    MEAN,
    PROPORTION,
    SEX,
    SPLIT,
    STATISTIC,
    STD,
    TEST,
    TRAIN,
    VALUE,
)
from clinicadl.utils.tsvtools import read_df
from clinicadl.utils.typing import DataFrameType, PathType

from .utils import (
    extract_baseline,
    find_available_split_dir,
    write_to_tsv,
)

logger = getLogger(__name__)


[docs] def make_split( data: DataFrameType, n_test: float = 0.2, output_dir: Optional[PathType] = None, subset_name: str = TEST, stratification: Union[Sequence[str], bool] = False, p_categorical_threshold: float = 0.80, p_continuous_threshold: float = 0.80, longitudinal: bool = False, n_try_max: int = 1000, seed: Optional[int] = None, ) -> Path: r""" Performs a single train-test split on a DataFrame with optional stratification. Stratification can be performed based on one or several variables present in the DataFrame: - If a variable is **categorical**, a `chi-squared test <https://en.wikipedia.org/wiki/Chi-squared_test>`_ is performed to check that the train and test sets have the same distribution; - If a variable is **continuous**, a `t-test <https://en.wikipedia.org/wiki/Student%27s_t-test>`_ is performed. ``make_split`` will try random splits until one split shows a p-values greater than ``p_categorical_threshold`` for all categorical variables used for stratification, and greater than ``p_continuous_threshold`` for all continuous variables. So, ``p_categorical_threshold`` and ``p_continuous_threshold`` controls the required level of similarity between the train and the test distributions. The higher the threshold, the more demanding the similarity test. Therefore, a too high a threshold may prevent you from finding a valid split. Parameters ---------- data: DataFrameType A :py:class:`pandas.DataFrame` (or a path to a ``TSV`` file containing the dataframe) with the list of participant/session pairs to split. n_test : float, default=0.2 A positive float. If ``>=1``, it specifies the number of test participants. If ``<1``, it is treated as a proportion of the input participants to have in the test data. .. note:: Here, we are talking about number of **participants**. So, if ``n_test=0.2``, it doesn't mean that you have 80% of your data in the training set, but rather that you have 80% of you participants in the training set. output_dir : Optional[PathType], default=None Directory where to save the output files of the split, passed as a ``str`` or a :pathlib.Path:`pathlib.Path <>`. If ``data`` is a path and ``output_dir`` is not passed, the parent directory of the ``TSV`` file will be used. subset_name : str, default="test" Name for the test subset. stratification : Union[Sequence[str], bool], default=False Whether to perform stratification. If ``True``, the columns ``"age"`` and ``"sex"`` will be used for stratification. If a list of ``str`` is passed, these columns will be used. p_categorical_threshold : float, default=0.80 Threshold for acceptable categorical stratification. Must be **between 0 and 1**. p_continuous_threshold : float, default=0.80 Threshold for acceptable continuous stratification. Must be **between 0 and 1**. longitudinal : bool, default=False Whether to include only the baseline sessions in the test set (``longitudinal=False``). If ``True``, all the sessions of the test participants will be included. No matter this argument, all sessions are always kept in the training set. n_try_max : int, default=1000 Maximum number of attempts to find a valid split. seed : Optional[int], default=None Seed to control the randomness of the split. Useful for reproducibility. Returns ------- Path Directory containing the split files, including some statistics on the stratification variables. Raises ------ ValueError If ``data`` is a :py:class:`pandas.DataFrame` and no ``output_dir`` is passed. DataFrameError If the DataFrame does not contain the columns ``"participant_id"`` and ``"session_id"``. KeyError If the stratification columns mentioned via ``stratification`` cannot be found in the DataFrame. RuntimeError If no good split was found after ``n_try_max`` tries. See Also -------- :py:func:`~clinicadl.split.make_kfold` Examples -------- >>> df.head(5) # quick look at the data participant_id session_id age sex diagnosis 0 sub-003 ses-M000 40 M MCI 1 sub-004 ses-M000 56 M CN 2 sub-004 ses-M054 75 F MCI 3 sub-005 ses-M006 85 F CN 4 sub-005 ses-M018 64 M AD >>> len(df) 64 >>> from clinicadl.split import make_split >>> split_dir = make_split( df, output_dir="splits", stratification=["sex", "age"], p_categorical_threshold=0.9, p_continuous_threshold=0.9, ) >>> split_dir PosixPath('splits/split') # splits/split # ├── single_split_config.json # ├── split_categorical_stats.tsv # ├── split_continuous_stats.tsv # ├── test_baseline.tsv # ├── train.tsv # └── train_baseline.tsv >>> pd.read_csv(split_dir / "train.tsv", sep="\t").head(5) participant_id session_id 0 sub-005 ses-M006 1 sub-005 ses-M018 2 sub-065 ses-M006 3 sub-065 ses-M018 4 sub-044 ses-M000 >>> train_baseline = pd.read_csv(split_dir / "train_baseline.tsv", sep="\t") >>> train_baseline.head(5) participant_id session_id sex age 0 sub-005 ses-M006 F 85 1 sub-065 ses-M006 F 58 2 sub-044 ses-M000 M 64 3 sub-014 ses-M000 M 56 4 sub-043 ses-M000 M 71 >>> len(train_baseline) 32 >>> test_baseline = pd.read_csv(split_dir / "test_baseline.tsv", sep="\t") >>> test_baseline.head(5) participant_id session_id sex age 0 sub-037 ses-M006 F 56 1 sub-003 ses-M000 M 40 2 sub-077 ses-M006 F 23 3 sub-023 ses-M000 M 69 4 sub-057 ses-M006 F 77 >>> len(test_baseline) 8 >>> pd.read_csv(split_dir / "split_continuous_stats.tsv", sep="\t") label statistic train test 0 age mean 62.8 63.6 1 age std 18.0 22.4 >>> pd.read_csv(split_dir / "split_categorical_stats.tsv", sep="\t") label value statistic train test 0 sex F proportion 0.41 0.38 1 sex F count 13.0 3.0 2 sex M proportion 0.59 0.62 3 sex M count 19.0 5.0 """ df = read_df(data, check_duplicates=False) if isinstance(data, (str, Path)): output_dir = output_dir or Path(data).parent elif isinstance(data, pd.DataFrame) and not output_dir: raise ValueError( "If you pass a DataFrame, you must specify the output directory." ) output_dir = Path(output_dir) stratification = _validate_stratification(df, stratification) baseline_df = extract_baseline(df, columns=stratification) split_dir = find_available_split_dir(output_dir, SPLIT) config = SingleSplitConfig( split_dir=split_dir, subset_name=subset_name, longitudinal=longitudinal, n_test=n_test, p_continuous_threshold=p_continuous_threshold, p_categorical_threshold=p_categorical_threshold, stratification=stratification, seed=seed, ) n_test = int(n_test) if n_test >= 1 else int(n_test * len(baseline_df)) continuous_labels, categorical_labels = _categorize_labels( df=baseline_df, stratification=config.stratification, n_test=n_test, ) if n_test == 0: train_df = baseline_df test_df = pd.DataFrame(columns=train_df.columns) else: splits = ShuffleSplit(n_splits=n_try_max, test_size=n_test, random_state=seed) for n_try, (train_index, test_index) in enumerate( splits.split(baseline_df), start=1 ): p_continuous = _compute_continuous_p_value( continuous_labels, baseline_df, train_index.tolist(), test_index.tolist(), ) if p_continuous >= p_continuous_threshold: p_categorical = _compute_categorical_p_value( categorical_labels, baseline_df, train_index.tolist(), test_index.tolist(), ) if p_categorical >= p_categorical_threshold: logger.info("Valid split found after %f attempts.", n_try) test_df = baseline_df.loc[test_index] train_df = baseline_df.loc[train_index] _write_continuous_stats( split_dir / "split_continuous_stats.tsv", continuous_labels, test_df, train_df, subset_name, ) _write_categorical_stats( split_dir / "split_categorical_stats.tsv", categorical_labels, test_df, train_df, subset_name, ) break else: raise RuntimeError( f"Unable to find a valid split after {n_try_max} attempts. " "Consider lowering thresholds or removing some stratification variables." ) write_to_tsv(test_df, split_dir, subset_name, df, longitudinal) write_to_tsv( train_df, split_dir, config._training_subset_name, df, longitudinal=True ) config.to_json() return split_dir
def _validate_stratification( df: pd.DataFrame, stratification: Union[List[str], bool], ) -> List[str]: """ Checks and validates the specified stratification columns. Parameters ---------- df : pd.DataFrame Input dataset. stratification : Union[List[str], bool] Columns to use for stratification. If True, columns are 'age' and 'sex', if False, there is no stratification. Returns ------- List[str], optional Validated list of stratification columns or None if no stratification is applied. """ if isinstance(stratification, bool): if stratification: stratification = [AGE, SEX] else: return [] if isinstance(stratification, list): if not set(stratification).issubset(df.columns): raise KeyError( f"Invalid stratification columns (not found in the dataframe): {set(stratification) - set(df.columns)}" ) return stratification raise ValueError( f"Invalid stratification option. Stratification must be a list of column names or a boolean. Got: {stratification}" ) def _categorize_labels( df: pd.DataFrame, stratification: List[str], n_test: int, ) -> Tuple[List[str], List[str]]: """ Categorize stratification columns into continuous and categorical labels. Parameters ---------- df : pd.DataFrame Input dataset. stratification : List[str] Columns to use for stratification. n_test : int Number of test samples. Returns ------- Tuple[List[str], List[str]] Continuous and categorical labels. """ continuous_labels, categorical_labels = [], [] for col in stratification: if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() >= (n_test / 2): continuous_labels.append(col) else: categorical_labels.append(col) return continuous_labels, categorical_labels def _compute_continuous_p_value( continuous_labels: List[str], baseline_df: pd.DataFrame, train_index: list[int], test_index: list[int], ) -> float: """ Compute the minimum p-value for continuous variables between train and test splits. Parameters ---------- continuous_labels : List[str] List of continuous variable names (can be empty). baseline_df : pd.DataFrame Dataframe containing the baseline data. train_index : List[int] Indices for the training set. test_index : List[int] Indices for the testing set. Returns ------- float The minimum p-value across all continuous labels. """ p_continuous = 1.0 if continuous_labels: for label in continuous_labels: train_values = baseline_df.loc[train_index, label] test_values = baseline_df.loc[test_index, label] _, new_p_continuous = ttest_ind( test_values.tolist(), train_values.tolist(), nan_policy="omit" ) # ks_2samp, or ttost_ind from statsmodels.stats.weightstats import ttost_ind if np.isnan(new_p_continuous): return 0.0 # can't compute the p-value so we won't choose this split p_continuous = min(p_continuous, new_p_continuous) return p_continuous def _compute_categorical_p_value( categorical_labels: list[str], baseline_df: pd.DataFrame, train_index: list[int], test_index: list[int], ) -> float: """ Compute the minimum p-value for categorical variables between train and test splits. Parameters ---------- categorical_labels : List[str] List of categorical variable names (can be empty). baseline_df : pd.DataFrame Dataframe containing the baseline data. train_index : List[int] Indices for the training set. test_index : List[int] Returns ------- float The minimum p-value across all categorical labels. """ p_categorical = 1.0 if categorical_labels: for label in categorical_labels: mapping = { val: i for i, val in enumerate(np.unique(baseline_df[label].dropna())) } train_values = baseline_df.loc[train_index, label].apply( lambda val: mapping[val] ) test_values = baseline_df.loc[test_index, label].apply( lambda val: mapping[val] ) new_p_categorical = _chi2_test(test_values, train_values) p_categorical = min(p_categorical, new_p_categorical) return p_categorical def _chi2_test(x_test: np.ndarray, x_train: np.ndarray) -> float: """ Perform the Chi-squared test on categorical data. Parameters ---------- x_test : np.ndarray Test data. x_train : np.ndarray Train data. Returns ------- float p-value from the Chi-squared test. """ unique_categories = np.unique(np.concatenate([x_test, x_train])) unique_categories = unique_categories[~np.isnan(unique_categories)] # Calculate observed (test) and expected (train) frequencies as raw counts f_obs = np.array([(x_test == category).sum() for category in unique_categories]) f_obs = f_obs / np.sum(f_obs) f_exp = np.array([(x_train == category).sum() for category in unique_categories]) f_exp = f_exp / np.sum(f_exp) _, p_value = chisquare(f_obs, f_exp) return p_value def _write_continuous_stats( tsv_path: Path, continuous_labels: list[str], test_df: pd.DataFrame, train_df: pd.DataFrame, subset_name: str, ): """ Write continuous statistics (mean, std) to a TSV file. Parameters ---------- tsv_path : Path Path to save the output TSV file. continuous_labels : List[str] List of continuous variable names. test_df : pd.DataFrame Test dataset. train_df : pd.DataFrame Train dataset. subset_name : str Name of the test subset. """ if not continuous_labels: return data = [ (label, MEAN, train_df[label].mean(), test_df[label].mean()) for label in continuous_labels ] + [ (label, STD, train_df[label].std(), test_df[label].std()) for label in continuous_labels ] df_stats_continuous = pd.DataFrame( data, columns=[LABEL, STATISTIC, TRAIN, subset_name] ) df_stats_continuous.to_csv(tsv_path, sep="\t", index=False) def _write_categorical_stats( tsv_path: Path, categorical_labels: list[str], test_df: pd.DataFrame, train_df: pd.DataFrame, subset_name: str, ): """ Write categorical statistics (proportion, count) to a TSV file. Parameters ---------- tsv_path : Path Path to save the output TSV file. categorical_labels : List[str] List of categorical variable names. test_df : pd.DataFrame Test dataset. train_df : pd.DataFrame Train dataset. subset_name : str Name of the test subset. """ if not categorical_labels: return data = [] for label in categorical_labels: unique_values = pd.concat([train_df, test_df])[label].unique() for value in unique_values: test_count = int((test_df[label] == value).sum()) train_count = int((train_df[label] == value).sum()) test_proportion = test_count / len(test_df) train_proportion = train_count / len(train_df) data.append((label, value, PROPORTION, train_proportion, test_proportion)) data.append((label, value, COUNT, train_count, test_count)) df_stats_categorical = pd.DataFrame( data, columns=[LABEL, VALUE, STATISTIC, TRAIN, subset_name] ) df_stats_categorical.to_csv(tsv_path, sep="\t", index=False)