"""
AIPW orthogonality diagnostics for IRM-based estimators.
This module implements comprehensive orthogonality diagnostics for AIPW/IRM-based
estimators like dml_ate and dml_att to validate the key assumptions required
for valid causal inference. Based on the efficient influence function (EIF) framework.
Key diagnostics implemented:
- Out-of-sample moment check (non-tautological)
- Orthogonality (Gateaux derivative) tests
- Influence diagnostics
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from typing import Callable, Any, Dict, List, Tuple, Optional
from scipy import stats
from sklearn.model_selection import KFold
from causalis.data.causaldata import CausalData
from copy import deepcopy
[docs]
def aipw_score_ate(y: np.ndarray, d: np.ndarray, g0: np.ndarray, g1: np.ndarray, m: np.ndarray, theta: float, trimming_threshold: float = 0.01) -> np.ndarray:
"""
Efficient influence function (EIF) for ATE.
Uses IRM naming: g0,g1 are outcome regressions E[Y|X,D=0/1], m is propensity P(D=1|X).
"""
m = np.clip(m, trimming_threshold, 1 - trimming_threshold)
return (g1 - g0) + d * (y - g1) / m - (1 - d) * (y - g0) / (1 - m) - theta
[docs]
def aipw_score_atte(y: np.ndarray, d: np.ndarray, g0: np.ndarray, g1: np.ndarray, m: np.ndarray, theta: float, p_treated: Optional[float] = None, trimming_threshold: float = 0.01) -> np.ndarray:
"""
Efficient influence function (EIF) for ATTE under IRM/AIPW.
ψ_ATTE(W; θ, η) = [ D*(Y - g0(X) - θ) - (1-D)*{ m(X)/(1-m(X)) }*(Y - g0(X)) ] / E[D]
Notes:
- Matches DoubleML's `score='ATTE'` (weights ω=D/E[D], \bar{ω}=m(X)/E[D]).
- g1 enters only via θ; ∂ψ/∂g1 = 0.
"""
m = np.clip(m, trimming_threshold, 1 - trimming_threshold)
if p_treated is None:
p_treated = float(np.mean(d))
gamma = m / (1.0 - m)
num = d * (y - g0 - theta) - (1.0 - d) * gamma * (y - g0)
return num / (p_treated + 1e-12)
[docs]
def oos_moment_check_with_fold_nuisances(
fold_thetas: List[float],
fold_indices: List[np.ndarray],
fold_nuisances: List[Tuple[np.ndarray, np.ndarray, np.ndarray]],
y: np.ndarray,
d: np.ndarray,
score_fn: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]] = None,
) -> Tuple[pd.DataFrame, float]:
"""
Out-of-sample moment check using fold-specific nuisances to avoid tautological results.
For each fold k, evaluates the AIPW score using θ fitted on other folds and
nuisance predictions from the fold-specific model, then tests if the combined
moment condition holds.
Parameters
----------
fold_thetas : List[float]
Treatment effects estimated excluding each fold
fold_indices : List[np.ndarray]
Indices for each fold
fold_nuisances : List[Tuple[np.ndarray, np.ndarray, np.ndarray]]
Fold-specific nuisance predictions (m, g0, g1) for each fold
y, d : np.ndarray
Observed outcomes and treatments
Returns
-------
Tuple[pd.DataFrame, float]
Fold-wise results and combined t-statistic
"""
rows = []
for k, (idx, nuisances) in enumerate(zip(fold_indices, fold_nuisances)):
th = fold_thetas[k]
m_fold, g0_fold, g1_fold = nuisances
# Compute scores using fold-specific theta and nuisances
if score_fn is None:
psi = aipw_score_ate(y[idx], d[idx], g0_fold, g1_fold, m_fold, th)
else:
psi = score_fn(y[idx], d[idx], g0_fold, g1_fold, m_fold, th)
rows.append({"fold": k, "n": len(idx), "psi_mean": psi.mean(), "psi_var": psi.var(ddof=1)})
df = pd.DataFrame(rows)
num = (df["n"] * df["psi_mean"]).sum()
den = np.sqrt((df["n"] * df["psi_var"]).sum())
tstat = num / den if den > 0 else 0.0
return df, float(tstat)
[docs]
def oos_moment_check(
fold_thetas: List[float],
fold_indices: List[np.ndarray],
y: np.ndarray,
d: np.ndarray,
g0: np.ndarray,
g1: np.ndarray,
m: np.ndarray,
score_fn: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]] = None,
) -> Tuple[pd.DataFrame, float]:
"""
Out-of-sample moment check to avoid tautological results (legacy/simple version).
For each fold k, evaluates the AIPW score using θ fitted on other folds,
then tests if the combined moment condition holds.
Parameters
----------
fold_thetas : List[float]
Treatment effects estimated excluding each fold
fold_indices : List[np.ndarray]
Indices for each fold
y, d, g0, g1, m : np.ndarray
Data arrays (outcomes, treatment, predictions)
Returns
-------
Tuple[pd.DataFrame, float]
Fold-wise results and combined t-statistic
"""
rows = []
for k, idx in enumerate(fold_indices):
th = fold_thetas[k]
if score_fn is None:
psi = aipw_score_ate(y[idx], d[idx], g0[idx], g1[idx], m[idx], th)
else:
psi = score_fn(y[idx], d[idx], g0[idx], g1[idx], m[idx], th)
rows.append({"fold": k, "n": len(idx), "psi_mean": psi.mean(), "psi_var": psi.var(ddof=1)})
df = pd.DataFrame(rows)
num = (df["n"] * df["psi_mean"]).sum()
den = np.sqrt((df["n"] * df["psi_var"]).sum())
tstat = num / den if den > 0 else 0.0
return df, float(tstat)
[docs]
def orthogonality_derivatives(X_basis: np.ndarray, y: np.ndarray, d: np.ndarray,
g0: np.ndarray, g1: np.ndarray, m: np.ndarray, trimming_threshold: float = 0.01) -> pd.DataFrame:
"""
Compute orthogonality (Gateaux derivative) tests for nuisance functions (ATE case).
Uses IRM naming: g0,g1 outcomes; m propensity.
"""
n, B = X_basis.shape
# Clip propensity scores to avoid division by zero and extreme weights
m_clipped = np.clip(m, trimming_threshold, 1 - trimming_threshold)
# g1 direction: ∂_{g1} φ̄[h1] = (1/n)∑ h1(Xi)(1 - Di/mi)
dg1_terms = X_basis * (1 - d / m_clipped)[:, None]
dg1 = dg1_terms.mean(axis=0)
dg1_se = dg1_terms.std(axis=0, ddof=1) / np.sqrt(n)
# g0 direction: ∂_{g0} φ̄[h0] = (1/n)∑ h0(Xi)((1-Di)/(1-mi) - 1)
dg0_terms = X_basis * ((1 - d) / (1 - m_clipped) - 1)[:, None]
dg0 = dg0_terms.mean(axis=0)
dg0_se = dg0_terms.std(axis=0, ddof=1) / np.sqrt(n)
# m direction: ∂_m φ̄[s] = -(1/n)∑ s(Xi)[Di(Yi-g1i)/mi² + (1-Di)(Yi-g0i)/(1-mi)²]
m_summand = (d * (y - g1) / m_clipped**2 + (1 - d) * (y - g0) / (1 - m_clipped)**2)
dm_terms = -X_basis * m_summand[:, None]
dm = dm_terms.mean(axis=0)
dm_se = dm_terms.std(axis=0, ddof=1) / np.sqrt(n)
out = pd.DataFrame({
"basis": np.arange(B),
"d_g1": dg1, "se_g1": dg1_se, "t_g1": dg1 / np.maximum(dg1_se, 1e-12),
"d_g0": dg0, "se_g0": dg0_se, "t_g0": dg0 / np.maximum(dg0_se, 1e-12),
"d_m": dm, "se_m": dm_se, "t_m": dm / np.maximum(dm_se, 1e-12),
})
return out
[docs]
def influence_summary(y: np.ndarray, d: np.ndarray, g0: np.ndarray, g1: np.ndarray,
m: np.ndarray, theta_hat: float, k: int = 10, score: str = "ATE", trimming_threshold: float = 0.01) -> Dict[str, Any]:
"""
Compute influence diagnostics showing where uncertainty comes from.
Parameters
----------
y, d, g0, g1, m : np.ndarray
Data arrays
theta_hat : float
Estimated treatment effect
k : int, default 10
Number of top influential observations to return
Returns
-------
Dict[str, Any]
Influence diagnostics including SE, heavy-tail metrics, and top-k cases
"""
score_u = str(score).upper()
if score_u == "ATE":
psi = aipw_score_ate(y, d, g0, g1, m, theta_hat, trimming_threshold=trimming_threshold)
elif score_u in ("ATTE", "ATT"):
psi = aipw_score_atte(y, d, g0, g1, m, theta_hat, p_treated=float(np.mean(d)), trimming_threshold=trimming_threshold)
else:
raise ValueError("score must be 'ATE' or 'ATTE'")
se = psi.std(ddof=1) / np.sqrt(len(psi))
idx = np.argsort(-np.abs(psi))[:k]
top = pd.DataFrame({
"i": idx,
"psi": psi[idx],
"m": m[idx],
"res_t": d[idx] * (y[idx] - g1[idx]),
"res_c": (1 - d[idx]) * (y[idx] - g0[idx])
})
return {
"se_plugin": float(se),
"kurtosis": float(((psi - psi.mean())**4).mean() / (psi.var(ddof=1)**2 + 1e-12)),
"p99_over_med": float(np.quantile(np.abs(psi), 0.99) / (np.median(np.abs(psi)) + 1e-12)),
"top_influential": top
}
[docs]
def refute_irm_orthogonality(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
trim_propensity: Tuple[float, float] = (0.02, 0.98),
n_basis_funcs: Optional[int] = None,
n_folds_oos: int = 4,
score: Optional[str] = None,
trimming_threshold: float = 0.01,
strict_oos: bool = True,
**inference_kwargs,
) -> Dict[str, Any]:
"""
Comprehensive AIPW orthogonality diagnostics for IRM estimators.
Implements three key diagnostic approaches based on the efficient influence function (EIF):
1. Out-of-sample moment check (non-tautological)
2. Orthogonality (Gateaux derivative) tests
3. Influence diagnostics
Parameters
----------
inference_fn : Callable
The inference function (dml_ate or dml_att)
data : CausalData
The causal data object
trim_propensity : Tuple[float, float], default (0.02, 0.98)
Propensity score trimming bounds (min, max) to avoid extreme weights
n_basis_funcs : Optional[int], default None (len(confounders)+1)
Number of basis functions for orthogonality derivative tests (constant + covariates).
If None, defaults to the number of confounders in `data` plus 1 for the constant term.
n_folds_oos : int, default 5
Number of folds for out-of-sample moment check
**inference_kwargs : dict
Additional arguments passed to inference_fn
Returns
-------
Dict[str, Any]
Dictionary containing:
- oos_moment_test: Out-of-sample moment condition results
- orthogonality_derivatives: Gateaux derivative test results
- influence_diagnostics: Influence function diagnostics
- theta: Original treatment effect estimate
- trimmed_diagnostics: Results on trimmed sample
- overall_assessment: Summary diagnostic assessment
Examples
--------
>>> from causalis.refutation.orthogonality import refute_irm_orthogonality
>>> from causalis.inference.ate.dml_ate import dml_ate
>>>
>>> # Comprehensive orthogonality check
>>> ortho_results = refute_irm_orthogonality(dml_ate, causal_data)
>>>
>>> # Check key diagnostics
>>> print(f"OOS moment t-stat: {ortho_results['oos_moment_test']['tstat']:.3f}")
>>> print(f"Assessment: {ortho_results['overall_assessment']}")
"""
# Run inference to get fitted IRM model
result = inference_fn(data, **inference_kwargs)
if 'model' not in result:
raise ValueError("Inference function must return a dictionary with 'model' key")
dml_model = result['model']
# Prefer wrapper-reported coefficient; fallback to IRM attributes if missing
if 'coefficient' in result:
theta_hat = float(result['coefficient'])
else:
theta_hat = float(getattr(dml_model, 'coef_', [np.nan])[0])
# Extract cross-fitted predictions from IRM model using robust function
m_propensity, g0_outcomes, g1_outcomes = extract_nuisances(dml_model)
# Get observed data
y = data.target.values.astype(float)
d = data.treatment.values.astype(float)
# Determine score
score_u = str(score).upper() if score is not None else "AUTO"
if score_u == "AUTO":
score_attr = getattr(dml_model, 'score', None)
if score_attr is None:
score_attr = getattr(dml_model, '_score', 'ATE')
score_str = str(score_attr).upper()
score_u = 'ATTE' if 'ATT' in score_str else 'ATE'
# Build score function
if score_u == 'ATE':
def score_fn(y_, d_, g0_, g1_, m_, th_):
return aipw_score_ate(y_, d_, g0_, g1_, m_, th_, trimming_threshold=trimming_threshold)
elif score_u in ('ATTE', 'ATT'):
p_treated = float(np.mean(d))
def score_fn(y_, d_, g0_, g1_, m_, th_):
return aipw_score_atte(y_, d_, g0_, g1_, m_, th_, p_treated=p_treated, trimming_threshold=trimming_threshold)
else:
raise ValueError("score must be 'ATE' or 'ATTE'")
# Apply propensity trimming
trim_min, trim_max = trim_propensity
trim_mask = (m_propensity >= trim_min) & (m_propensity <= trim_max)
n_trimmed = int(np.sum(~trim_mask))
# Create trimmed arrays
y_trim = y[trim_mask]
d_trim = d[trim_mask]
m_trim = m_propensity[trim_mask]
g0_trim = g0_outcomes[trim_mask]
g1_trim = g1_outcomes[trim_mask]
# === 1. OUT-OF-SAMPLE MOMENT CHECK ===
# Prefer fast path using cached ψ_a, ψ_b and training folds from the fitted IRM model
use_fast = hasattr(dml_model, "psi_a_") and hasattr(dml_model, "psi_b_") and getattr(dml_model, "folds_", None) is not None
if use_fast:
folds_arr = np.asarray(dml_model.folds_, dtype=int)
K = int(folds_arr.max() + 1) if folds_arr.size > 0 else 0
fold_indices = [np.where(folds_arr == k)[0] for k in range(K)]
# Compute both fold-aggregated and strict t-stats using only ψ components
oos_df, oos_tstat, tstat_strict = oos_moment_check_from_psi(
np.asarray(dml_model.psi_a_, dtype=float),
np.asarray(dml_model.psi_b_, dtype=float),
fold_indices,
strict=True,
)
oos_pvalue = float(2 * (1 - stats.norm.cdf(abs(oos_tstat))))
pvalue_strict = float(2 * (1 - stats.norm.cdf(abs(tstat_strict)))) if tstat_strict is not None else float("nan")
strict_applied = bool(strict_oos)
main_tstat = float(tstat_strict if strict_applied else oos_tstat)
main_pvalue = float(2 * (1 - stats.norm.cdf(abs(main_tstat))))
else:
# Fallback: legacy slow path with K refits and score recomputations
kf = KFold(n_splits=n_folds_oos, shuffle=True, random_state=42)
fold_thetas: List[float] = []
fold_indices: List[np.ndarray] = []
fold_nuisances: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = []
df_full = data.get_df()
for train_idx, test_idx in kf.split(df_full):
# Create fold-specific data using public APIs
train_data = CausalData(
df=df_full.iloc[train_idx].copy(),
treatment=data.treatment.name,
outcome=data.target.name,
confounders=list(data.confounders) if data.confounders else None
)
# Fit model on training fold to get theta
fold_result = inference_fn(train_data, **inference_kwargs)
fold_thetas.append(float(fold_result['coefficient']))
fold_indices.append(test_idx)
# Use full-model cross-fitted nuisances indexed by test fold (fast OOS)
m_fold = m_propensity[test_idx]
g0_fold = g0_outcomes[test_idx]
g1_fold = g1_outcomes[test_idx]
fold_nuisances.append((m_fold, g0_fold, g1_fold))
# Run out-of-sample moment check with fold-specific nuisances (fold-aggregated)
oos_df, oos_tstat = oos_moment_check_with_fold_nuisances(
fold_thetas, fold_indices, fold_nuisances, y, d, score_fn=score_fn
)
oos_pvalue = float(2 * (1 - stats.norm.cdf(abs(oos_tstat))))
# Strict aggregation over all held-out psi without creating a giant array
total_n = 0
total_sum = 0.0
total_sumsq = 0.0
for k, (idx, (m_fold, g0_fold, g1_fold)) in enumerate(zip(fold_indices, fold_nuisances)):
th = fold_thetas[k]
psi_k = score_fn(y[idx], d[idx], g0_fold, g1_fold, m_fold, th)
n_k = psi_k.size
s_k = float(psi_k.sum())
ss_k = float((psi_k * psi_k).sum())
total_n += n_k
total_sum += s_k
total_sumsq += ss_k
if total_n > 1:
mean_all = total_sum / total_n
var_all = (total_sumsq - total_n * (mean_all ** 2)) / (total_n - 1)
se_all = float(np.sqrt(max(var_all, 0.0) / total_n))
tstat_strict = float(mean_all / se_all) if se_all > 0 else 0.0
else:
tstat_strict = 0.0
pvalue_strict = float(2 * (1 - stats.norm.cdf(abs(tstat_strict))))
# Choose which t-stat to report as main
strict_applied = bool(strict_oos)
main_tstat = float(tstat_strict if strict_applied else oos_tstat)
main_pvalue = float(pvalue_strict if strict_applied else oos_pvalue)
# === 2. ORTHOGONALITY DERIVATIVE TESTS ===
# Create basis functions: constant + first few standardized confounders
if data.confounders is None or len(data.confounders) == 0:
raise ValueError("CausalData object must have confounders defined for orthogonality diagnostics")
# If n_basis_funcs is not provided, default to (number of confounders + 1) for the constant term
if n_basis_funcs is None:
n_basis_funcs = len(data.confounders) + 1
# Ensure confounders are represented as a float ndarray to avoid object-dtype reductions
df_conf = data.get_df()[list(data.confounders)]
X = df_conf.to_numpy(dtype=float)
n_covs = min(max(n_basis_funcs - 1, 0), X.shape[1]) # -1 for constant term
if n_covs > 0:
X_selected = X[:, :n_covs]
X_std = (X_selected - np.mean(X_selected, axis=0)) / (np.std(X_selected, axis=0) + 1e-8)
X_basis = np.c_[np.ones(len(X)), X_std] # Constant + standardized covariates
else:
X_basis = np.ones((len(X), 1)) # Only constant term
if score_u == 'ATE':
ortho_derivs_full = orthogonality_derivatives(X_basis, y, d, g0_outcomes, g1_outcomes, m_propensity, trimming_threshold=trimming_threshold)
X_basis_trim = X_basis[trim_mask]
ortho_derivs_trim = orthogonality_derivatives(X_basis_trim, y_trim, d_trim, g0_trim, g1_trim, m_trim, trimming_threshold=trimming_threshold)
problematic_derivs_full = ortho_derivs_full[(np.abs(ortho_derivs_full['t_g1']) > 2) | (np.abs(ortho_derivs_full['t_g0']) > 2) | (np.abs(ortho_derivs_full['t_m']) > 2)]
problematic_derivs_trim = ortho_derivs_trim[(np.abs(ortho_derivs_trim['t_g1']) > 2) | (np.abs(ortho_derivs_trim['t_g0']) > 2) | (np.abs(ortho_derivs_trim['t_m']) > 2)]
derivs_interpretation = 'Large |t| (>2) indicate calibration issues'
derivs_full_ok = len(problematic_derivs_full) == 0
derivs_trim_ok = len(problematic_derivs_trim) == 0
else: # ATTE / ATT
p_treated_full = float(np.mean(d))
ortho_derivs_full = orthogonality_derivatives_atte(X_basis, y, d, g0_outcomes, m_propensity, p_treated_full, trimming_threshold=trimming_threshold)
X_basis_trim = X_basis[trim_mask]
p_treated_trim = float(np.mean(d_trim))
ortho_derivs_trim = orthogonality_derivatives_atte(X_basis_trim, y_trim, d_trim, g0_trim, m_trim, p_treated_trim, trimming_threshold=trimming_threshold)
# Backward-compatible alias columns for expected legacy names
try:
ortho_derivs_full = ortho_derivs_full.copy()
ortho_derivs_full["d_g"] = ortho_derivs_full.get("d_g0", np.nan)
ortho_derivs_full["t_g"] = ortho_derivs_full.get("t_g0", np.nan)
ortho_derivs_full["d_m0"] = ortho_derivs_full.get("d_m", np.nan)
ortho_derivs_full["t_m0"] = ortho_derivs_full.get("t_m", np.nan)
ortho_derivs_trim = ortho_derivs_trim.copy()
ortho_derivs_trim["d_g"] = ortho_derivs_trim.get("d_g0", np.nan)
ortho_derivs_trim["t_g"] = ortho_derivs_trim.get("t_g0", np.nan)
ortho_derivs_trim["d_m0"] = ortho_derivs_trim.get("d_m", np.nan)
ortho_derivs_trim["t_m0"] = ortho_derivs_trim.get("t_m", np.nan)
except Exception:
pass
problematic_derivs_full = ortho_derivs_full[(np.abs(ortho_derivs_full['t_g0']) > 2) | (np.abs(ortho_derivs_full['t_m']) > 2)]
problematic_derivs_trim = ortho_derivs_trim[(np.abs(ortho_derivs_trim['t_g0']) > 2) | (np.abs(ortho_derivs_trim['t_m']) > 2)]
derivs_interpretation = 'ATTE: check g0 & m only; large |t| (>2) => calibration issues'
derivs_full_ok = len(problematic_derivs_full) == 0
derivs_trim_ok = len(problematic_derivs_trim) == 0
# === 3. INFLUENCE DIAGNOSTICS ===
influence_full = influence_summary(y, d, g0_outcomes, g1_outcomes, m_propensity, theta_hat, score=score_u, trimming_threshold=trimming_threshold)
# Re-estimate theta on the trimmed sample for fair trimmed influence diagnostics
theta_hat_trim = theta_hat
try:
df_full_local = data.get_df()
data_trim_obj = CausalData(
df=df_full_local.loc[trim_mask].copy(),
treatment=data.treatment.name,
outcome=data.target.name,
confounders=list(data.confounders) if data.confounders else None
)
res_trim = inference_fn(data_trim_obj, **inference_kwargs)
theta_hat_trim = float(res_trim.get('coefficient', theta_hat))
except Exception:
pass
influence_trim = influence_summary(y_trim, d_trim, g0_trim, g1_trim, m_trim, theta_hat_trim, score=score_u, trimming_threshold=trimming_threshold)
# === DIAGNOSTIC ASSESSMENT ===
model_se = result.get('std_error')
if model_se is not None:
se_reasonable = abs(influence_full['se_plugin'] - model_se) < 2 * model_se
else:
se_reasonable = False
conditions = {
'oos_moment_ok': abs(main_tstat) < 2.0,
'derivs_full_ok': derivs_full_ok,
'derivs_trim_ok': derivs_trim_ok,
'se_reasonable': se_reasonable,
'no_extreme_influence': influence_full['p99_over_med'] < 10,
'trimming_reasonable': n_trimmed < 0.1 * len(y)
}
n_passed = sum(bool(v) for v in conditions.values())
if n_passed >= 5:
overall_assessment = "PASS: Strong evidence for orthogonality"
elif n_passed >= 4:
overall_assessment = "CAUTION: Most conditions satisfied"
elif n_passed >= 3:
overall_assessment = "WARNING: Several orthogonality violations"
else:
overall_assessment = "FAIL: Major orthogonality violations detected"
# ATTE-specific overlap and trim-sensitivity
overlap_atte = None
trim_curve_atte = None
if score_u in ('ATTE', 'ATT'):
overlap_atte = overlap_diagnostics_atte(m_propensity, d, eps_list=[0.95, 0.97, 0.98, 0.99])
try:
trim_curve_atte = trim_sensitivity_curve_atte(
inference_fn, data, m_propensity, d,
thresholds=np.linspace(0.90, 0.995, 12),
**inference_kwargs
)
except Exception as _:
trim_curve_atte = None
# p-treated parameterization
params_extra = {}
if score_u in ('ATTE','ATT'):
p_full = float(np.mean(d))
p_trim = float(np.mean(d_trim)) if len(d_trim) > 0 else float('nan')
params_extra = {
'p_treated': p_full,
'p_treated_full': p_full,
'p_treated_trim': p_trim,
}
return {
'theta': float(theta_hat),
'params': {
'score': score_u,
'trimming_threshold': trimming_threshold,
'strict_oos_requested': bool(strict_oos),
'strict_oos_applied': bool(strict_applied),
**params_extra,
},
'oos_moment_test': {
'fold_results': oos_df,
'tstat': float(main_tstat),
'pvalue': float(main_pvalue),
'tstat_fold_agg': float(oos_tstat),
'pvalue_fold_agg': float(oos_pvalue),
'tstat_strict': float(tstat_strict),
'pvalue_strict': float(pvalue_strict),
'aggregation': 'strict' if strict_applied else 'fold',
'interpretation': 'Should be ≈ 0 if moment condition holds'
},
'orthogonality_derivatives': {
'full_sample': ortho_derivs_full,
'trimmed_sample': ortho_derivs_trim,
'problematic_full': problematic_derivs_full,
'problematic_trimmed': problematic_derivs_trim,
'interpretation': derivs_interpretation
},
'influence_diagnostics': {
'full_sample': influence_full,
'trimmed_sample': influence_trim,
'interpretation': 'Heavy tails or extreme kurtosis suggest instability'
},
'overlap_atte': overlap_atte,
'robustness': {
'trim_curve_atte': trim_curve_atte,
'interpretation': 'ATTE typically more sensitive to trimming near m→1 (controls).'
},
'trimming_info': {
'bounds': trim_propensity,
'n_trimmed': int(n_trimmed),
'pct_trimmed': float(n_trimmed / len(y) * 100.0)
},
'diagnostic_conditions': conditions,
'overall_assessment': overall_assessment
}
[docs]
def orthogonality_derivatives_atte(
X_basis: np.ndarray, y: np.ndarray, d: np.ndarray,
g0: np.ndarray, m: np.ndarray, p_treated: float, trimming_threshold: float = 0.01
) -> pd.DataFrame:
"""
Gateaux derivatives of the ATTE score wrt nuisances (g0, m). g1-derivative is 0.
For ψ_ATTE = [ D*(Y - g0 - θ) - (1-D)*(m/(1-m))*(Y - g0) ] / p_treated:
∂_{g0}[h] : (1/n) Σ h(X_i) * [ ((1-D_i)*m_i/(1-m_i) - D_i) / p_treated ]
∂_{m}[s] : (1/n) Σ s(X_i) * [ -(1-D_i)*(Y_i - g0_i) / ( p_treated * (1-m_i)^2 ) ]
Both have 0 expectation at the truth (Neyman orthogonality).
"""
n, B = X_basis.shape
m = np.clip(m, trimming_threshold, 1 - trimming_threshold)
odds = m / (1.0 - m)
dg0_terms = X_basis * (((1.0 - d) * odds - d) / (p_treated + 1e-12))[:, None]
dg0 = dg0_terms.mean(axis=0)
dg0_se = dg0_terms.std(axis=0, ddof=1) / np.sqrt(n)
dm_terms = - X_basis * (((1.0 - d) * (y - g0)) / ((p_treated + 1e-12) * (1.0 - m)**2))[:, None]
dm = dm_terms.mean(axis=0)
dm_se = dm_terms.std(axis=0, ddof=1) / np.sqrt(n)
dg1 = np.zeros(B)
dg1_se = np.zeros(B)
t = lambda est, se: est / np.maximum(se, 1e-12)
return pd.DataFrame({
"basis": np.arange(B),
"d_g1": dg1, "se_g1": dg1_se, "t_g1": np.zeros(B),
"d_g0": dg0, "se_g0": dg0_se, "t_g0": t(dg0, dg0_se),
"d_m": dm, "se_m": dm_se, "t_m": t(dm, dm_se),
})
[docs]
def overlap_diagnostics_atte(
m: np.ndarray, d: np.ndarray, eps_list: List[float] = [0.95, 0.97, 0.98, 0.99]
) -> pd.DataFrame:
"""
Key overlap metrics for ATTE: availability of suitable controls.
Reports conditional shares: among CONTROLS, fraction with m(X) ≥ threshold; among TREATED, fraction with m(X) ≤ 1 - threshold.
"""
rows = []
m = np.asarray(m)
d_bool = np.asarray(d).astype(bool)
ctrl = ~d_bool
trt = d_bool
n_ctrl = int(ctrl.sum())
n_trt = int(trt.sum())
for thr in eps_list:
pct_ctrl = float(100.0 * ((m[ctrl] >= thr).mean() if n_ctrl else np.nan))
pct_trt = float(100.0 * ((m[trt] <= (1.0 - thr)).mean() if n_trt else np.nan))
rows.append({
"threshold": thr,
"pct_controls_with_m_ge_thr": pct_ctrl,
"pct_treated_with_m_le_1_minus_thr": pct_trt,
})
return pd.DataFrame(rows)
[docs]
def trim_sensitivity_curve_atte(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
m: np.ndarray, d: np.ndarray,
thresholds: np.ndarray = np.linspace(0.90, 0.995, 12),
**inference_kwargs
) -> pd.DataFrame:
"""
Re-estimate θ while progressively trimming CONTROLS with large m(X).
"""
df_full = data.get_df()
rows = []
for thr in thresholds:
controls = (np.asarray(d).astype(bool) == False)
high_p = (np.asarray(m) >= thr)
drop = controls & high_p
keep = ~drop
df_trim = df_full.loc[keep].copy()
data_trim = CausalData(
df=df_trim, treatment=data.treatment.name,
outcome=data.target.name, confounders=list(data.confounders) if data.confounders else None
)
res = inference_fn(data_trim, **inference_kwargs)
rows.append({
"trim_threshold": float(thr),
"n": int(keep.sum()),
"pct_dropped": float(100.0 * drop.mean()),
"theta": float(res["coefficient"]),
"se": float(res.get("std_error", np.nan))
})
return pd.DataFrame(rows)
[docs]
def trim_sensitivity_curve_ate(
m_hat: np.ndarray,
D: np.ndarray,
Y: np.ndarray,
g0_hat: np.ndarray,
g1_hat: np.ndarray,
eps_grid: tuple[float, ...] = (0.0, 0.005, 0.01, 0.02, 0.05),
) -> pd.DataFrame:
"""
Sensitivity of ATE estimate to propensity clipping epsilon (no re-fit).
For each epsilon in eps_grid, compute the AIPW/IRM ATE estimate using
m_clipped = clip(m_hat, eps, 1-eps) over the full sample and report
the plug-in standard error from the EIF.
Parameters
----------
m_hat, D, Y, g0_hat, g1_hat : np.ndarray
Cross-fitted nuisances and observed arrays.
eps_grid : tuple[float, ...]
Sequence of clipping thresholds ε to evaluate.
Returns
-------
pd.DataFrame
Columns: ['trim_eps','n','pct_clipped','theta','se'].
pct_clipped is the percent of observations with m outside [ε,1-ε].
"""
m = np.asarray(m_hat, dtype=float).ravel()
d = np.asarray(D, dtype=float).ravel()
y = np.asarray(Y, dtype=float).ravel()
g0 = np.asarray(g0_hat, dtype=float).ravel()
g1 = np.asarray(g1_hat, dtype=float).ravel()
n = y.size
rows: list[dict[str, float]] = []
for eps in eps_grid:
eps_f = float(eps)
m_clip = np.clip(m, eps_f, 1.0 - eps_f)
# share clipped (either side)
pct_clipped = float(100.0 * np.mean((m <= eps_f) | (m >= 1.0 - eps_f))) if n > 0 else float("nan")
# AIPW ATE with clipped propensity
theta_terms = (g1 - g0) + d * (y - g1) / m_clip - (1.0 - d) * (y - g0) / (1.0 - m_clip)
theta = float(np.mean(theta_terms)) if n > 0 else float("nan")
psi = theta_terms - theta
se = float(np.std(psi, ddof=1) / np.sqrt(n)) if n > 1 else 0.0
rows.append({
"trim_eps": eps_f,
"n": int(n),
"pct_clipped": pct_clipped,
"theta": theta,
"se": se,
})
return pd.DataFrame(rows)
# ------------------------------------------------------------------
# Placebo / robustness checks (moved from placebo.py)
# ------------------------------------------------------------------
def _is_binary(series: pd.Series) -> bool:
"""
Determine if a pandas Series contains binary data (0/1 or True/False).
Parameters
----------
series : pd.Series
The series to check
Returns
-------
bool
True if the series appears to be binary
"""
unique_vals = set(series.dropna().unique())
# Check for 0/1 binary
if unique_vals == {0, 1} or unique_vals == {0} or unique_vals == {1}:
return True
# Check for True/False binary
if unique_vals == {True, False} or unique_vals == {True} or unique_vals == {False}:
return True
# Check for numeric that could be treated as binary (only two distinct values)
if len(unique_vals) == 2 and all(isinstance(x, (int, float, bool, np.integer, np.floating)) for x in unique_vals):
return True
return False
def _generate_random_outcome(original_outcome: pd.Series, rng: np.random.Generator) -> np.ndarray:
"""
Generate random outcome variables matching the distribution of the original outcome.
For binary outcomes: generates random binary variables with same proportion as original
For continuous outcomes: generates random continuous variables from normal distribution
fitted to the original data
Parameters
----------
original_outcome : pd.Series
The original outcome variable
rng : np.random.Generator
Random number generator
Returns
-------
np.ndarray
Generated random outcome variables
"""
n = len(original_outcome)
if _is_binary(original_outcome):
# For binary outcome, generate random binary with same proportion
original_rate = float(original_outcome.mean())
return rng.binomial(1, original_rate, size=n).astype(original_outcome.dtype)
else:
# For continuous outcome, generate from normal distribution fitted to original data
mean = float(original_outcome.mean())
std = float(original_outcome.std())
if std == 0:
# If no variance, generate constant values equal to the mean
return np.full(n, mean, dtype=original_outcome.dtype)
return rng.normal(mean, std, size=n).astype(original_outcome.dtype)
def _generate_random_treatment(original_treatment: pd.Series, rng: np.random.Generator) -> np.ndarray:
"""
Generate random binary treatment variables with same proportion as original.
Parameters
----------
original_treatment : pd.Series
The original treatment variable
rng : np.random.Generator
Random number generator
Returns
-------
np.ndarray
Generated random binary treatment variables
"""
n = len(original_treatment)
treatment_rate = float(original_treatment.mean())
return rng.binomial(1, treatment_rate, size=n).astype(original_treatment.dtype)
def _run_inference(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
**kwargs,
) -> Dict[str, float]:
"""
Helper that executes `inference_fn` and extracts the two items of interest.
Provides a safe fallback for p_value if missing by using a normal approximation
based on coefficient and std_error.
"""
res = inference_fn(data, **kwargs)
out = {"theta": float(res["coefficient"])}
p = res.get("p_value")
if p is None and "std_error" in res:
try:
se = float(res["std_error"])
th = float(res["coefficient"])
z = 0.0 if se == 0.0 else th / se
p = 2.0 * (1 - stats.norm.cdf(abs(z)))
except Exception:
p = None
out["p_value"] = float(p) if p is not None else float("nan")
return out
# ------------------------------------------------------------------
# 1. Placebo ‑- generate random outcome
# ------------------------------------------------------------------
def _public_names(data: CausalData) -> tuple[str, str]:
tname = getattr(getattr(data, "treatment", None), "name", None) or getattr(data, "_treatment", "D")
yname = getattr(getattr(data, "target", None), "name", None) or getattr(data, "_target", "Y")
return tname, yname
[docs]
def refute_placebo_outcome(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
random_state: int | None = None,
**inference_kwargs,
) -> Dict[str, float]:
"""
Generate random outcome (target) variables while keeping treatment
and covariates intact. For binary outcomes, generates random binary
variables with the same proportion. For continuous outcomes, generates
random variables from a normal distribution fitted to the original data.
A valid causal design should now yield θ ≈ 0 and a large p-value.
"""
rng = np.random.default_rng(random_state)
df_mod = data.get_df().copy()
tname, yname = _public_names(data)
original_outcome = df_mod[yname]
df_mod[yname] = _generate_random_outcome(original_outcome, rng)
ck_mod = CausalData(
df=df_mod,
treatment=tname,
outcome=yname,
confounders=(list(data.confounders) if data.confounders else None),
)
return _run_inference(inference_fn, ck_mod, **inference_kwargs)
# ------------------------------------------------------------------
# 2. Placebo ‑- generate random treatment
# ------------------------------------------------------------------
[docs]
def refute_placebo_treatment(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
random_state: int | None = None,
**inference_kwargs,
) -> Dict[str, float]:
"""
Generate random binary treatment variables while keeping outcome and
covariates intact. Generates random binary treatment with the same
proportion as the original treatment. Breaks the treatment–outcome link.
"""
rng = np.random.default_rng(random_state)
df_mod = data.get_df().copy()
tname, yname = _public_names(data)
original_treatment = df_mod[tname]
df_mod[tname] = _generate_random_treatment(original_treatment, rng)
ck_mod = CausalData(
df=df_mod,
treatment=tname,
outcome=yname,
confounders=(list(data.confounders) if data.confounders else None),
)
return _run_inference(inference_fn, ck_mod, **inference_kwargs)
# ------------------------------------------------------------------
# 3. Subset robustness check
# ------------------------------------------------------------------
[docs]
def refute_subset(
inference_fn: Callable[..., Dict[str, Any]],
data: CausalData,
fraction: float = 0.8,
random_state: int | None = None,
**inference_kwargs,
) -> Dict[str, float]:
"""
Re-estimate the effect on a random subset (default 80 %)
to check sample-stability of the estimate.
"""
if not 0.0 < fraction <= 1.0:
raise ValueError("`fraction` must lie in (0, 1].")
rng = np.random.default_rng(random_state)
df = data.get_df()
n = len(df)
idx = rng.choice(n, size=int(np.floor(fraction * n)), replace=False)
tname, yname = _public_names(data)
df_mod = df.iloc[idx].copy()
ck_mod = CausalData(
df=df_mod,
treatment=tname,
outcome=yname,
confounders=(list(data.confounders) if data.confounders else None),
)
return _run_inference(inference_fn, ck_mod, **inference_kwargs)
def _fast_fold_thetas_from_psi(psi_a: np.ndarray, psi_b: np.ndarray, fold_indices: List[np.ndarray]) -> List[float]:
"""
Compute leave-fold-out θ_{-k} using cached ψ_a, ψ_b.
θ_{-k} = - mean_R(ψ_b) / mean_R(ψ_a), where R is the complement of fold k.
"""
psi_a = np.asarray(psi_a, dtype=float).ravel()
psi_b = np.asarray(psi_b, dtype=float).ravel()
n = psi_a.size
sum_a = float(psi_a.sum())
sum_b = float(psi_b.sum())
thetas: List[float] = []
for idx in fold_indices:
idx = np.asarray(idx, dtype=int)
n_r = n - idx.size
if n_r <= 0:
thetas.append(0.0)
continue
sa = sum_a - float(psi_a[idx].sum())
sb = sum_b - float(psi_b[idx].sum())
mean_a = sa / n_r
mean_b = sb / n_r
denom = mean_a if abs(mean_a) > 1e-12 else -1.0 # robust guard against div-by-zero
thetas.append(-mean_b / denom)
return thetas
[docs]
def oos_moment_check_from_psi(
psi_a: np.ndarray,
psi_b: np.ndarray,
fold_indices: List[np.ndarray],
*,
strict: bool = False,
) -> Tuple[pd.DataFrame, float, Optional[float]]:
"""
OOS moment check using cached ψ_a, ψ_b only.
Returns (fold-wise DF, t_fold_agg, t_strict if requested).
"""
psi_a = np.asarray(psi_a, dtype=float).ravel()
psi_b = np.asarray(psi_b, dtype=float).ravel()
fold_thetas = _fast_fold_thetas_from_psi(psi_a, psi_b, fold_indices)
rows: List[Dict[str, Any]] = []
total_n = 0
total_sum = 0.0
total_sumsq = 0.0
for k, idx in enumerate(fold_indices):
idx = np.asarray(idx, dtype=int)
th = fold_thetas[k]
# ψ_k = ψ_b[idx] + ψ_a[idx]*θ_{-k}
psi_k = psi_b[idx] + psi_a[idx] * th
n_k = psi_k.size
m_k = float(psi_k.mean()) if n_k > 0 else 0.0
v_k = float(psi_k.var(ddof=1)) if n_k > 1 else 0.0
rows.append({"fold": k, "n": n_k, "psi_mean": m_k, "psi_var": v_k})
if strict and n_k > 0:
total_n += n_k
s = float(psi_k.sum())
total_sum += s
total_sumsq += float((psi_k * psi_k).sum())
df = pd.DataFrame(rows)
# Fold-aggregated t-stat
num = float((df["n"] * df["psi_mean"]).sum())
den = float(np.sqrt((df["n"] * df["psi_var"]).sum()))
t_fold = (num / den) if den > 0 else 0.0
t_strict: Optional[float] = None
if strict and total_n > 1:
mean_all = total_sum / total_n
var_all = (total_sumsq - total_n * (mean_all ** 2)) / (total_n - 1)
se_all = float(np.sqrt(max(var_all, 0.0) / total_n))
t_strict = (mean_all / se_all) if se_all > 0 else 0.0
return df, float(t_fold), (None if t_strict is None else float(t_strict))
# ---- High-level entry point similar to overlap_validation.run_overlap_diagnostics ----
ResultLike = Dict[str, Any] | Any
def _extract_score_inputs_from_result(res: ResultLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float, str, Optional[np.ndarray], Optional[np.ndarray], Optional[List[np.ndarray]]]:
"""
Extract (y, d, g0, g1, m, theta, score, psi_a, psi_b, fold_indices) from a
dml_ate/dml_att-like result dict or an IRM-like model instance.
Supports two paths:
- Dict with keys {'model', 'coefficient'} and optional 'diagnostic_data'.
- IRM/DoubleML-like model object with .data, cross-fitted nuisances, and optionally psi caches.
"""
model = None
theta = float('nan')
score_str = 'ATE'
# 1) Result dict
if isinstance(res, dict):
if 'model' in res:
model = res['model']
if 'coefficient' in res:
theta = float(res['coefficient'])
# Try to infer score from result params if present
params = res.get('params') or res.get('Parameters') or {}
sc = params.get('score') if isinstance(params, dict) else None
if sc is not None:
score_str = str(sc).upper()
# If diagnostic_data carries arrays, prefer them
dd = res.get('diagnostic_data', {}) if isinstance(res.get('diagnostic_data', {}), dict) else {}
m = dd['m_hat'] if 'm_hat' in dd else (dd['m'] if 'm' in dd else None)
g0 = dd['g0_hat'] if 'g0_hat' in dd else (dd['g0'] if 'g0' in dd else None)
g1 = dd['g1_hat'] if 'g1_hat' in dd else (dd['g1'] if 'g1' in dd else None)
y = dd['y'] if 'y' in dd else (dd['Y'] if 'Y' in dd else None)
d = dd['d'] if 'd' in dd else (dd['D'] if 'D' in dd else None)
if all(v is not None for v in (y, d, g0, g1, m)):
y = np.asarray(y, dtype=float).ravel()
d = np.asarray(d, dtype=float).ravel()
g0 = np.asarray(g0, dtype=float).ravel()
g1 = np.asarray(g1, dtype=float).ravel()
m = np.asarray(m, dtype=float).ravel()
# Score fallback from coefficient if missing
if np.isnan(theta) and 'coefficient' in res:
theta = float(res['coefficient'])
# psi caches from model if available
if model is None:
return y, d, g0, g1, m, theta, score_str, None, None, None
# otherwise continue to model extraction below
else:
model = res
if model is None:
raise ValueError("Result must contain 'model' or provide diagnostic_data arrays.")
# 2) Model extraction
# Try to determine score setting
sc_attr = getattr(model, 'score', None)
if sc_attr is None:
sc_attr = getattr(model, '_score', None)
if sc_attr is not None:
score_str = 'ATTE' if 'ATT' in str(sc_attr).upper() else 'ATE'
# Cross-fitted nuisances
m, g0, g1 = extract_nuisances(model)
# Observed y,d from model.data
df = None
data_obj = getattr(model, 'data', None)
if data_obj is not None and hasattr(data_obj, 'get_df'):
df = data_obj.get_df()
# Try robust attribute names
if hasattr(data_obj, 'treatment') and hasattr(data_obj, 'target'):
tname = data_obj.treatment.name
yname = data_obj.target.name
else:
# Fallback to common attributes used in CausalData
tname = getattr(data_obj, '_treatment', 'D')
yname = getattr(data_obj, '_target', 'Y')
d = df[tname].to_numpy(dtype=float)
y = df[yname].to_numpy(dtype=float)
else:
raise ValueError("Model.data with get_df() is required to extract y and d.")
# Theta from wrapper or model
if np.isnan(theta):
coef = getattr(model, 'coef_', None)
if coef is not None:
try:
theta = float(coef[0])
except Exception:
theta = float(coef)
else:
theta = float('nan')
# Optional fast-path caches
psi_a = getattr(model, 'psi_a_', None)
psi_b = getattr(model, 'psi_b_', None)
folds = getattr(model, 'folds_', None)
if psi_a is not None and psi_b is not None and folds is not None:
# Build fold indices list
folds_arr = np.asarray(folds, dtype=int)
K = int(folds_arr.max() + 1) if folds_arr.size > 0 else 0
fold_indices = [np.where(folds_arr == k)[0] for k in range(K)]
return y, d, g0, g1, m, float(theta), score_str, np.asarray(psi_a, dtype=float), np.asarray(psi_b, dtype=float), fold_indices
return y, d, g0, g1, m, float(theta), score_str, None, None, None
[docs]
def run_score_diagnostics(
res: ResultLike = None,
*,
y: Optional[np.ndarray] = None,
d: Optional[np.ndarray] = None,
g0: Optional[np.ndarray] = None,
g1: Optional[np.ndarray] = None,
m: Optional[np.ndarray] = None,
theta: Optional[float] = None,
score: Optional[str] = None,
trimming_threshold: float = 0.01,
n_basis_funcs: Optional[int] = None,
return_summary: bool = True,
) -> Dict[str, Any]:
"""
Single entry-point for score diagnostics (orthogonality) akin to run_overlap_diagnostics.
You can call it in TWO ways:
A) With raw arrays:
run_score_diagnostics(y=..., d=..., g0=..., g1=..., m=..., theta=...)
B) With a model/result:
run_score_diagnostics(res=<dml_ate/dml_att result dict or IRM-like model>)
Returns a dictionary with:
- params (score, trimming_threshold)
- oos_moment_test (if fast-path caches available on model; else omitted)
- orthogonality_derivatives (DataFrame)
- influence_diagnostics (full_sample)
- summary (compact DataFrame) if return_summary=True
- meta
"""
# Resolve inputs
psi_a = psi_b = None
fold_indices = None
if any(v is None for v in (y, d, g0, g1, m)):
if res is None:
raise ValueError("Pass either (y,d,g0,g1,m,theta) or `res`.")
y, d, g0, g1, m, theta_ex, score_detected, psi_a, psi_b, fold_indices = _extract_score_inputs_from_result(res)
if theta is None:
theta = theta_ex
if score is None:
score = score_detected
else:
if score is None:
score = 'ATE'
score_u = str(score).upper()
# Enforce binary D and scrub non-finite across inputs
d = (np.asarray(d) > 0.5).astype(float)
y_arr = np.asarray(y, dtype=float)
g0_arr = np.asarray(g0, dtype=float)
g1_arr = np.asarray(g1, dtype=float)
m_arr = np.asarray(m, dtype=float)
mask = np.isfinite(y_arr) & np.isfinite(d) & np.isfinite(g0_arr) & np.isfinite(g1_arr) & np.isfinite(m_arr)
y, d, g0, g1, m = y_arr[mask], d[mask], g0_arr[mask], g1_arr[mask], m_arr[mask]
# Influence diagnostics
infl = influence_summary(y, d, g0, g1, m, float(theta if theta is not None else 0.0), score=score_u, trimming_threshold=trimming_threshold)
# Build basis for orthogonality derivatives
# If we can access confounders via res.model.data, use them; else constant-only basis
X_basis = None
use_data_basis = False
if res is not None:
model = res.get('model') if isinstance(res, dict) else res
data_obj = getattr(model, 'data', None) if model is not None else None
if data_obj is not None and hasattr(data_obj, 'get_df') and getattr(data_obj, 'confounders', None) is not None:
df_conf = data_obj.get_df()[list(data_obj.confounders)]
X = df_conf.to_numpy(dtype=float)
if n_basis_funcs is None:
n_basis_funcs = len(data_obj.confounders) + 1
n_covs = min(max(n_basis_funcs - 1, 0), X.shape[1])
if n_covs > 0:
X_sel = X[:, :n_covs]
X_std = (X_sel - np.mean(X_sel, axis=0)) / (np.std(X_sel, axis=0) + 1e-8)
X_basis = np.c_[np.ones(X.shape[0]), X_std]
use_data_basis = True
if X_basis is None:
X_basis = np.ones((len(y), 1))
elif use_data_basis:
# Align X_basis with scrubbed arrays
try:
X_basis = X_basis[mask]
except Exception:
# If masking fails due to shape issues, fall back to constant basis
X_basis = np.ones((len(y), 1))
if score_u == 'ATE':
ortho = orthogonality_derivatives(X_basis, y, d, g0, g1, m, trimming_threshold=trimming_threshold)
elif score_u in ('ATTE', 'ATT'):
p_treated = float(np.mean(d))
ortho = orthogonality_derivatives_atte(X_basis, y, d, g0, m, p_treated, trimming_threshold=trimming_threshold)
else:
raise ValueError("score must be 'ATE' or 'ATTE'")
# Optional fast-path OOS moment check using cached psi if available
oos = None
if psi_a is not None and psi_b is not None and fold_indices is not None:
df_oos, t_fold, t_strict = oos_moment_check_from_psi(psi_a, psi_b, fold_indices, strict=True)
p_fold = float(2 * (1 - stats.norm.cdf(abs(t_fold))))
p_strict = float(2 * (1 - stats.norm.cdf(abs(t_strict)))) if t_strict is not None else float('nan')
oos = {
'fold_results': df_oos,
'tstat_fold_agg': float(t_fold),
'pvalue_fold_agg': float(p_fold),
'tstat_strict': (None if t_strict is None else float(t_strict)),
'pvalue_strict': (float('nan') if t_strict is None else float(2 * (1 - stats.norm.cdf(abs(t_strict))))),
'interpretation': 'Near 0 indicates moment condition holds.'
}
report: Dict[str, Any] = {
'params': {
'score': score_u,
'trimming_threshold': float(trimming_threshold),
},
'orthogonality_derivatives': ortho,
'influence_diagnostics': infl,
}
if oos is not None:
report['oos_moment_test'] = oos
if return_summary:
# Build compact summary rows
max_t_g1 = float(np.nanmax(np.abs(ortho['t_g1'])) if 't_g1' in ortho.columns else np.nan)
max_t_g0 = float(np.nanmax(np.abs(ortho['t_g0'])) if 't_g0' in ortho.columns else np.nan)
max_t_m = float(np.nanmax(np.abs(ortho['t_m'])) if 't_m' in ortho.columns else np.nan)
rows = [
{'metric': 'se_plugin', 'value': infl['se_plugin']},
{'metric': 'psi_p99_over_med', 'value': infl['p99_over_med']},
{'metric': 'psi_kurtosis', 'value': infl['kurtosis']},
{'metric': 'max_|t|_g1', 'value': max_t_g1},
{'metric': 'max_|t|_g0', 'value': max_t_g0},
{'metric': 'max_|t|_m', 'value': max_t_m},
]
if oos is not None:
rows.append({'metric': 'oos_tstat_fold', 'value': oos['tstat_fold_agg']})
rows.append({'metric': 'oos_tstat_strict', 'value': (np.nan if oos['tstat_strict'] is None else oos['tstat_strict'])})
report['summary'] = pd.DataFrame(rows)
report['meta'] = {'n': int(len(y)), 'score': score_u}
# Augment with flags and thresholds, and enrich summary
try:
report = add_score_flags(report)
except Exception:
# be permissive: if flags fail, still return base report
pass
return report
# -------------- Flags and summary augmentation for score diagnostics --------------
def _grade(val: float, warn: float, strong: float, *, larger_is_worse: bool = True) -> str:
"""Map a scalar to GREEN/YELLOW/RED with optional direction."""
if val is None or (isinstance(val, float) and (np.isnan(val) or np.isinf(val))):
return "NA"
v = float(val)
if larger_is_worse:
return "GREEN" if v < warn else ("YELLOW" if v < strong else "RED")
else:
# smaller is worse → invert thresholds
return "GREEN" if v <= warn else ("YELLOW" if v <= strong else "RED")
[docs]
def add_score_flags(rep_score: dict, thresholds: dict | None = None, *, effect_size_guard: float = 0.02, oos_gate: bool = True, se_rule: str | None = None, se_ref: float | None = None) -> dict:
"""
Augment run_score_diagnostics(...) dict with:
- rep['flags'] (per-metric flags)
- rep['thresholds'] (the cutoffs used)
- rep['summary'] with a new 'flag' column
- rep['overall_flag'] (rollup)
Additional logic:
- Practical effect-size guard: if the constant-basis derivative magnitude is tiny
(<= effect_size_guard), then downgrade an orthogonality RED to GREEN (if OOS is GREEN)
or to YELLOW (otherwise). Controlled by `oos_gate`.
- Huge-n relaxation: for very large n (>= 200k), relax tail/kurtosis flags slightly
under specified value gates.
"""
rep = deepcopy(rep_score)
# --- defaults tuned to your style (>= means worse) ---
thr = {
# heavy tails / stability
"tail_ratio_warn": 10.0, # |psi| p99 / median
"tail_ratio_strong": 20.0,
"kurt_warn": 10.0, # normal≈3; allow slack for EIFs
"kurt_strong": 30.0,
# orthogonality t-stats (|t| should be ~0)
"t_warn": 2.0,
"t_strong": 4.0,
# OOS moment t-stat (|t| should be ~0)
"oos_warn": 2.0,
"oos_strong": 3.0,
}
if thresholds:
thr.update(thresholds)
# --- fetch raw pieces robustly ---
infl = rep.get("influence_diagnostics", {}) or {}
ortho = rep.get("orthogonality_derivatives")
oos = rep.get("oos_moment_test")
# Influence metrics
tail_ratio = infl.get("p99_over_med") if isinstance(infl, dict) else None
kurtosis = infl.get("kurtosis") if isinstance(infl, dict) else None
# Max |t| over bases (compute if not already summarized)
def _max_abs(df: pd.DataFrame, col: str) -> float | float:
if isinstance(df, pd.DataFrame) and col in df.columns:
a = np.abs(df[col].to_numpy(dtype=float))
return float(np.nanmax(a)) if a.size else float("nan")
return float("nan")
max_t_g1 = _max_abs(ortho, "t_g1")
max_t_g0 = _max_abs(ortho, "t_g0")
max_t_m = _max_abs(ortho, "t_m")
# OOS t-stats (if available)
oos_t_fold = float(oos.get("tstat_fold_agg")) if isinstance(oos, dict) and "tstat_fold_agg" in oos else float("nan")
oos_t_strict = float(oos.get("tstat_strict")) if isinstance(oos, dict) and "tstat_strict" in oos and oos.get("tstat_strict") is not None else float("nan")
# --- compute flags ---
flags = {
# influence / tails
"psi_tail_ratio": _grade(tail_ratio, thr["tail_ratio_warn"], thr["tail_ratio_strong"], larger_is_worse=True),
"psi_kurtosis": _grade(kurtosis, thr["kurt_warn"], thr["kurt_strong"], larger_is_worse=True),
# orthogonality derivatives
"ortho_max_|t|_g1": _grade(max_t_g1, thr["t_warn"], thr["t_strong"], larger_is_worse=True),
"ortho_max_|t|_g0": _grade(max_t_g0, thr["t_warn"], thr["t_strong"], larger_is_worse=True),
"ortho_max_|t|_m": _grade(max_t_m, thr["t_warn"], thr["t_strong"], larger_is_worse=True),
# OOS moment check (prefer strict if present)
"oos_tstat_fold": _grade(abs(oos_t_fold), thr["oos_warn"], thr["oos_strong"], larger_is_worse=True),
"oos_tstat_strict": _grade(abs(oos_t_strict), thr["oos_warn"], thr["oos_strong"], larger_is_worse=True),
}
# choose one OOS flag as canonical (strict if present)
canonical_oos = "oos_tstat_strict" if flags["oos_tstat_strict"] != "NA" else "oos_tstat_fold"
flags["oos_moment"] = flags[canonical_oos]
# --- attach thresholds used ---
rep["thresholds"] = {
"tail_ratio_warn": thr["tail_ratio_warn"],
"tail_ratio_strong": thr["tail_ratio_strong"],
"kurt_warn": thr["kurt_warn"],
"kurt_strong": thr["kurt_strong"],
"t_warn": thr["t_warn"],
"t_strong": thr["t_strong"],
"oos_warn": thr["oos_warn"],
"oos_strong": thr["oos_strong"],
}
# --- build/update compact summary ---
# Start from existing summary if present; otherwise create it.
summary_rows = [
("se_plugin", infl.get("se_plugin") if isinstance(infl, dict) else np.nan),
("psi_p99_over_med", tail_ratio),
("psi_kurtosis", kurtosis),
("max_|t|_g1", max_t_g1),
("max_|t|_g0", max_t_g0),
("max_|t|_m", max_t_m),
]
if not np.isnan(oos_t_fold):
summary_rows.append(("oos_tstat_fold", oos_t_fold))
if not np.isnan(oos_t_strict):
summary_rows.append(("oos_tstat_strict", oos_t_strict))
# Map summary metric → flag key and rule
summary_flag_rules = {
"psi_p99_over_med": ("psi_tail_ratio", True, thr["tail_ratio_warn"], thr["tail_ratio_strong"]),
"psi_kurtosis": ("psi_kurtosis", True, thr["kurt_warn"], thr["kurt_strong"]),
"max_|t|_g1": ("ortho_max_|t|_g1", True, thr["t_warn"], thr["t_strong"]),
"max_|t|_g0": ("ortho_max_|t|_g0", True, thr["t_warn"], thr["t_strong"]),
"max_|t|_m": ("ortho_max_|t|_m", True, thr["t_warn"], thr["t_strong"]),
"oos_tstat_fold": ("oos_tstat_fold", True, thr["oos_warn"], thr["oos_strong"]),
"oos_tstat_strict": ("oos_tstat_strict", True, thr["oos_warn"], thr["oos_strong"]),
# se_plugin is scale-dependent → mark NA by default
"se_plugin": (None, True, np.nan, np.nan),
}
df_sum = pd.DataFrame(summary_rows, columns=["metric", "value"])
df_sum["flag"] = [
(
_grade(abs(v) if k.startswith("oos_tstat") else v,
summary_flag_rules[k][2], summary_flag_rules[k][3],
larger_is_worse=summary_flag_rules[k][1])
if summary_flag_rules[k][0] is not None else "NA"
)
for k, v in zip(df_sum["metric"], df_sum["value"])
]
# If run_score_diagnostics already produced a summary, align/merge on known metrics.
if "summary" in rep and isinstance(rep["summary"], pd.DataFrame):
# left-join by 'metric' and prefer new flags
cur = rep["summary"].copy()
cur = cur.drop(columns=[c for c in cur.columns if c not in ("metric","value","flag")], errors="ignore")
rep["summary"] = (cur.drop(columns=["flag"], errors="ignore")
.merge(df_sum[["metric","flag"]], on="metric", how="left"))
else:
rep["summary"] = df_sum
# Assign initial flags
rep["flags"] = flags
# --- SE flagging (optional) ---
try:
# thresholds defaults for SE relative gap
thr.setdefault("se_rel_warn", 0.25)
thr.setdefault("se_rel_strong", 0.50)
se = float(infl.get("se_plugin", float("nan"))) if isinstance(infl, dict) else float("nan")
se_flag = "NA"
if se_rule is not None:
rule = str(se_rule).lower()
if rule == "model" and se_ref is not None and np.isfinite(se) and np.isfinite(float(se_ref)):
rel = abs(se - float(se_ref)) / max(float(se_ref), 1e-12)
se_flag = _grade(rel, thr["se_rel_warn"], thr["se_rel_strong"], larger_is_worse=True)
# write back to flags and summary if available
rep["flags"]["se_plugin"] = se_flag
if "summary" in rep and isinstance(rep["summary"], pd.DataFrame):
try:
rep["summary"].loc[rep["summary"]["metric"] == "se_plugin", "flag"] = se_flag
except Exception:
pass
# also expose thresholds used
rep.setdefault("thresholds", {})
rep["thresholds"]["se_rel_warn"] = thr["se_rel_warn"]
rep["thresholds"]["se_rel_strong"] = thr["se_rel_strong"]
except Exception:
# keep permissive behavior if anything goes wrong
pass
# ================= Practical adjustments =================
# 1) Effect-size guard for orthogonality REDs
try:
ortho_df = rep.get("orthogonality_derivatives")
if isinstance(ortho_df, pd.DataFrame):
if "basis" in ortho_df.columns and 0 in set(ortho_df["basis"].tolist()):
row0 = ortho_df.loc[ortho_df["basis"] == 0]
# Absolute derivative magnitudes at constant basis
d_g1_abs = float(abs(row0["d_g1"].values[0])) if "d_g1" in ortho_df.columns else float("nan")
d_g0_abs = float(abs(row0["d_g0"].values[0])) if "d_g0" in ortho_df.columns else float("nan")
d_m_abs = float(abs(row0["d_m" ].values[0])) if "d_m" in ortho_df.columns else float("nan")
else:
d_g1_abs = d_g0_abs = d_m_abs = float("nan")
oos_ok = (rep["flags"].get("oos_moment") == "GREEN") if oos_gate else True
def soften(_):
return "GREEN" if oos_ok else "YELLOW"
if not np.isnan(d_g1_abs) and d_g1_abs <= effect_size_guard and rep["flags"].get("ortho_max_|t|_g1") == "RED":
rep["flags"]["ortho_max_|t|_g1"] = soften(rep["flags"]["ortho_max_|t|_g1"])
if not np.isnan(d_g0_abs) and d_g0_abs <= effect_size_guard and rep["flags"].get("ortho_max_|t|_g0") == "RED":
rep["flags"]["ortho_max_|t|_g0"] = soften(rep["flags"]["ortho_max_|t|_g0"])
if not np.isnan(d_m_abs) and d_m_abs <= effect_size_guard and rep["flags"].get("ortho_max_|t|_m") == "RED":
rep["flags"]["ortho_max_|t|_m"] = soften(rep["flags"]["ortho_max_|t|_m"])
except Exception:
pass
# 2) Huge-n synthetic run relaxation for tails
try:
n = int(rep.get("meta", {}).get("n", 0))
if n >= 200_000:
# Tail ratio: YELLOW→GREEN if value ≤ 12.0
try:
val_tail = float(df_sum.loc[df_sum["metric"] == "psi_p99_over_med", "value"].iloc[0])
except Exception:
val_tail = tail_ratio if tail_ratio is not None else float("nan")
if rep["flags"].get("psi_tail_ratio") == "YELLOW" and np.isfinite(val_tail) and val_tail <= 12.0:
rep["flags"]["psi_tail_ratio"] = "GREEN"
# Kurtosis: RED→YELLOW if value ≤ 50.0
try:
val_kurt = float(df_sum.loc[df_sum["metric"] == "psi_kurtosis", "value"].iloc[0])
except Exception:
val_kurt = kurtosis if kurtosis is not None else float("nan")
if rep["flags"].get("psi_kurtosis") == "RED" and np.isfinite(val_kurt) and val_kurt <= 50.0:
rep["flags"]["psi_kurtosis"] = "YELLOW"
except Exception:
pass
# 3) Recompute overall and sync summary flags with final per-metric flags
order = {"GREEN": 0, "YELLOW": 1, "RED": 2, "NA": -1}
worst = max((order.get(f, -1) for f in rep["flags"].values()), default=-1)
inv = {v: k for k, v in order.items()}
rep["overall_flag"] = inv.get(worst, "NA")
# Update summary flags in-place to reflect final flags
map_flags = {
"psi_p99_over_med": rep["flags"].get("psi_tail_ratio", "NA"),
"psi_kurtosis": rep["flags"].get("psi_kurtosis", "NA"),
"max_|t|_g1": rep["flags"].get("ortho_max_|t|_g1", "NA"),
"max_|t|_g0": rep["flags"].get("ortho_max_|t|_g0", "NA"),
"max_|t|_m": rep["flags"].get("ortho_max_|t|_m", "NA"),
"oos_tstat_fold": rep["flags"].get("oos_tstat_fold", "NA"),
"oos_tstat_strict": rep["flags"].get("oos_tstat_strict", "NA"),
"se_plugin": rep["flags"].get("se_plugin", "NA"),
}
try:
rep["summary"]["flag"] = rep["summary"]["metric"].map(map_flags).fillna(rep["summary"].get("flag"))
except Exception:
# If mapping fails, keep existing flags in summary
pass
return rep