Source code for causalis.inference.atte.dml_atte

"""
Simple IRM-based implementation for estimating ATT (Average Treatment effect on the Treated).

This module provides a function dml_att_s to estimate ATT using our internal
DoubleML-style IRM estimator that consumes CausalData directly (not DoubleML).
"""
from __future__ import annotations

import warnings
from typing import Any, Dict, Optional

import numpy as np
import pandas as pd

from causalis.data.causaldata import CausalData
from causalis.inference.estimators import IRM


[docs] def dml_atte( data: CausalData, ml_g: Optional[Any] = None, ml_m: Optional[Any] = None, n_folds: int = 4, n_rep: int = 1, confidence_level: float = 0.95, normalize_ipw: bool = False, trimming_rule: str = "truncate", trimming_threshold: float = 1e-2, random_state: Optional[int] = None, store_diagnostic_data: bool = True, ) -> Dict[str, Any]: """ Estimate average treatment effect on the treated (ATT) using the internal IRM estimator. Parameters ---------- data : CausalData The CausalData object containing treatment, outcome, and confounders. ml_g : estimator, optional Learner for g(D,X)=E[Y|X,D]. If outcome is binary and learner is classifier, predict_proba will be used; otherwise predict(). ml_m : classifier, optional Learner for m(X)=E[D|X] (propensity). If None, a CatBoostClassifier is used. n_folds : int, default 5 Number of folds for cross-fitting. n_rep : int, default 1 Number of repetitions (currently only 1 supported by IRM). confidence_level : float, default 0.95 Confidence level for CI in (0,1). normalize_ipw : bool, default False Whether to normalize IPW terms within the score. trimming_rule : str, default "truncate" Trimming approach for propensity (only "truncate" supported). trimming_threshold : float, default 1e-2 Trimming threshold for propensity. random_state : int, optional Random seed for fold creation. Returns ------- Dict[str, Any] Keys: coefficient, std_error, p_value, confidence_interval, model Notes ----- By default, this function stores a comprehensive 'diagnostic_data' dictionary in the result. You can disable this by setting store_diagnostic_data=False. """ # Basic validations similar to existing wrappers if data.treatment is None: raise ValueError("CausalData object must have a treatment variable defined") if data.target is None: raise ValueError("CausalData object must have a outcome variable defined") if not data.confounders: raise ValueError("CausalData object must have confounders variables defined") if not 0 < confidence_level < 1: raise ValueError("confidence_level must be between 0 and 1 (exclusive)") # Defaults for learners: lazy import CatBoost only if needed if ml_g is None or ml_m is None: try: from catboost import CatBoostRegressor, CatBoostClassifier # type: ignore except ImportError as e: raise ImportError( "CatBoost is required for default learners. Install 'catboost' or provide ml_g and ml_m." ) from e if ml_g is None: ml_g = CatBoostRegressor( thread_count=-1, verbose=False, allow_writing_files=False, ) if ml_m is None: ml_m = CatBoostClassifier( thread_count=-1, verbose=False, allow_writing_files=False, ) # Normalize treatment to 0/1 if boolean to keep CausalData consistent for IRM df = data.get_df().copy() tname = data.treatment.name if df[tname].dtype == bool: df[tname] = df[tname].astype(int) data = CausalData(df=df, treatment=tname, outcome=data.target.name, confounders=data.confounders) else: uniq = np.unique(df[tname].values) if not np.array_equal(np.sort(uniq), np.array([0, 1])) and not np.array_equal(np.sort(uniq), np.array([0.0, 1.0])): raise ValueError(f"Treatment must be binary 0/1 or boolean; found {uniq}.") # Fit IRM with ATT (ATTE score) irm = IRM(data, ml_g=ml_g, ml_m=ml_m, n_folds=n_folds, n_rep=n_rep, score="ATTE", normalize_ipw=normalize_ipw, trimming_rule=trimming_rule, trimming_threshold=trimming_threshold, random_state=random_state) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) irm.fit() # Confidence interval ci_df = irm.confint(level=confidence_level) if isinstance(ci_df, pd.DataFrame): ci_lower = float(ci_df.iloc[0, 0]) ci_upper = float(ci_df.iloc[0, 1]) else: arr = np.asarray(ci_df) ci_lower = float(arr[0, 0]) ci_upper = float(arr[0, 1]) # Collect diagnostic data for overlap/weight and score checks (optional) diagnostic_data = None if store_diagnostic_data: df_diag = data.get_df() y_diag = df_diag[data.target.name].to_numpy(dtype=float) d_diag = df_diag[data.treatment.name].to_numpy().astype(int) x_diag = df_diag[data.confounders].to_numpy(dtype=float) p1_diag = float(np.mean(d_diag)) diagnostic_data = { "m_hat": np.asarray(irm.m_hat_, dtype=float), "g0_hat": np.asarray(irm.g0_hat_, dtype=float), "g1_hat": np.asarray(irm.g1_hat_, dtype=float), "y": y_diag, "d": d_diag, "x": x_diag, "psi": np.asarray(irm.psi_, dtype=float), "psi_a": np.asarray(irm.psi_a_, dtype=float), "psi_b": np.asarray(irm.psi_b_, dtype=float), "folds": np.asarray(getattr(irm, "folds_", None), dtype=int) if getattr(irm, "folds_", None) is not None else None, "score": "ATTE", "normalize_ipw": bool(normalize_ipw), "trimming_threshold": float(trimming_threshold), "p1": p1_diag, } return { "coefficient": float(irm.coef[0]), "std_error": float(irm.se[0]), "p_value": float(irm.pvalues[0]), "confidence_interval": (ci_lower, ci_upper), "model": irm, "diagnostic_data": diagnostic_data, }