Refutation Module#
The causalkit.refutation
package provides robustness and refutation utilities to stress-test causal estimates by perturbing data, checking identifying assumptions, and running sensitivity analyses.
Overview#
Key utilities:
Placebo tests (randomize outcome or treatment, subsample)
Sensitivity analysis for unobserved confounding (including set-based benchmarking)
Orthogonality/IRM moment checks with out-of-sample (OOS) diagnostics
API Reference (Package)#
Generate random outcome (target) variables while keeping treatment and covariates intact. |
|
Generate random binary treatment variables while keeping outcome and covariates intact. |
|
Re-estimate the effect on a random subset (default 80 %) to check sample-stability of the estimate. |
|
Perform sensitivity analysis on a causal effect estimate. |
|
Get the sensitivity summary string from a sensitivity-analyzed model. |
|
Benchmark a set of observed confounders to assess robustness, mirroring DoubleML's sensitivity_benchmark for DoubleMLIRM. |
|
Comprehensive AIPW orthogonality diagnostics for DoubleML estimators. |
Placebo and Subset Refutation#
Placebo / robustness checks for CausalKit.
The functions below deliberately break or perturb assumptions, re-estimate the causal effect, and report the resulting theta-hat and p-value so that users can judge the credibility of their original estimate.
All helpers share the same signature:
- def refute_xxx(
inference_fn, data: CausalData, random_state: int | None = None, **inference_kwargs,
) -> dict
- param inference_fn:
Any CausalKit estimator (e.g. dml_att) returning a dict that contains keys “coefficient” and “p_value”.
- type inference_fn:
callable
- param data:
The original data object.
- type data:
CausalData
- param random_state:
Seed for reproducibility.
- type random_state:
int, optional
- param **inference_kwargs:
Extra keyword args forwarded verbatim to inference_fn (e.g. you can tweak learners, #folds, …).
- returns:
{‘theta’: float, ‘p_value’: float}
- rtype:
dict
- causalkit.refutation.placebo.refute_placebo_outcome(inference_fn, data, random_state=None, **inference_kwargs)[source]#
Generate random outcome (target) variables while keeping treatment and covariates intact. For binary outcomes, generates random binary variables with the same proportion. For continuous outcomes, generates random variables from a normal distribution fitted to the original data. A valid causal design should now yield θ ≈ 0 and a large p-value.
- causalkit.refutation.placebo.refute_placebo_treatment(inference_fn, data, random_state=None, **inference_kwargs)[source]#
Generate random binary treatment variables while keeping outcome and covariates intact. Generates random binary treatment with the same proportion as the original treatment. Breaks the treatment–outcome link.
Sensitivity Analysis#
Sensitivity analysis for causal inference using DoubleML.
This module provides functions to perform sensitivity analysis on causal effect estimates to assess the robustness of the results to potential unobserved confounding.
- causalkit.refutation.sensitivity.sensitivity_analysis(effect_estimation, cf_y, cf_d, rho=1.0, level=0.95)[source]#
Perform sensitivity analysis on a causal effect estimate.
This function takes a DoubleML effect estimation result and performs sensitivity analysis to assess robustness to unobserved confounding.
- Parameters:
effect_estimation (Dict[str, Any]) – A dictionary containing the effect estimation results, must include: - ‘model’: A fitted DoubleML model object (e.g., DoubleMLIRM) - Other keys like ‘coefficient’, ‘std_error’, ‘p_value’, etc.
cf_y (float) – Sensitivity parameter for the outcome equation (confounding strength)
cf_d (float) – Sensitivity parameter for the treatment equation (confounding strength)
rho (float, default 1.0) – Correlation parameter between unobserved confounders
level (float, default 0.95) – Confidence level for the sensitivity analysis
- Returns:
A formatted sensitivity analysis summary report
- Return type:
- Raises:
ValueError – If the effect_estimation does not contain a ‘model’ key or if the model does not support sensitivity analysis
KeyError – If required keys are missing from the effect_estimation dictionary
TypeError – If the model is not a DoubleML object that supports sensitivity analysis
Examples
>>> from causalkit.data import generate_rct_data, CausalData >>> from causalkit.inference.ate import dml_ate >>> from causalkit.refutation.sensitivity import sensitivity_analysis >>> >>> # Generate data and estimate effect >>> df = generate_rct_data() >>> ck = CausalData(df=df, outcome='outcome', treatment='treatment', ... confounders=['age', 'invited_friend']) >>> results = dml_ate(ck) >>> >>> # Perform sensitivity analysis >>> sensitivity_report = sensitivity_analysis(results, cf_y=0.04, cf_d=0.03) >>> print(sensitivity_report)
- causalkit.refutation.sensitivity.get_sensitivity_summary(effect_estimation)[source]#
Get the sensitivity summary string from a sensitivity-analyzed model.
- causalkit.refutation.sensitivity.sensitivity_analysis_set(effect_estimation, benchmarking_set, level=0.95, null_hypothesis=0.0, **kwargs)[source]#
Benchmark a set of observed confounders to assess robustness, mirroring DoubleML’s sensitivity_benchmark for DoubleMLIRM.
- Parameters:
effect_estimation (Dict[str, Any]) – A dictionary containing the effect estimation results with a fitted DoubleML model under the ‘model’ key.
benchmarking_set (Union[str, List[str], List[List[str]]]) – One or multiple names of observed confounders to benchmark (e.g., [“inc”], [“pira”], [“twoearn”]). Accepts: - a single string (benchmarks that single confounder), - a list of strings (interpreted as multiple single-variable benchmarks, each run separately), or - a list of lists/tuples of strings to specify explicit benchmarking groups (each inner list is run once together).
level (float, default 0.95) – Confidence level used by the benchmarking procedure.
null_hypothesis (float, default 0.0) – The null hypothesis value for the target parameter.
**kwargs (Any) – Additional keyword arguments passed through to the underlying DoubleML sensitivity_benchmark method.
- Returns:
If a single confounder/group is provided, returns the object from a single call to model.sensitivity_benchmark(benchmarking_set=[…]).
If multiple confounders/groups are provided, returns a dict mapping each confounder (str) or group (tuple[str, …]) to its corresponding result object.
- Return type:
Any
- Raises:
TypeError – If inputs have invalid types or if the model does not support sensitivity benchmarking.
ValueError – If required inputs are missing or invalid (e.g., empty benchmarking_set, invalid level).
RuntimeError – If the underlying sensitivity_benchmark call fails.
Orthogonality Checks#
AIPW orthogonality diagnostics for DoubleML estimators.
This module implements comprehensive orthogonality diagnostics for AIPW/IRM-based estimators like dml_ate and dml_att to validate the key assumptions required for valid causal inference. Based on the efficient influence function (EIF) framework.
Key diagnostics implemented: - Out-of-sample moment check (non-tautological) - Orthogonality (Gateaux derivative) tests - Influence diagnostics
- causalkit.refutation.orthogonality.aipw_score_ate(y, d, m0, m1, g, theta, eps=0.01)[source]#
Efficient influence function (EIF) for ATE.
- causalkit.refutation.orthogonality.aipw_score_att(y, d, m0, m1, g, theta, p1=None, eps=0.01)[source]#
Efficient influence function (EIF) for ATT (a.k.a. ATTE) under IRM/AIPW.
ψ_ATT(W; θ, η) = [ D*(Y - m0(X) - θ) - (1-D)*{ g(X)/(1-g(X)) }*(Y - m0(X)) ] / E[D]
- Return type:
- Parameters:
Notes
This matches DoubleML’s score=’ATTE’ (weights ω=D/E[D], ar{ω}=m(X)/E[D]).
m1 enters only via θ; ∂ψ/∂m1 = 0.
- causalkit.refutation.orthogonality.extract_nuisances(dml_model, test_indices=None)[source]#
Robustly extract nuisance function predictions from DoubleML model.
Handles different DoubleML prediction key layouts and provides clear error messages.
- Parameters:
dml_model (DoubleML model) – Fitted DoubleML model with predictions
test_indices (np.ndarray, optional) – If provided, extract predictions only for these indices
- Returns:
(g, m0, m1) where: - g: propensity scores P(D=1|X) - m0: outcome predictions E[Y|X,D=0] - m1: outcome predictions E[Y|X,D=1]
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- Raises:
KeyError – If required prediction keys cannot be found
- causalkit.refutation.orthogonality.oos_moment_check_with_fold_nuisances(fold_thetas, fold_indices, fold_nuisances, y, d, score_fn=None)[source]#
Out-of-sample moment check using fold-specific nuisances to avoid tautological results.
For each fold k, evaluates the AIPW score using θ fitted on other folds and nuisance predictions from the fold-specific model, then tests if the combined moment condition holds.
- Parameters:
fold_thetas (List[float]) – Treatment effects estimated excluding each fold
fold_indices (List[np.ndarray]) – Indices for each fold
fold_nuisances (List[Tuple[np.ndarray, np.ndarray, np.ndarray]]) – Fold-specific nuisance predictions (g, m0, m1) for each fold
y (np.ndarray) – Observed outcomes and treatments
d (np.ndarray) – Observed outcomes and treatments
score_fn (Callable[[ndarray, ndarray, ndarray, ndarray, ndarray, float], ndarray] | None)
- Returns:
Fold-wise results and combined t-statistic
- Return type:
Tuple[pd.DataFrame, float]
- causalkit.refutation.orthogonality.oos_moment_check(fold_thetas, fold_indices, y, d, m0, m1, g, score_fn=None)[source]#
Out-of-sample moment check to avoid tautological results (legacy version).
For each fold k, evaluates the AIPW score using θ fitted on other folds, then tests if the combined moment condition holds.
- Parameters:
fold_thetas (List[float]) – Treatment effects estimated excluding each fold
fold_indices (List[np.ndarray]) – Indices for each fold
y (np.ndarray) – Data arrays (outcomes, treatment, predictions)
d (np.ndarray) – Data arrays (outcomes, treatment, predictions)
m0 (np.ndarray) – Data arrays (outcomes, treatment, predictions)
m1 (np.ndarray) – Data arrays (outcomes, treatment, predictions)
g (np.ndarray) – Data arrays (outcomes, treatment, predictions)
score_fn (Callable[[ndarray, ndarray, ndarray, ndarray, ndarray, float], ndarray] | None)
- Returns:
Fold-wise results and combined t-statistic
- Return type:
Tuple[pd.DataFrame, float]
- causalkit.refutation.orthogonality.orthogonality_derivatives(X_basis, y, d, m0, m1, g, eps=0.01)[source]#
Compute orthogonality (Gateaux derivative) tests for nuisance functions.
Tests directional derivatives of the AIPW signal with respect to nuisances. For true nuisances, these derivatives should be ≈ 0 for rich sets of directions.
- Parameters:
X_basis (np.ndarray, shape (n, B)) – Matrix of direction functions evaluated at X (include column of 1s for calibration)
y (np.ndarray) – Data arrays
d (np.ndarray) – Data arrays
m0 (np.ndarray) – Data arrays
m1 (np.ndarray) – Data arrays
g (np.ndarray) – Data arrays
eps (float, default 0.01) – Clipping bound for propensity scores to avoid extreme weights
- Returns:
Derivative estimates, standard errors, and t-statistics for each basis function
- Return type:
pd.DataFrame
- causalkit.refutation.orthogonality.influence_summary(y, d, m0, m1, g, theta_hat, k=10, target='ATE', clip_eps=0.01)[source]#
Compute influence diagnostics showing where uncertainty comes from.
- Parameters:
- Returns:
Influence diagnostics including SE, heavy-tail metrics, and top-k cases
- Return type:
Dict[str, Any]
- causalkit.refutation.orthogonality.refute_irm_orthogonality(inference_fn, data, trim_propensity=(0.02, 0.98), n_basis_funcs=None, n_folds_oos=5, target='ATE', clip_eps=0.01, strict_oos=False, **inference_kwargs)[source]#
Comprehensive AIPW orthogonality diagnostics for DoubleML estimators.
Implements three key diagnostic approaches based on the efficient influence function (EIF): 1. Out-of-sample moment check (non-tautological) 2. Orthogonality (Gateaux derivative) tests 3. Influence diagnostics
- Parameters:
inference_fn (Callable) – The inference function (dml_ate or dml_att)
data (CausalData) – The causal data object
trim_propensity (Tuple[float, float], default (0.02, 0.98)) – Propensity score trimming bounds (min, max) to avoid extreme weights
n_basis_funcs (Optional[int], default None (len(confounders)+1)) – Number of basis functions for orthogonality derivative tests (constant + covariates). If None, defaults to the number of confounders in data plus 1 for the constant term.
n_folds_oos (int, default 5) – Number of folds for out-of-sample moment check
**inference_kwargs (dict) – Additional arguments passed to inference_fn
target (str)
clip_eps (float)
strict_oos (bool)
- Returns:
Dictionary containing: - oos_moment_test: Out-of-sample moment condition results - orthogonality_derivatives: Gateaux derivative test results - influence_diagnostics: Influence function diagnostics - theta: Original treatment effect estimate - trimmed_diagnostics: Results on trimmed sample - overall_assessment: Summary diagnostic assessment
- Return type:
Dict[str, Any]
Examples
>>> from causalkit.refutation import refute_irm_orthogonality >>> from causalkit.inference.ate import dml_ate >>> >>> # Comprehensive orthogonality check >>> ortho_results = refute_irm_orthogonality(dml_ate, causal_data) >>> >>> # Check key diagnostics >>> print(f"OOS moment t-stat: {ortho_results['oos_moment_test']['tstat']:.3f}") >>> print(f"Calibration issues: {len(ortho_results['orthogonality_derivatives'].query('abs(t_g) > 2'))}") >>> print(f"Assessment: {ortho_results['overall_assessment']}")
- causalkit.refutation.orthogonality.orthogonality_derivatives_att(X_basis, y, d, m0, g, p1, eps=0.01)[source]#
Gateaux derivatives of the ATT score wrt nuisances (m0, g). m1-derivative is 0.
For ψ_ATT = [ D*(Y - m0 - θ) - (1-D)*(g/(1-g))*(Y - m0) ] / p1: :rtype:
DataFrame
∂_{m0}[h] : (1/n) Σ h(X_i) * [ ((1-D_i)*g_i/(1-g_i) - D_i) / p1 ] ∂_{g}[s] : (1/n) Σ s(X_i) * [ -(1-D_i)*(Y_i - m0_i) / ( p1 * (1-g_i)^2 ) ]
Both have 0 expectation at the truth (Neyman orthogonality).
- causalkit.refutation.orthogonality.overlap_diagnostics_att(g, d, eps_list=[0.95, 0.97, 0.98, 0.99])[source]#
Key overlap metrics for ATT: availability of suitable controls. Reports conditional shares: among CONTROLS, fraction with m(X) ≥ threshold; among TREATED, fraction with m(X) ≤ 1 - threshold.
- causalkit.refutation.orthogonality.trim_sensitivity_curve_att(inference_fn, data, g, d, thresholds=array([0.9, 0.90863636, 0.91727273, 0.92590909, 0.93454545, 0.94318182, 0.95181818, 0.96045455, 0.96909091, 0.97772727, 0.98636364, 0.995]), **inference_kwargs)[source]#
Re-estimate θ while progressively trimming CONTROLS with large m(X).