Source code for causalis.inference.ate.dml_ate

"""
IRM implementation for estimating average treatment effects.

This module provides a function dml_ate to estimate average treatment effects using
our internal DoubleML-style IRM estimator that consumes CausalData directly.
"""
from __future__ import annotations

import warnings
from typing import Any, Dict, Optional

import numpy as np
import pandas as pd
from catboost import CatBoostClassifier, CatBoostRegressor

from causalis.data.causaldata import CausalData
from causalis.inference.estimators import IRM


[docs] def dml_ate( data: CausalData, ml_g: Optional[Any] = None, ml_m: Optional[Any] = None, n_folds: int = 5, n_rep: int = 1, score: str = "ATE", confidence_level: float = 0.95, normalize_ipw: bool = False, trimming_rule: str = "truncate", trimming_threshold: float = 1e-2, random_state: Optional[int] = None, store_diagnostic_data: bool = True, ) -> Dict[str, Any]: """ Estimate average treatment effects using the internal IRM estimator. Parameters ---------- data : CausalData The causaldata object containing treatment, outcome, and confounders. ml_g : estimator, optional Learner for g(D,X)=E[Y|X,D]. If outcome is binary and learner is classifier, predict_proba will be used; otherwise predict(). ml_m : classifier, optional Learner for m(X)=E[D|X] (propensity). If None, a CatBoostClassifier is used. n_folds : int, default 5 Number of folds for cross-fitting. n_rep : int, default 1 Number of repetitions (currently only 1 supported by IRM). score : {"ATE","ATTE"}, default "ATE" Target estimand. confidence_level : float, default 0.95 Confidence level for CI in (0,1). normalize_ipw : bool, default False Whether to normalize IPW terms within the score. trimming_rule : str, default "truncate" Trimming approach for propensity (only "truncate" supported). trimming_threshold : float, default 1e-2 Trimming threshold for propensity. random_state : int, optional Random seed for fold creation. Returns ------- Dict[str, Any] Keys: coefficient, std_error, p_value, confidence_interval, model Notes ----- By default, this function stores a comprehensive 'diagnostic_data' dictionary in the result. You can disable this by setting store_diagnostic_data=False. """ # Basic validations mirroring dml_ate if data.treatment is None: raise ValueError("CausalData object must have a treatment variable defined") if data.target is None: raise ValueError("CausalData object must have a outcome variable defined") if not data.confounders: raise ValueError("CausalData object must have confounders variables defined") # Check confidence level if not 0 < confidence_level < 1: raise ValueError("confidence_level must be between 0 and 1 (exclusive)") # Defaults for learners if ml_g is None: ml_g = CatBoostRegressor( thread_count=-1, verbose=False, allow_writing_files=False, ) if ml_m is None: ml_m = CatBoostClassifier( thread_count=-1, verbose=False, allow_writing_files=False, ) # Normalize treatment to 0/1 if boolean (IRM will also check) df = data.get_df().copy() tname = data.treatment.name if df[tname].dtype == bool: df[tname] = df[tname].astype(int) # Construct a new CausalData with coerced dtype to keep internal consistency data = CausalData(df=df, treatment=tname, outcome=data.target.name, confounders=data.confounders) else: # Ensure binary encoding uniq = np.unique(df[tname].values) if not np.array_equal(np.sort(uniq), np.array([0, 1])) and not np.array_equal(np.sort(uniq), np.array([0.0, 1.0])): raise ValueError(f"Treatment must be binary 0/1 or boolean; found {uniq}.") # Fit IRM irm = IRM(data, ml_g=ml_g, ml_m=ml_m, n_folds=n_folds, n_rep=n_rep, score=score, normalize_ipw=normalize_ipw, trimming_rule=trimming_rule, trimming_threshold=trimming_threshold, random_state=random_state) # Suppress any benign warnings during fit similar to dml_ate with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) irm.fit() # Confidence interval ci_df = irm.confint(level=confidence_level) # Robust extraction of CI values if isinstance(ci_df, pd.DataFrame): # Expect exactly two columns ci_lower = float(ci_df.iloc[0, 0]) ci_upper = float(ci_df.iloc[0, 1]) else: # Fallback if implementation changes arr = np.asarray(ci_df) ci_lower = float(arr[0, 0]) ci_upper = float(arr[0, 1]) # Collect diagnostic data for overlap/weight and score checks (optional) diagnostic_data = None if store_diagnostic_data: df_diag = data.get_df() y_diag = df_diag[data.target.name].to_numpy(dtype=float) d_diag = df_diag[data.treatment.name].to_numpy().astype(int) x_diag = df_diag[data.confounders].to_numpy(dtype=float) p1_diag = float(np.mean(d_diag)) diagnostic_data = { "m_hat": np.asarray(irm.m_hat_, dtype=float), "g0_hat": np.asarray(irm.g0_hat_, dtype=float), "g1_hat": np.asarray(irm.g1_hat_, dtype=float), "y": y_diag, "d": d_diag, "x": x_diag, "psi": np.asarray(irm.psi_, dtype=float), "psi_a": np.asarray(irm.psi_a_, dtype=float), "psi_b": np.asarray(irm.psi_b_, dtype=float), "folds": np.asarray(getattr(irm, "folds_", None), dtype=int) if getattr(irm, "folds_", None) is not None else None, "score": str(score).upper(), "normalize_ipw": bool(normalize_ipw), "trimming_threshold": float(trimming_threshold), "p1": p1_diag, } return { "coefficient": float(irm.coef[0]), "std_error": float(irm.se[0]), "p_value": float(irm.pvalues[0]), "confidence_interval": (ci_lower, ci_upper), "model": irm, "diagnostic_data": diagnostic_data, }