causalkit.refutation.refute_irm_orthogonality#
- causalkit.refutation.refute_irm_orthogonality(inference_fn, data, trim_propensity=(0.02, 0.98), n_basis_funcs=None, n_folds_oos=5, target='ATE', clip_eps=0.01, strict_oos=False, **inference_kwargs)[source]#
Comprehensive AIPW orthogonality diagnostics for DoubleML estimators.
Implements three key diagnostic approaches based on the efficient influence function (EIF): 1. Out-of-sample moment check (non-tautological) 2. Orthogonality (Gateaux derivative) tests 3. Influence diagnostics
- Parameters:
inference_fn (Callable) – The inference function (dml_ate or dml_att)
data (CausalData) – The causal data object
trim_propensity (Tuple[float, float], default (0.02, 0.98)) – Propensity score trimming bounds (min, max) to avoid extreme weights
n_basis_funcs (Optional[int], default None (len(confounders)+1)) – Number of basis functions for orthogonality derivative tests (constant + covariates). If None, defaults to the number of confounders in data plus 1 for the constant term.
n_folds_oos (int, default 5) – Number of folds for out-of-sample moment check
**inference_kwargs (dict) – Additional arguments passed to inference_fn
target (str)
clip_eps (float)
strict_oos (bool)
- Returns:
Dictionary containing: - oos_moment_test: Out-of-sample moment condition results - orthogonality_derivatives: Gateaux derivative test results - influence_diagnostics: Influence function diagnostics - theta: Original treatment effect estimate - trimmed_diagnostics: Results on trimmed sample - overall_assessment: Summary diagnostic assessment
- Return type:
Dict[str, Any]
Examples
>>> from causalkit.refutation import refute_irm_orthogonality >>> from causalkit.inference.ate import dml_ate >>> >>> # Comprehensive orthogonality check >>> ortho_results = refute_irm_orthogonality(dml_ate, causal_data) >>> >>> # Check key diagnostics >>> print(f"OOS moment t-stat: {ortho_results['oos_moment_test']['tstat']:.3f}") >>> print(f"Calibration issues: {len(ortho_results['orthogonality_derivatives'].query('abs(t_g) > 2'))}") >>> print(f"Assessment: {ortho_results['overall_assessment']}")