DML ATE Example#

This notebook covers scenario:

Is RCT

Treatment

Outcome

EDA

Estimands

Refutation

Observational

Binary

Continuous

Yes

ATE

Yes

We will estimate Average Treatment Effect (ATE) of binary treatment on continuous outcome. It shows explonatary data analysis and refutation tests

Generate data#

Let’s generate data of how feature (Treatment) impact on ARPU (Outcome) with linear effect (theta) = 1.8

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from causalkit.data import CausalDatasetGenerator, CausalData

# Reproducibility
np.random.seed(42)

confounder_specs = [
    {"name": "tenure_months", "dist": "normal", "mu": 24, "sd": 12},
    {"name": "avg_sessions_week", "dist": "normal", "mu": 5, "sd": 2},
    {"name": "spend_last_month", "dist": "uniform", "a": 0, "b": 200},
    {"name": "premium_user", "dist": "bernoulli", "p": 0.25},
    {"name": "urban_resident", "dist": "bernoulli", "p": 0.60},
]

# Causal effect and noise
theta = 1.8  # ATE: +1.8 ARPU units if new_feature = 1
sigma_y = 3.5  # ARPU noise std
target_t_rate = 0.35  # ~35% treated

# Effects of confounders on ARPU (baseline, additive)
# Order: tenure_months, avg_sessions_week, spend_last_month, premium_user, urban_resident
beta_y = np.array([
    0.05,  # tenure_months: small positive effect
    0.40,  # avg_sessions_week: strong positive effect
    0.02,  # spend_last_month: recent spend correlates with ARPU
    2.00,  # premium_user: premium users have higher ARPU
    1.00,  # urban_resident: urban users slightly higher ARPU
], dtype=float)

# Effects of confounders on treatment assignment (log-odds scale)
beta_t = np.array([
    0.015,  # tenure_months
    0.10,  # avg_sessions_week
    0.002,  # spend_last_month
    0.75,  # premium_user
    0.30,  # urban_resident: more likely to get the feature
], dtype=float)

gen = CausalDatasetGenerator(
    theta=theta,
    outcome_type="continuous",
    sigma_y=sigma_y,
    target_t_rate=target_t_rate,
    seed=42,
    confounder_specs=confounder_specs,
    beta_y=beta_y,
    beta_t=beta_t,
)


# Create dataset
causal_data = gen.to_causal_data(
    n=10000,
    confounders = [
    "tenure_months",
    "avg_sessions_week",
    "spend_last_month",
    "premium_user",
    "urban_resident",
]
)

# Show first few rows
causal_data.df.head()
y t tenure_months avg_sessions_week spend_last_month premium_user urban_resident
0 5.927714 1.0 27.656605 5.352554 72.552568 1.0 0.0
1 11.122008 1.0 11.520191 6.798247 188.481287 1.0 0.0
2 10.580393 1.0 33.005414 2.055459 51.040440 0.0 1.0
3 6.982844 1.0 35.286777 4.429404 166.992239 0.0 1.0
4 10.899381 0.0 0.587578 6.658307 179.371126 0.0 0.0

EDA#

from causalkit.eda import CausalEDA
eda = CausalEDA(causal_data)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 from causalkit.eda import CausalEDA
      2 eda = CausalEDA(causal_data)

File ~/work/CausalKit/CausalKit/causalkit/eda/__init__.py:1
----> 1 from .eda import CausalEDA, CausalDataLite
      3 __all__ = ["CausalEDA", "CausalDataLite"]

File ~/work/CausalKit/CausalKit/causalkit/eda/eda.py:37
     35 from sklearn.compose import ColumnTransformer
     36 from sklearn.pipeline import Pipeline
---> 37 from catboost import CatBoostClassifier, CatBoostRegressor
     38 import matplotlib.pyplot as plt
     41 class PropensityModel:

ModuleNotFoundError: No module named 'catboost'

General dataset information#

Let’s see how outcome differ between clients who recieved the feature and didn’t

# shape of data
eda.data_shape()
{'n_rows': 10000, 'n_columns': 7}
# 1) Outcome statistics by treatment
eda.outcome_stats()
count mean std min p10 p25 median p75 p90 max
treatment
0.0 6519 6.011340 3.925893 -9.866447 1.012888 3.426894 6.030504 8.676781 10.952265 20.770359
1.0 3481 8.553615 3.934260 -5.507710 3.525225 5.909860 8.566860 11.216054 13.602967 21.377687
# 2) Outcome distribution by treatment (hist + boxplot)
fig1, fig2 = eda.outcome_plots()
plt.show()
../_images/fa149fa07c0df489aa5f7b735ce2bfee005068f3fd99c6b50e4f14ca8aca23f4.png ../_images/a168dc26ca1011c1dcf5048a0961c9382dddc1e1fc85336a7f7cde711e1788df.png

Propensity#

Now let’s examine how propensity score differ treatments

# Shows means of confounders for control/treated groups, absolute differences, and SMD values
confounders_balance_df = eda.confounders_means()
display(confounders_balance_df)
mean_t_0 mean_t_1 abs_diff smd
confounders
premium_user 0.198343 0.347889 0.149545 0.340423
avg_sessions_week 4.905841 5.293265 0.387423 0.193959
tenure_months 23.170108 25.200827 2.030718 0.168676
urban_resident 0.578003 0.643781 0.065778 0.135209
spend_last_month 97.801230 105.271930 7.470700 0.129267
# Propensity model fit
ps_model = eda.fit_propensity()

# ROC AUC - shows how predictable treatment is from confounders
roc_auc_score = ps_model.roc_auc
print("ROC AUC from PropensityModel:", round(roc_auc_score, 4))
ROC AUC from PropensityModel: 0.6053
# Positivity check - assess overlap between treatment groups
positivity_result = ps_model.positivity_check()
print("Positivity check from PropensityModel:", positivity_result)
Positivity check from PropensityModel: {'bounds': (0.05, 0.95), 'share_below': 0.0008, 'share_above': 0.0, 'flag': False}
# SHAP values - feature importance for treatment assignment from confounders
shap_values_df = ps_model.shap
display(shap_values_df)
feature shap_mean
0 spend_last_month -0.000322
1 tenure_months 0.000140
2 avg_sessions_week 0.000122
3 premium_user 0.000052
4 urban_resident 0.000008
# Propensity score overlap graph
ps_model.ps_graph()
plt.show()
../_images/c7c4b11eb29d99e1c57798027b7500bcb86e2a5a69b52ec614a141268dbfd6c4.png

Outcome regression#

Let’s analyze how confounders predict outcome

# Outcome model fit
outcome_model = eda.outcome_fit()

# RMSE and MAE of regression model
print(outcome_model.scores)
{'rmse': 3.6896558290948964, 'mae': 2.94099663606919}
# 2) SHAP values - feature importance for outcome prediction from confounders
shap_outcome_df = outcome_model.shap
display(shap_outcome_df)
feature shap_mean
0 spend_last_month 0.002252
1 tenure_months -0.001747
2 avg_sessions_week -0.000951
3 premium_user 0.000729
4 urban_resident -0.000282

Inference#

Now time to estimate ATE with Double Machine Learning

from causalkit.inference.ate import dml_ate

# Estimate Average Treatment Effect (ATE)
ate_result = dml_ate(causal_data, n_folds=3, confidence_level=0.95)
ate_result
{'coefficient': 1.7547754158492368,
 'std_error': 0.08120639822820347,
 'p_value': 1.4835396169728733e-103,
 'confidence_interval': (1.5956138000077407, 1.913937031690733),
 'model': <doubleml.irm.irm.DoubleMLIRM at 0x157b9fed0>}

True theta in our data generating proccess was 1.8

Refutation#

Placebo#

# Import refutation utilities
from causalkit.refutation import (
    refute_placebo_outcome,
    refute_placebo_treatment,
    refute_subset,
    refute_irm_orthogonality,
    sensitivity_analysis,
    sensitivity_analysis_set
)

Replacing outcome with dummy random variable must broke our model and effect will be near zero

# Replacing an outcome with placebo
ate_placebo_outcome = refute_placebo_outcome(
    dml_ate,
    causal_data,
    random_state=42
)

print(ate_placebo_outcome)
{'theta': 0.0010896926888271316, 'p_value': 0.863664499475262}

Replacing treatment with dummy random variable must broke our model and effect will be near zero

# Replacing treatment with placebo
ate_placebo_treatment = refute_placebo_treatment(
    dml_ate,
    causal_data,
    random_state=42
)

print(ate_placebo_treatment)
{'theta': 0.05021278313460212, 'p_value': 0.5290821466624945}

Let’s chanllege our dataset and romove random parts. Theta shoul be near estimated

# Inference on subsets
subset_fractions = [0.3, 0.5]

ate_subset_results = []
for fraction in subset_fractions:
    subset_result = refute_subset(
        dml_ate,
        causal_data,
        fraction=fraction,
        random_state=42
    )
    ate_subset_results.append(subset_result)

    print(f" With {fraction*100:.0f}% subset: theta = {subset_result['theta']:.4f}, p_value = {subset_result['p_value']:.4f}")
 With 30% subset: theta = 1.5725, p_value = 0.0000
 With 50% subset: theta = 1.8468, p_value = 0.0000

Orthogonality#

Orthogonality tests show us “Does we correct specify the model”

ate_ortho_check = refute_irm_orthogonality(dml_ate, causal_data)
# 1. Out-of-sample moment check
print("\n--- 1. Out-of-Sample Moment Check ---")
oos_test = ate_ortho_check['oos_moment_test']
print(f"T-statistic: {oos_test['tstat']:.4f}")
print(f"P-value: {oos_test['pvalue']:.4f}")
print(f"Interpretation: {oos_test['interpretation']}")
print("\nFold-wise results:")
display(oos_test['fold_results'])
--- 1. Out-of-Sample Moment Check ---
T-statistic: 0.3617
P-value: 0.7176
Interpretation: Should be ≈ 0 if moment condition holds

Fold-wise results:
fold n psi_mean psi_var
0 0 2000 0.142070 64.367381
1 1 2000 0.114837 71.748247
2 2 2000 0.106751 65.548761
3 3 2000 0.091285 62.569903
4 4 2000 -0.307471 68.208318

psi_mean is near zero on every fold, so the test is successful

# 2. Orthogonality derivatives
print("\n--- 2. Orthogonality (Gateaux Derivative) Tests ---")
ortho_derivs = ate_ortho_check['orthogonality_derivatives']
print(f"Interpretation: {ortho_derivs['interpretation']}")

print("\nFull sample derivatives:")
display(ortho_derivs['full_sample'])

print("\nTrimmed sample derivatives:")
display(ortho_derivs['trimmed_sample'])

if len(ortho_derivs['problematic_full']) > 0:
    print("\n⚠ PROBLEMATIC derivatives (full sample):")
    display(ortho_derivs['problematic_full'])
else:
    print("\n✓ No problematic derivatives in full sample")

if len(ortho_derivs['problematic_trimmed']) > 0:
    print("\n⚠ PROBLEMATIC derivatives (trimmed sample):")
    display(ortho_derivs['problematic_trimmed'])
else:
    print("\n✓ No problematic derivatives in trimmed sample")
--- 2. Orthogonality (Gateaux Derivative) Tests ---
Interpretation: Large |t-stats| (>2) indicate calibration issues

Full sample derivatives:
basis d_m1 se_m1 t_m1 d_m0 se_m0 t_m0 d_g se_g t_g
0 0 -0.027778 0.015464 -1.796300 0.009315 0.007849 1.186833 -0.103516 0.296654 -0.348945
1 1 0.005088 0.016336 0.311467 0.003259 0.008088 0.402935 0.137892 0.322104 0.428097
2 2 0.003129 0.016049 0.194959 0.000922 0.008162 0.112934 0.028325 0.385094 0.073553
3 3 -0.001588 0.014938 -0.106328 -0.000145 0.008205 -0.017732 0.021827 0.282692 0.077210
4 4 -0.000312 0.012767 -0.024401 0.002863 0.009497 0.301462 -0.019862 0.231095 -0.085947
5 5 0.000451 0.015937 0.028314 -0.001855 0.007660 -0.242168 0.138553 0.316203 0.438177
Trimmed sample derivatives:
basis d_m1 se_m1 t_m1 d_m0 se_m0 t_m0 d_g se_g t_g
0 0 -0.027778 0.015464 -1.796300 0.009315 0.007849 1.186833 -0.103516 0.296654 -0.348945
1 1 0.005088 0.016336 0.311467 0.003259 0.008088 0.402935 0.137892 0.322104 0.428097
2 2 0.003129 0.016049 0.194959 0.000922 0.008162 0.112934 0.028325 0.385094 0.073553
3 3 -0.001588 0.014938 -0.106328 -0.000145 0.008205 -0.017732 0.021827 0.282692 0.077210
4 4 -0.000312 0.012767 -0.024401 0.002863 0.009497 0.301462 -0.019862 0.231095 -0.085947
5 5 0.000451 0.015937 0.028314 -0.001855 0.007660 -0.242168 0.138553 0.316203 0.438177
✓ No problematic derivatives in full sample

✓ No problematic derivatives in trimmed sample

t_m1, t_m0, t_g should be above 2. Even after trimming our data t_m1, t_m0, t_g above 2 across all confounders

# 3. Influence diagnostics
print("\n--- 3. Influence Diagnostics ---")
influence = ate_ortho_check['influence_diagnostics']
print(f"Interpretation: {influence['interpretation']}")

print("\nFull sample influence metrics:")
print(f"  Plugin SE: {influence['full_sample']['se_plugin']:.4f}")
print(f"  Kurtosis: {influence['full_sample']['kurtosis']:.2f}")
print(f"  P99/Median ratio: {influence['full_sample']['p99_over_med']:.2f}")

print("\nTrimmed sample influence metrics:")
print(f"  Plugin SE: {influence['trimmed_sample']['se_plugin']:.4f}")
print(f"  Kurtosis: {influence['trimmed_sample']['kurtosis']:.2f}")
print(f"  P99/Median ratio: {influence['trimmed_sample']['p99_over_med']:.2f}")

print("\nTop influential observations (full sample):")
display(influence['full_sample']['top_influential'])
--- 3. Influence Diagnostics ---
Interpretation: Heavy tails or extreme kurtosis suggest instability

Full sample influence metrics:
  Plugin SE: 0.0815
  Kurtosis: 7.64
  P99/Median ratio: 6.59

Trimmed sample influence metrics:
  Plugin SE: 0.0815
  Kurtosis: 7.64
  P99/Median ratio: 6.59

Top influential observations (full sample):
i psi g res_t res_c
0 7299 -78.668133 0.100071 -7.970469 -0.000000
1 874 63.471801 0.116169 8.132480 0.000000
2 9405 49.195863 0.169723 8.457407 0.000000
3 9116 -46.282651 0.164969 -7.989467 -0.000000
4 294 -44.627395 0.213603 -9.570962 -0.000000
5 7634 44.388611 0.742768 -0.000000 -12.113462
6 8409 44.028348 0.284715 12.696785 0.000000
7 5821 -43.885460 0.153316 -6.377281 -0.000000
8 8085 -43.798883 0.247829 -11.536566 -0.000000
9 731 43.545110 0.165439 6.820135 0.000000
# Trimming information
print("\n--- Propensity Score Trimming ---")
trim_info = ate_ortho_check['trimming_info']
print(f"Trimming bounds: {trim_info['bounds']}")
print(f"Observations trimmed: {trim_info['n_trimmed']} ({trim_info['pct_trimmed']:.1f}%)")
--- Propensity Score Trimming ---
Trimming bounds: (0.02, 0.98)
Observations trimmed: 0 (0.0%)

In this test we analyze distribution of theta before trimming and after. There is no big difference. So test passed

# Diagnostic conditions breakdown
print("\n--- Diagnostic Conditions Assessment ---")
conditions = ate_ortho_check['diagnostic_conditions']

print("Individual condition checks:")
for condition, passed in conditions.items():
    status = "✓ PASS" if passed else "✗ FAIL"
    print(f"  {condition}: {status}")

print(f"\nOverall: {ate_ortho_check['overall_assessment']}")
--- Diagnostic Conditions Assessment ---
Individual condition checks:
  oos_moment_ok: ✓ PASS
  derivs_full_ok: ✓ PASS
  derivs_trim_ok: ✓ PASS
  se_reasonable: ✓ PASS
  no_extreme_influence: ✓ PASS
  trimming_reasonable: ✓ PASS

Overall: PASS: Strong evidence for orthogonality

Sensitivity analysis#

Let’s analyze how unobserved confounder could look

bench_sets = causal_data.confounders
res = sensitivity_analysis_set(
    ate_result,
    benchmarking_set=bench_sets,
    level=0.95,
    null_hypothesis=0.0,
)
# Build a DataFrame with confounders as rows and the metrics as columns
summary_df = pd.DataFrame({
    name: (df.loc['t'] if 't' in df.index else df.iloc[0])
    for name, df in res.items()
}).T
summary_df
cf_y cf_d rho delta_theta
tenure_months 0.019343 0.004555 1.000000 0.125310
avg_sessions_week 0.031075 0.010474 1.000000 0.172854
spend_last_month 0.080634 0.006735 0.982448 0.177995
premium_user 0.062270 0.031318 1.000000 0.360557
urban_resident 0.018716 0.000000 1.000000 0.093394

It is business domain and data knowledge relative question. In this situation we will stop on the most influential confounder ‘premium_user’ and test theta when we do not observe anothere confounder with strength of ‘premium_user’

# Run sensitivity analysis on our ATE result
sensitivity_report_1 = sensitivity_analysis(
    ate_result,
    cf_y=0.06,  # Confounding strength affecting outcome
    cf_d=0.03,  # Confounding strength affecting treatment
    rho=1.0     # Perfect correlation between unobserved confounders
)

print(sensitivity_report_1)
================== Sensitivity Analysis ==================

------------------ Scenario          ------------------
Significance Level: level=0.95
Sensitivity parameters: cf_y=0.06; cf_d=0.03, rho=1.0

------------------ Bounds with CI    ------------------
   CI lower  theta lower     theta  theta upper  CI upper
t  1.284916     1.418741  1.754775      2.09081  2.224353

------------------ Robustness Values ------------------
   H_0     RV (%)    RVa (%)
t  0.0  20.106836  18.701604

CI lower >> 0. It means test is passed