causalkit.inference.dml_att#
- causalkit.inference.dml_att(data, ml_g=None, ml_m=None, n_folds=5, n_rep=1, confidence_level=0.95)[source]#
Estimate average treatment effects on the treated using DoubleML’s interactive regression model (IRM).
- Parameters:
data (CausalData) – The causaldata object containing treatment, target, and confounders variables.
ml_g (estimator, optional) – A machine learner implementing
fit()
andpredict()
methods for the nuisance function g_0(D,X) = E[Y|X,D]. If None, a CatBoostRegressor configured to use all CPU cores is used.ml_m (classifier, optional) – A machine learner implementing
fit()
andpredict_proba()
methods for the nuisance function m_0(X) = E[D|X]. If None, a CatBoostClassifier configured to use all CPU cores is used.n_folds (int, default 5) – Number of folds for cross-fitting.
n_rep (int, default 1) – Number of repetitions for the sample splitting.
confidence_level (float, default 0.95) – The confidence level for calculating confidence intervals (between 0 and 1).
- Returns:
A dictionary containing: - coefficient: The estimated average treatment effect on the treated - std_error: The standard error of the estimate - p_value: The p-value for the null hypothesis that the effect is zero - confidence_interval: Tuple of (lower, upper) bounds for the confidence interval - model: The fitted DoubleMLIRM object
- Return type:
Dict[str, Any]
- Raises:
ValueError – If the causaldata object doesn’t have treatment, target, and confounders variables defined, or if the treatment variable is not binary.
Examples
>>> from causalkit.data import generate_rct_data >>> from causalkit.data import CausalData >>> from causalkit.inference.att import dml_att >>> >>> # Generate data >>> df = generate_rct_data() >>> >>> # Create causaldata object >>> ck = CausalData( ... df=df, ... outcome='outcome', ... treatment='treatment', ... confounders=['age', 'invited_friend'] ... ) >>> >>> # Estimate ATT using DoubleML >>> results = dml_att(ck) >>> print(f"ATT: {results['coefficient']:.4f}") >>> print(f"Standard Error: {results['std_error']:.4f}") >>> print(f"P-value: {results['p_value']:.4f}") >>> print(f"Confidence Interval: {results['confidence_interval']}")