causalis.eda.CausalEDA#

class causalis.eda.CausalEDA(data, ps_model=None, n_splits=5, random_state=42)[source]#

Exploratory diagnostics for causal designs with binary treatment.

The class exposes methods to:

  • Summarize outcome by treatment and naive mean difference.

  • Estimate cross-validated propensity scores and assess treatment predictability (AUC) and positivity/overlap.

  • Inspect covariate balance via standardized mean differences (SMD) before/after IPTW weighting; visualize with a love plot.

  • Inspect weight distributions and effective sample size (ESS).

Parameters:
  • data (Any)

  • ps_model (Optional[Any])

  • n_splits (int)

  • random_state (int)

__init__(data, ps_model=None, n_splits=5, random_state=42)[source]#
Parameters:
  • data (Any)

  • ps_model (Any | None)

  • n_splits (int)

  • random_state (int)

Methods

__init__(data[, ps_model, n_splits, ...])

auc_m([m])

Compute ROC AUC of treatment assignment vs.

confounders_means()

Comprehensive confounders balance assessment with means by treatment group.

confounders_roc_auc([ps])

data_shape()

Return the shape information of the causal dataset.

fit_m()

Estimate cross-validated m(x) = P(D=1|X).

fit_propensity()

m_features()

Return feature attribution from the fitted m(x) model.

outcome_boxplot([treatment, target, ...])

Prettified boxplot of the outcome by treatment.

outcome_fit([outcome_model])

Fit a regression model to predict outcome from confounders only.

outcome_hist([treatment, target, bins, ...])

Plot the distribution of the outcome for each treatment on a single, pretty plot.

outcome_plots([treatment, target, bins, ...])

Plot the distribution of the outcome for every treatment on one plot, and also produce a boxplot by treatment to visualize outliers.

outcome_stats()

Comprehensive outcome statistics grouped by treatment.

plot_m_overlap([m])

Plot overlaid histograms of m(x) for treated vs control.

plot_ps_overlap([ps])

positivity_check([ps, bounds])

positivity_check_m([m, bounds])

Check overlap/positivity for m(x) based on thresholds.

treatment_features()

Attributes

Optional

Tuple

alias of Tuple

__init__(data, ps_model=None, n_splits=5, random_state=42)[source]#
Parameters:
  • data (Any)

  • ps_model (Any | None)

  • n_splits (int)

  • random_state (int)

data_shape()[source]#

Return the shape information of the causal dataset.

Returns a dict with: - n_rows: number of rows (observations) in the dataset - n_columns: number of columns (features) in the dataset

This provides a quick overview of the dataset dimensions for exploratory analysis and reporting purposes.

Returns:

Dictionary containing ‘n_rows’ and ‘n_columns’ keys with corresponding integer values representing the dataset dimensions.

Return type:

Dict[str, int]

Examples

>>> eda = CausalEDA(causal_data)
>>> shape_info = eda.data_shape()
>>> print(f"Dataset has {shape_info['n_rows']} rows and {shape_info['n_columns']} columns")
outcome_stats()[source]#

Comprehensive outcome statistics grouped by treatment.

Returns a DataFrame with detailed outcome statistics for each treatment group, including count, mean, std, min, various percentiles, and max. This method provides comprehensive outcome analysis and returns data in a clean DataFrame format suitable for reporting.

Returns:

DataFrame with treatment groups as index and the following columns: - count: number of observations in each group - mean: average outcome value - std: standard deviation of outcome - min: minimum outcome value - p10: 10th percentile - p25: 25th percentile (Q1) - median: 50th percentile (median) - p75: 75th percentile (Q3) - p90: 90th percentile - max: maximum outcome value

Return type:

pd.DataFrame

Examples

>>> eda = CausalEDA(causal_data)
>>> stats = eda.outcome_stats()
>>> print(stats)
        count      mean       std       min       p10       p25    median       p75       p90       max
treatment
0        3000  5.123456  2.345678  0.123456  2.345678  3.456789  5.123456  6.789012  7.890123  9.876543
1        2000  6.789012  2.456789  0.234567  3.456789  4.567890  6.789012  8.901234  9.012345  10.987654
fit_m()[source]#

Estimate cross-validated m(x) = P(D=1|X).

Uses a preprocessing + classifier setup with stratified K-fold to generate out-of-fold probabilities. For CatBoost, data are one-hot encoded via the configured ColumnTransformer before fitting. Returns a PropensityModel.

Return type:

PropensityModel

fit_propensity()[source]#
Return type:

PropensityModel

outcome_fit(outcome_model=None)[source]#

Fit a regression model to predict outcome from confounders only.

Uses a preprocessing+CatBoost regressor pipeline with K-fold cross_val_predict to generate out-of-fold predictions. CatBoost uses all available threads and handles categorical features natively. Returns an OutcomeModel instance containing predicted outcomes and diagnostic methods.

The outcome model predicts the baseline outcome from confounders only, excluding treatment. This is essential for proper causal analysis.

Parameters:

outcome_model (Optional[Any]) – Custom regression model to use. If None, uses CatBoostRegressor.

Returns:

An OutcomeModel instance with methods for: - scores: RMSE and MAE regression metrics - shap: SHAP values DataFrame property for outcome prediction

Return type:

OutcomeModel

auc_m(m=None)[source]#

Compute ROC AUC of treatment assignment vs. m(x).

Return type:

float

Parameters:

m (ndarray | None)

positivity_check_m(m=None, bounds=(0.05, 0.95))[source]#

Check overlap/positivity for m(x) based on thresholds.

Return type:

Dict[str, Any]

Parameters:
plot_m_overlap(m=None)[source]#

Plot overlaid histograms of m(x) for treated vs control.

Parameters:

m (ndarray | None)

confounders_roc_auc(ps=None)[source]#
Return type:

float

Parameters:

ps (ndarray | None)

positivity_check(ps=None, bounds=(0.05, 0.95))[source]#
Return type:

Dict[str, Any]

Parameters:
plot_ps_overlap(ps=None)[source]#
Parameters:

ps (ndarray | None)

confounders_means()[source]#

Comprehensive confounders balance assessment with means by treatment group.

Returns a DataFrame with detailed balance information including: - Mean values of each confounder for control group (treatment=0) - Mean values of each confounder for treated group (treatment=1) - Absolute difference between treatment groups - Standardized Mean Difference (SMD) for formal balance assessment - Kolmogorov–Smirnov statistic (ks) and p-value (ks_pvalue) for distributional differences

This method provides a comprehensive view of confounder balance by showing the actual mean values alongside the standardized differences, making it easier to understand both the magnitude and direction of imbalances.

Returns:

DataFrame with confounders as index and the following columns: - mean_t_0: mean value for control group (treatment=0) - mean_t_1: mean value for treated group (treatment=1) - abs_diff: absolute difference abs(mean_t_1 - mean_t_0) - smd: standardized mean difference (Cohen’s d) - ks: Kolmogorov–Smirnov statistic - ks_pvalue: p-value of the KS test

Return type:

pd.DataFrame

Notes

SMD values > 0.1 in absolute value typically indicate meaningful imbalance. Categorical variables are automatically converted to dummy variables.

Examples

>>> eda = CausalEDA(causal_data)
>>> balance = eda.confounders_means()
>>> print(balance.head())
             mean_t_0  mean_t_1  abs_diff       smd
confounders
age              29.5      31.2      1.7     0.085
income        45000.0   47500.0   2500.0     0.125
education         0.25      0.35      0.1     0.215
outcome_hist(treatment=None, target=None, bins='fd', density=True, alpha=0.45, sharex=True, kde=True, clip=(0.01, 0.99), figsize=(9, 5.5), dpi=220, font_scale=1.15, save=None, save_dpi=None, transparent=False)[source]#

Plot the distribution of the outcome for each treatment on a single, pretty plot.

Features#

  • High-DPI canvas + scalable fonts

  • Default Matplotlib colors; KDE & mean lines match their histogram colors

  • Numeric outcomes: shared x-range (optional), optional KDE, quantile clipping

  • Categorical outcomes: normalized grouped bars by treatment

  • Optional hi-res export (PNG/SVG/PDF)

Parameters:
Optional = typing.Optional#
Tuple#

alias of Tuple

outcome_boxplot(treatment=None, target=None, figsize=(9, 5.5), dpi=220, font_scale=1.15, showfliers=True, patch_artist=True, save=None, save_dpi=None, transparent=False)[source]#

Prettified boxplot of the outcome by treatment.

Features#

  • High-DPI figure, scalable fonts

  • Soft modern color styling (default Matplotlib palette)

  • Optional outliers, gentle transparency

  • Optional save to PNG/SVG/PDF

Parameters:
outcome_plots(treatment=None, target=None, bins=30, density=True, alpha=0.5, figsize=(7, 4), sharex=True)[source]#

Plot the distribution of the outcome for every treatment on one plot, and also produce a boxplot by treatment to visualize outliers.

Parameters:
  • treatment (Optional[str]) – Treatment column name. Defaults to the treatment stored in the CausalEDA data.

  • target (Optional[str]) – Target/outcome column name. Defaults to the outcome stored in the CausalEDA data.

  • bins (int) – Number of bins for histograms when the outcome is numeric.

  • density (bool) – Whether to normalize histograms to form a density.

  • alpha (float) – Transparency for overlaid histograms.

  • figsize (tuple) – Figure size for the plots.

  • sharex (bool) – If True and the outcome is numeric, use the same x-limits across treatments.

Returns:

(fig_distribution, fig_boxplot)

Return type:

Tuple[matplotlib.figure.Figure, matplotlib.figure.Figure]

m_features()[source]#

Return feature attribution from the fitted m(x) model. :rtype: DataFrame

  • CatBoost path: SHAP attributions with columns ‘shap_mean’ and ‘shap_mean_abs’, sorted by ‘shap_mean_abs’. Uses transformed feature names from the preprocessor.

  • Sklearn path (LogisticRegression): absolute coefficients reported as ‘coef_abs’.

Return type:

DataFrame

treatment_features()[source]#
Return type:

DataFrame