causalkit.refutation.refute_irm_orthogonality#

causalkit.refutation.refute_irm_orthogonality(inference_fn, data, trim_propensity=(0.02, 0.98), n_basis_funcs=None, n_folds_oos=5, target='ATE', clip_eps=0.01, strict_oos=False, **inference_kwargs)[source]#

Comprehensive AIPW orthogonality diagnostics for DoubleML estimators.

Implements three key diagnostic approaches based on the efficient influence function (EIF): 1. Out-of-sample moment check (non-tautological) 2. Orthogonality (Gateaux derivative) tests 3. Influence diagnostics

Parameters:
  • inference_fn (Callable) – The inference function (dml_ate or dml_att)

  • data (CausalData) – The causal data object

  • trim_propensity (Tuple[float, float], default (0.02, 0.98)) – Propensity score trimming bounds (min, max) to avoid extreme weights

  • n_basis_funcs (Optional[int], default None (len(confounders)+1)) – Number of basis functions for orthogonality derivative tests (constant + covariates). If None, defaults to the number of confounders in data plus 1 for the constant term.

  • n_folds_oos (int, default 5) – Number of folds for out-of-sample moment check

  • **inference_kwargs (dict) – Additional arguments passed to inference_fn

  • target (str)

  • clip_eps (float)

  • strict_oos (bool)

Returns:

Dictionary containing: - oos_moment_test: Out-of-sample moment condition results - orthogonality_derivatives: Gateaux derivative test results - influence_diagnostics: Influence function diagnostics - theta: Original treatment effect estimate - trimmed_diagnostics: Results on trimmed sample - overall_assessment: Summary diagnostic assessment

Return type:

Dict[str, Any]

Examples

>>> from causalkit.refutation import refute_irm_orthogonality
>>> from causalkit.inference.ate import dml_ate
>>>
>>> # Comprehensive orthogonality check
>>> ortho_results = refute_irm_orthogonality(dml_ate, causal_data)
>>>
>>> # Check key diagnostics
>>> print(f"OOS moment t-stat: {ortho_results['oos_moment_test']['tstat']:.3f}")
>>> print(f"Calibration issues: {len(ortho_results['orthogonality_derivatives'].query('abs(t_g) > 2'))}")
>>> print(f"Assessment: {ortho_results['overall_assessment']}")