causalkit.inference.dml#

causalkit.inference.dml(data, ml_g=None, ml_m=None, n_folds=5, n_rep=1, score='ATE', confidence_level=0.95)#

Estimate average treatment effects using DoubleML’s interactive regression model (IRM).

Parameters:
  • data (CausalData) – The causaldata object containing treatment, target, and confounders variables.

  • ml_g (estimator, optional) – A machine learner implementing fit() and predict() methods for the nuisance function g_0(D,X) = E[Y|X,D]. If None, a CatBoostRegressor configured to use all CPU cores is used.

  • ml_m (classifier, optional) – A machine learner implementing fit() and predict_proba() methods for the nuisance function m_0(X) = E[D|X]. If None, a CatBoostClassifier configured to use all CPU cores is used.

  • n_folds (int, default 5) – Number of folds for cross-fitting.

  • n_rep (int, default 1) – Number of repetitions for the sample splitting.

  • score (str, default "ATE") – A str (“ATE” or “ATTE”) specifying the score function.

  • confidence_level (float, default 0.95) – The confidence level for calculating confidence intervals (between 0 and 1).

Returns:

A dictionary containing: - coefficient: The estimated average treatment effect - std_error: The standard error of the estimate - p_value: The p-value for the null hypothesis that the effect is zero - confidence_interval: Tuple of (lower, upper) bounds for the confidence interval - model: The fitted DoubleMLIRM object

Return type:

Dict[str, Any]

Raises:

ValueError – If the causaldata object doesn’t have treatment, target, and confounders variables defined, or if the treatment variable is not binary.

Examples

>>> from causalkit.data import generate_rct_data
>>> from causalkit.data import CausalData
>>> from causalkit.inference.ate import dml_ate
>>>
>>> # Generate data
>>> df = generate_rct_data()
>>>
>>> # Create causaldata object
>>> ck = CausalData(
...     df=df,
...     outcome='outcome',
...     treatment='treatment',
...     confounders=['age', 'invited_friend']
... )
>>>
>>> # Estimate ATE using DoubleML
>>> results = dml_ate(ck)
>>> print(f"ATE: {results['coefficient']:.4f}")
>>> print(f"Standard Error: {results['std_error']:.4f}")
>>> print(f"P-value: {results['p_value']:.4f}")
>>> print(f"Confidence Interval: {results['confidence_interval']}")