causalkit.inference.causalforestdml#
- causalkit.inference.causalforestdml(data, model_y=None, model_t=None, n_estimators=100, max_depth=None, min_samples_leaf=5, cv=5, n_jobs=-1, random_state=None, confidence_level=0.95)[source]#
Estimate average treatment effects using EconML’s CausalForestDML.
- Parameters:
data (CausalData) – The causaldata object containing treatment, target, and confounders variables.
model_y (estimator, optional) – The model for fitting the outcome variable. If None, a CatBoostRegressor configured to use all CPU cores is used.
model_t (estimator, optional) – The model for fitting the treatment variable. If None, a CatBoostRegressor configured to use all CPU cores is used.
n_estimators (int, default 100) – Number of trees in the forest.
max_depth (int, optional) – Maximum depth of the trees. If None, nodes are expanded until all leaves are pure or contain less than min_samples_leaf samples.
min_samples_leaf (int, default 5) – Minimum number of samples required to be at a leaf node.
cv (int, default 5) – Number of folds for cross-fitting.
n_jobs (int, default -1) – Number of jobs to run in parallel. -1 means using all processors.
random_state (int, optional) – Controls the randomness of the estimator.
confidence_level (float, default 0.95) – The confidence level for calculating confidence intervals (between 0 and 1).
- Returns:
A dictionary containing: - coefficient: The estimated average treatment effect - std_error: The standard error of the estimate - p_value: The p-value for the null hypothesis that the effect is zero - confidence_interval: Tuple of (lower, upper) bounds for the confidence interval - model: The fitted CausalForestDML object
- Return type:
Dict[str, Any]
- Raises:
ValueError – If the causaldata object doesn’t have treatment, target, and confounders variables defined, or if the treatment variable is not binary.
Examples
>>> from causalkit.data import generate_rct_data >>> from causalkit.data import CausalData >>> from causalkit.inference.ate import causalforestdml >>> >>> # Generate data >>> df = generate_rct_data() >>> >>> # Create causaldata object >>> ck = CausalData( ... df=df, ... outcome='outcome', ... treatment='treatment', ... confounders=['age', 'invited_friend'] ... ) >>> >>> # Estimate ATE using CausalForestDML >>> results = causalforestdml(ck) >>> print(f"ATE: {results['coefficient']:.4f}") >>> print(f"Standard Error: {results['std_error']:.4f}") >>> print(f"P-value: {results['p_value']:.4f}") >>> print(f"Confidence Interval: {results['confidence_interval']}")