"""
Group Average Treatment Effect (GATE) estimation using DoubleML orthogonal signals.
This module provides a function that, given a (possibly filtered) CausalData object,
fits a DoubleML IRM model, computes per-observation CATEs (orthogonal signals),
forms groups (by default CATE quintiles), and returns group-level estimates
(theta), standard errors, p-values, and confidence intervals.
It prefers DoubleML's native `gate()` and `confint()` methods if available;
otherwise falls back to a simple normal approximation using the group mean of
orthogonal signals and its standard error.
"""
from typing import Any, Optional, Union
import numpy as np
import pandas as pd
from scipy.stats import norm
import warnings
import doubleml as dml
from catboost import CatBoostRegressor, CatBoostClassifier
from causalkit.data.causaldata import CausalData
def _fit_doubleml_irm(
data: CausalData,
ml_g: Optional[Any],
ml_m: Optional[Any],
n_folds: int,
n_rep: int,
):
if ml_g is None:
ml_g = CatBoostRegressor(iterations=100, depth=5, min_data_in_leaf=2, thread_count=-1, verbose=False)
if ml_m is None:
ml_m = CatBoostClassifier(iterations=100, depth=5, min_data_in_leaf=2, thread_count=-1, verbose=False)
df = data.get_df()
dml_data = dml.DoubleMLData(
df,
y_col=data._target,
d_cols=data._treatment,
x_cols=data._confounders,
)
obj = dml.DoubleMLIRM(
dml_data,
ml_g=ml_g,
ml_m=ml_m,
n_folds=n_folds,
n_rep=n_rep,
score="ATE",
)
# Suppress scikit-learn FutureWarning about 'force_all_finite' rename during fit
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*'force_all_finite' was renamed to 'ensure_all_finite'.*",
category=FutureWarning,
)
obj.fit()
return obj, df
def _extract_signals(obj: Any) -> np.ndarray:
# Prefer private attribute if present
if hasattr(obj, "_orthogonal_signals"):
sig = np.asarray(obj._orthogonal_signals).reshape(-1)
return sig
# Fallback via psi_elements if available
if hasattr(obj, "psi_elements") and isinstance(obj.psi_elements, dict) and "psi_b" in obj.psi_elements:
arr = np.asarray(obj.psi_elements["psi_b"]) # shape could include folds/reps
if arr.ndim == 1:
return arr
axes = tuple(range(1, arr.ndim))
return np.nanmean(arr, axis=axes)
raise AttributeError("Could not extract orthogonal signals from DoubleMLIRM object.")
[docs]
def gate_esimand(
data: CausalData,
groups: Optional[Union[pd.Series, pd.DataFrame]] = None,
n_groups: int = 5,
ml_g: Optional[Any] = None,
ml_m: Optional[Any] = None,
n_folds: int = 5,
n_rep: int = 1,
confidence_level: float = 0.95,
) -> pd.DataFrame:
"""
Estimate Group Average Treatment Effects (GATEs) by grouping observations
using CATE-based quantiles unless custom groups are provided.
Parameters
----------
data : CausalData
The (possibly filtered) CausalData object. Filtering should be done by
subsetting data.df before constructing CausalData, or by preparing a
filtered CausalData instance.
groups : pd.Series or pd.DataFrame, optional
Group assignments per observation. If a Series is passed, it will be
used as a single column named 'q'. If a DataFrame, it should contain a
single column specifying groups. If None, groups are formed by
pd.qcut over the in-sample CATEs into `n_groups` quantiles labeled 0..n_groups-1.
n_groups : int, default 5
Number of quantile groups if `groups` is None.
ml_g, ml_m, n_folds, n_rep :
Learners and DoubleML cross-fitting controls (as in ATE/ATT).
confidence_level : float, default 0.95
Confidence level for two-sided normal-approximation intervals.
Returns
-------
pd.DataFrame
A DataFrame with columns:
- group: group label
- n: group size
- theta: estimated group average treatment effect
- std_error: standard error (normal approx if fallback path)
- p_value: two-sided p-value for H0: theta=0
- ci_lower, ci_upper: confidence interval bounds
"""
# Validate inputs similar to other inference functions
if data.treatment is None:
raise ValueError("CausalData object must have a treatment variable defined")
if data.target is None:
raise ValueError("CausalData object must have a outcome variable defined")
if data.confounders is None:
raise ValueError("CausalData object must have confounders variables defined")
# Binary treatment check {0,1}
unique_treatments = pd.Series(data.treatment).unique()
if len(unique_treatments) != 2:
raise ValueError("Treatment variable must be binary (have exactly 2 unique values)")
if set(unique_treatments) != {0, 1}:
raise ValueError("Treatment variable must have values 0 and 1")
if not 0 < confidence_level < 1:
raise ValueError("confidence_level must be between 0 and 1 (exclusive)")
# Fit DoubleML model and extract signals
obj, df = _fit_doubleml_irm(data, ml_g, ml_m, n_folds, n_rep)
cate_hat = _extract_signals(obj)
# Prepare groups
if groups is None:
# Build quantile groups from CATEs
try:
q = pd.qcut(cate_hat, n_groups, labels=False, duplicates="drop")
except ValueError:
# In case of too many ties, fall back to cut on unique bins
q = pd.Series(pd.cut(cate_hat, n_groups, labels=False, duplicates="drop"))
groups_df = pd.DataFrame({"q": q})
else:
if isinstance(groups, pd.Series):
groups_df = groups.to_frame(name="q")
elif isinstance(groups, pd.DataFrame):
if groups.shape[1] != 1:
raise ValueError("groups DataFrame must have exactly one column")
groups_df = groups.copy()
groups_df.columns = ["q"]
else:
raise TypeError("groups must be a pandas Series or DataFrame if provided")
if len(groups_df) != len(df):
raise ValueError("groups length must match number of observations in data")
# Try DoubleML native gate() if available
if hasattr(obj, "gate"):
try:
gate_obj = obj.gate(groups=groups_df)
# confint for simultaneous intervals (if available)
ci_df = None
try:
ci_df = gate_obj.confint(level=confidence_level)
except Exception:
ci_df = None
# Extract estimates
# gate_obj might expose .coef, .se, .pval similar to DoubleML objects
if hasattr(gate_obj, "coef") and hasattr(gate_obj, "se") and hasattr(gate_obj, "pval"):
theta = np.asarray(gate_obj.coef).reshape(-1)
se = np.asarray(gate_obj.se).reshape(-1)
pval = np.asarray(gate_obj.pval).reshape(-1)
# Map groups to unique sorted labels
labels, counts = np.unique(groups_df["q"].to_numpy(), return_counts=True)
z = norm.ppf(1 - (1 - confidence_level) / 2)
ci_lower = theta - z * se
ci_upper = theta + z * se
if isinstance(ci_df, pd.DataFrame) and ci_df.shape[0] == theta.shape[0]:
# prefer provided CIs if shapes align
ci_lower = ci_df.iloc[:, 0].to_numpy()
ci_upper = ci_df.iloc[:, 1].to_numpy()
out = pd.DataFrame(
{
"group": labels,
"n": counts,
"theta": theta,
"std_error": se,
"p_value": pval,
"ci_lower": ci_lower,
"ci_upper": ci_upper,
}
)
return out.sort_values("group").reset_index(drop=True)
except Exception:
# Fall back to manual computation below
pass
# Manual group-wise estimation from orthogonal signals
z = norm.ppf(1 - (1 - confidence_level) / 2)
groups_df = groups_df.reset_index(drop=True)
sig_ser = pd.Series(cate_hat).reset_index(drop=True)
df_g = pd.concat([groups_df, sig_ser.rename("signal")], axis=1)
agg = df_g.groupby("q")["signal"].agg(["mean", "count", "std"]).reset_index()
# Handle std = NaN (groups of size 1) by setting SE to inf (pval=1, CI wide)
se = (agg["std"] / np.sqrt(agg["count"].clip(lower=1))).to_numpy()
theta = agg["mean"].to_numpy()
with np.errstate(divide="ignore", invalid="ignore"):
zstat = np.where(se > 0, theta / se, 0.0)
pval = 2 * (1 - norm.cdf(np.abs(zstat)))
ci_lower = theta - z * se
ci_upper = theta + z * se
out = pd.DataFrame(
{
"group": agg["q"].to_numpy(),
"n": agg["count"].to_numpy(),
"theta": theta,
"std_error": se,
"p_value": pval,
"ci_lower": ci_lower,
"ci_upper": ci_upper,
}
)
return out.sort_values("group").reset_index(drop=True)