Source code for causalis.inference.gate.gate_esimand

"""
Group Average Treatment Effect (GATE) estimation using local DML IRM and BLP.
"""

from typing import Any, Optional, Union

import numpy as np
import pandas as pd
from catboost import CatBoostRegressor, CatBoostClassifier

from causalis.data.causaldata import CausalData
from causalis.inference.estimators.irm import IRM


[docs] def gate_esimand( data: CausalData, groups: Optional[Union[pd.Series, pd.DataFrame]] = None, n_groups: int = 5, ml_g: Optional[Any] = None, ml_m: Optional[Any] = None, n_folds: int = 5, n_rep: int = 1, confidence_level: float = 0.95, ) -> pd.DataFrame: """ Estimate Group Average Treatment Effects (GATEs). If `groups` is None, observations are grouped by quantiles of the plugin CATE proxy (g1_hat - g0_hat). """ # 1. Define defaults if ml_g is None: ml_g = CatBoostRegressor(thread_count=-1, verbose=False, allow_writing_files=False) if ml_m is None: ml_m = CatBoostClassifier(thread_count=-1, verbose=False, allow_writing_files=False) # 2. Fit IRM model # We use the local IRM implementation which exposes .gate() irm = IRM( data=data, ml_g=ml_g, ml_m=ml_m, n_folds=n_folds, n_rep=n_rep, score="ATE" # GATE uses ATE orthogonal signal ) irm.fit() # 3. Prepare groups if groups is None: # Construct groups based on CATE proxy # We use the plug-in estimator g1 - g0 for sorting to avoid overfitting to Y noise if irm.g1_hat_ is None or irm.g0_hat_ is None: raise RuntimeError("IRM model did not produce g1/g0 estimates.") cate_proxy = irm.g1_hat_ - irm.g0_hat_ # Create quantile groups try: q = pd.qcut(cate_proxy, n_groups, labels=False, duplicates="drop") except ValueError: # Fallback for ties q = pd.cut(cate_proxy, n_groups, labels=False, duplicates="drop") # Create a DataFrame with a clear name groups_df = pd.DataFrame({"Group": q}) else: groups_df = groups.copy() if isinstance(groups_df, pd.Series): groups_df = groups_df.to_frame() # 4. Run GATE via BLP # This returns a fitted BLP object gate_model = irm.gate(groups_df, level=confidence_level) # 5. Format results # Retrieve summary stats and confidence intervals summary = gate_model.summary ci_df = gate_model.confint(level=confidence_level) # Calculate group sizes (n) from the basis used in BLP # basis columns correspond to the groups counts = gate_model.basis.sum(axis=0).astype(int) # Construct final DataFrame # Note: summary index matches counts index and ci_df index results = pd.DataFrame({ "group": summary.index, "n": counts.values, "theta": summary["coef"].values, "std_error": summary["std err"].values, "p_value": summary["P>|t|"].values, # CI columns in ci_df are [lower, effect, upper] "ci_lower": ci_df.iloc[:, 0].values, "ci_upper": ci_df.iloc[:, 2].values, }) # Sort by group name/label for consistency results = results.sort_values("group").reset_index(drop=True) return results