Source code for causalkit.inference.att.bootstrap_diff_means

"""
Bootstrap difference-in-means inference for CausalData (ATT context).

Computes the ATT-style difference in means (treated - control) and provides:
- Two-sided p-value using a normal approximation with bootstrap standard error
- Percentile confidence interval for the absolute difference
- Relative difference (%) and corresponding CI relative to control mean

Input:
- data: CausalData
- confidence_level: float in (0, 1), default 0.95
- n_simul: number of bootstrap simulations (int > 0), default 10000

Output: dict with the same keys as ttest:
- p_value
- absolute_difference
- absolute_ci: (low, high)
- relative_difference
- relative_ci: (low, high)
"""

from typing import Dict, Any

import numpy as np
import pandas as pd
from scipy import stats

from causalkit.data.causaldata import CausalData


[docs] def bootstrap_diff_means( data: CausalData, confidence_level: float = 0.95, n_simul: int = 10000, ) -> Dict[str, Any]: """ Bootstrap inference for difference in means between treated (T=1) and control (T=0). Parameters ---------- data : CausalData The CausalData object containing treatment and outcome variables. confidence_level : float, default 0.95 Confidence level for the percentile confidence interval (0 < level < 1). n_simul : int, default 10000 Number of bootstrap resamples. Returns ------- Dict[str, Any] Dictionary with p_value, absolute_difference, absolute_ci, relative_difference, relative_ci (matching the structure of inference.att.ttest). Raises ------ ValueError If inputs are invalid, treatment is not binary, or groups are empty. """ # Validate inputs if not 0 < confidence_level < 1: raise ValueError("confidence_level must be between 0 and 1 (exclusive)") if not isinstance(n_simul, int) or n_simul <= 0: raise ValueError("n_simul must be a positive integer") treatment = data.treatment target = data.target if not isinstance(treatment, pd.Series) or treatment.empty: raise ValueError("causaldata object must have a treatment variable defined") if not isinstance(target, pd.Series) or target.empty: raise ValueError("causaldata object must have a outcome variable defined") uniq = treatment.unique() if len(uniq) != 2: raise ValueError("Treatment variable must be binary (have exactly 2 unique values)") control = target[treatment == 0] treated = target[treatment == 1] n0 = int(control.shape[0]) n1 = int(treated.shape[0]) if n0 < 1 or n1 < 1: raise ValueError("Not enough observations in one of the groups for bootstrap (need at least 1 per group)") control_mean = float(control.mean()) treated_mean = float(treated.mean()) abs_diff = float(treated_mean - control_mean) # Prepare for bootstrap: indices for resampling within each group ctrl_vals = control.to_numpy() trt_vals = treated.to_numpy() rng = np.random.default_rng() # Vectorized bootstrap using random integers for indices ctrl_idx = rng.integers(0, n0, size=(n_simul, n0)) trt_idx = rng.integers(0, n1, size=(n_simul, n1)) ctrl_boot_means = ctrl_vals[ctrl_idx].mean(axis=1) trt_boot_means = trt_vals[trt_idx].mean(axis=1) boot_diffs = trt_boot_means - ctrl_boot_means # Percentile CI for absolute difference alpha = 1 - confidence_level lower = float(np.quantile(boot_diffs, alpha / 2)) upper = float(np.quantile(boot_diffs, 1 - alpha / 2)) absolute_ci = (lower, upper) # p-value using bootstrap SE and normal approximation se_boot = float(np.std(boot_diffs, ddof=1)) if se_boot == 0: p_value = 1.0 else: z = abs_diff / se_boot p_value = float(2 * (1 - stats.norm.cdf(abs(z)))) # Relative effects and CI by scaling if control_mean == 0: relative_diff = np.inf if abs_diff > 0 else 0.0 if abs_diff == 0 else -np.inf relative_ci = (np.nan, np.nan) else: relative_diff = (abs_diff / abs(control_mean)) * 100.0 rel_lower = (lower / abs(control_mean)) * 100.0 rel_upper = (upper / abs(control_mean)) * 100.0 relative_ci = (float(rel_lower), float(rel_upper)) return { "p_value": float(p_value), "absolute_difference": float(abs_diff), "absolute_ci": (float(absolute_ci[0]), float(absolute_ci[1])), "relative_difference": float(relative_diff), "relative_ci": (float(relative_ci[0]), float(relative_ci[1])), }