Skip to content

Design Module

The causalkit.design module provides utilities for designing experiments and splitting traffic.

Overview

This module includes functions for:

  • Splitting traffic for experiments with customizable ratios
  • Supporting stratified splitting to maintain distribution of key variables

API Reference

Utility functions for splitting traffic data from DataFrames.

split_traffic(df, split_ratio=0.5, stratify_column=None, random_state=None)

Split a DataFrame into multiple parts based on the specified ratio.

Parameters

df : pd.DataFrame The input DataFrame containing traffic data. split_ratio : float or list of floats, default 0.5 If float, represents the proportion of the DataFrame to include in the first split. If list, each value represents the proportion for each split. The values should sum to 1. stratify_column : str, optional Column name to use for stratified splitting. If provided, the splits will have the same proportion of values in this column. random_state : int, optional Random seed for reproducibility.

Returns

tuple of pd.DataFrame A tuple containing the split DataFrames. If split_ratio is a float, returns a tuple of two DataFrames. If split_ratio is a list, returns a tuple with length equal to len(split_ratio) + 1.

Examples

import pandas as pd df = pd.DataFrame({'user_id': range(100), 'group': ['A', 'B'] * 50}) train_df, test_df = split_traffic(df, split_ratio=0.8, random_state=42) len(train_df), len(test_df) (80, 20)

train_df, val_df, test_df = split_traffic(df, split_ratio=[0.7, 0.2], random_state=42) len(train_df), len(val_df), len(test_df) (70, 20, 10)

Source code in causalkit/design/traffic_splitter.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def split_traffic(
    df: pd.DataFrame,
    split_ratio: Union[float, List[float]] = 0.5,
    stratify_column: Optional[str] = None,
    random_state: Optional[int] = None
) -> Tuple[pd.DataFrame, ...]:
    """
    Split a DataFrame into multiple parts based on the specified ratio.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame containing traffic data.
    split_ratio : float or list of floats, default 0.5
        If float, represents the proportion of the DataFrame to include in the first split.
        If list, each value represents the proportion for each split. The values should sum to 1.
    stratify_column : str, optional
        Column name to use for stratified splitting. If provided, the splits will have
        the same proportion of values in this column.
    random_state : int, optional
        Random seed for reproducibility.

    Returns
    -------
    tuple of pd.DataFrame
        A tuple containing the split DataFrames. If split_ratio is a float, returns a tuple of two DataFrames.
        If split_ratio is a list, returns a tuple with length equal to len(split_ratio) + 1.

    Examples
    --------
    >>> import pandas as pd
    >>> df = pd.DataFrame({'user_id': range(100), 'group': ['A', 'B'] * 50})
    >>> train_df, test_df = split_traffic(df, split_ratio=0.8, random_state=42)
    >>> len(train_df), len(test_df)
    (80, 20)

    >>> train_df, val_df, test_df = split_traffic(df, split_ratio=[0.7, 0.2], random_state=42)
    >>> len(train_df), len(val_df), len(test_df)
    (70, 20, 10)
    """
    np.random.seed(random_state)

    if isinstance(split_ratio, float):
        split_ratio = [split_ratio]

    # Validate split_ratio
    if sum(split_ratio) >= 1:
        raise ValueError("Sum of split ratios should be less than 1.")

    # Calculate the cumulative split points
    cum_splits = np.cumsum(split_ratio)

    # Create a list to store the split DataFrames
    split_dfs = []

    if stratify_column is not None and stratify_column in df.columns:
        # Stratified split
        unique_strata = df[stratify_column].unique()
        strata_dfs = []

        for stratum in unique_strata:
            stratum_df = df[df[stratify_column] == stratum].copy()
            stratum_indices = stratum_df.index.tolist()
            np.random.shuffle(stratum_indices)

            # Calculate split indices for this stratum
            stratum_splits = [int(len(stratum_indices) * split) for split in cum_splits]

            # Initialize list for this stratum's splits
            stratum_split_dfs = []

            # First split
            stratum_split_dfs.append(stratum_df.loc[stratum_indices[:stratum_splits[0]]])

            # Middle splits (if any)
            for i in range(1, len(stratum_splits)):
                stratum_split_dfs.append(
                    stratum_df.loc[stratum_indices[stratum_splits[i-1]:stratum_splits[i]]]
                )

            # Last split
            stratum_split_dfs.append(stratum_df.loc[stratum_indices[stratum_splits[-1]:]])

            strata_dfs.append(stratum_split_dfs)

        # Combine strata for each split
        for i in range(len(cum_splits) + 1):
            split_dfs.append(pd.concat([strata_df[i] for strata_df in strata_dfs], axis=0))
    else:
        # Random split
        indices = df.index.tolist()
        np.random.shuffle(indices)

        # Calculate split indices
        splits = [int(len(indices) * split) for split in cum_splits]

        # First split
        split_dfs.append(df.loc[indices[:splits[0]]])

        # Middle splits (if any)
        for i in range(1, len(splits)):
            split_dfs.append(df.loc[indices[splits[i-1]:splits[i]]])

        # Last split
        split_dfs.append(df.loc[indices[splits[-1]:]])

    return tuple(split_dfs)