Source code for causalis.data.causaldata

"""
CKit class for storing DataFrame and column metadata for causal inference.
"""

import pandas as pd
import pandas.api.types as pdtypes
from typing import Union, List, Optional
import warnings


[docs] class CausalData: """ Container for causal inference datasets. Wraps a pandas DataFrame and stores the names of treatment, outcome, and optional confounder columns. The stored DataFrame is restricted to only those columns. Parameters ---------- df : pd.DataFrame The DataFrame containing the data. Cannot contain NaN values. Only columns specified in outcome, treatment, and confounders will be stored. treatment : str Column name representing the treatment variable. outcome : str Column name representing the outcome (target) variable. confounders : Union[str, List[str]], optional Column name(s) representing the confounders/covariates. Attributes ---------- df : pd.DataFrame A copy of the original data restricted to [outcome, treatment] + confounders. treatment : str Name of the treatment column. outcome : str Name of the outcome (target) column. confounders : list[str] Names of the confounder columns (may be empty). Examples -------- >>> from causalis.data import generate_rct >>> from causalis.data import CausalData >>> >>> # Generate data >>> df = generate_rct() >>> >>> # Create CausalData object >>> causal_data = CausalData( ... df=df, ... treatment='treatment', ... outcome='outcome', ... confounders=['age', 'invited_friend'] ... ) >>> >>> # Access data >>> causal_data.df.head() >>> >>> # Access columns by role >>> causal_data.target >>> causal_data.confounders >>> causal_data.treatment """
[docs] def __init__( self, df: pd.DataFrame, treatment: str, outcome: str, confounders: Optional[Union[str, List[str]]] = None, ): """ Initialize a CausalData object. """ self._treatment = treatment self._target = outcome # Store confounders as a list of unique names (preserve order) conf_list = self._ensure_list(confounders) if confounders is not None else [] merged: List[str] = [] for v in conf_list: if v not in merged: merged.append(v) self._confounders = merged # Validate column names self._validate_columns(df) # Store only the relevant columns columns_to_keep = [self._target, self._treatment] + self._confounders self.df = df[columns_to_keep].copy() # Coerce boolean columns to integers to ensure stored df is fully numeric for col in self.df.columns: if pdtypes.is_bool_dtype(self.df[col]): # Use int8 for compact 0/1 storage self.df[col] = self.df[col].astype("int8") # Final safeguard: ensure all stored columns are numeric for col in self.df.columns: if not pdtypes.is_numeric_dtype(self.df[col]): raise ValueError( f"All columns in stored DataFrame must be numeric; column '{col}' has dtype {self.df[col].dtype}." ) # Re-run duplicate-column and duplicate-row checks on the stored, normalized subset self._check_duplicate_column_values(self.df) self._check_duplicate_rows(self.df)
def _ensure_list(self, value: Union[str, List[str]]) -> List[str]: """ Ensure that the value is a list of strings. """ if isinstance(value, str): return [value] return value def _validate_columns(self, df): """ Validate that all specified columns exist in the DataFrame and that the DataFrame does not contain NaN values. Also validate that outcome, confounders, and treatment columns contain only int or float values. Also validate that no columns are constant (have zero variance). """ # Check for NaN values in the DataFrame if df.isna().any().any(): raise ValueError("DataFrame contains NaN values, which are not allowed.") all_columns = set(df.columns) # Validate outcome column if self._target not in all_columns: raise ValueError(f"Column '{self._target}' specified as outcome does not exist in the DataFrame.") # Check if outcome column contains numeric or boolean values if not (pdtypes.is_numeric_dtype(df[self._target]) or pdtypes.is_bool_dtype(df[self._target])): raise ValueError(f"Column '{self._target}' specified as outcome must contain only int, float, or bool values.") # Check if outcome column is constant (single unique value) if df[self._target].nunique(dropna=False) <= 1: raise ValueError( f"Column '{self._target}' specified as outcome is constant (has zero variance / single unique value), " f"which is not allowed for causal inference." ) # Validate treatment column if self._treatment not in all_columns: raise ValueError(f"Column '{self._treatment}' specified as treatment does not exist in the DataFrame.") # Check if treatment column contains numeric or boolean values if not (pdtypes.is_numeric_dtype(df[self._treatment]) or pdtypes.is_bool_dtype(df[self._treatment])): raise ValueError( f"Column '{self._treatment}' specified as treatment must contain only int, float, or bool values.") # Check if treatment column is constant (single unique value) if df[self._treatment].nunique(dropna=False) <= 1: raise ValueError( f"Column '{self._treatment}' specified as treatment is constant (has zero variance / single unique value), " f"which is not allowed for causal inference." ) # Validate confounders columns; drop constant ones with a warning kept_confounders: List[str] = [] dropped_constants: List[str] = [] for col in self._confounders: if col not in all_columns: raise ValueError(f"Column '{col}' specified as confounders does not exist in the DataFrame.") # Check if confounder column contains numeric or boolean values if not (pdtypes.is_numeric_dtype(df[col]) or pdtypes.is_bool_dtype(df[col])): raise ValueError(f"Column '{col}' specified as confounders must contain only int, float, or bool values.") # Check if confounder column is constant (single unique value) if df[col].nunique(dropna=False) <= 1: dropped_constants.append(col) continue kept_confounders.append(col) if dropped_constants: warnings.warn( "Dropping constant confounder columns (zero variance): " + ", ".join(dropped_constants), UserWarning, stacklevel=2, ) # Update confounders to exclude dropped constants self._confounders = kept_confounders # Note: duplicate columns and duplicate rows are checked on the stored, normalized subset # after dtype coercion, to reflect the actual data used by the class. def _check_duplicate_column_values(self, df): """ Check for duplicate column values across all used columns. Raises ValueError if any two columns have identical values. """ # Get all columns that will be used in CausalData columns_to_check = [self._target, self._treatment] + self._confounders # Compare each pair of columns (post-normalization) for i, col1 in enumerate(columns_to_check): for j in range(i + 1, len(columns_to_check)): col2 = columns_to_check[j] # Use pandas.Series.equals for exact equality on stored subset (NaN not expected) if df[col1].equals(df[col2]): # Determine the types of columns for better error message col1_type = self._get_column_type(col1) col2_type = self._get_column_type(col2) raise ValueError( f"Columns '{col1}' ({col1_type}) and '{col2}' ({col2_type}) have identical values, " f"which is not allowed for causal inference. Only column names differ." ) def _check_duplicate_rows(self, df): """ Check for duplicate rows in the DataFrame and issue a warning if found. Only checks the columns that will be used in CausalData. """ # Get only the columns that will be used in CausalData columns_to_check = [self._target, self._treatment] + self._confounders df_subset = df[columns_to_check] # Find duplicate rows duplicated_mask = df_subset.duplicated() num_duplicates = int(duplicated_mask.sum()) if num_duplicates > 0: total_rows = int(len(df_subset)) unique_rows = total_rows - num_duplicates warnings.warn( f"Found {num_duplicates} duplicate rows out of {total_rows} total rows in the DataFrame. " f"This leaves {unique_rows} unique rows for analysis. " f"Duplicate rows may affect the quality of causal inference results. " f"Consider removing duplicates if they are not intentional.", UserWarning, stacklevel=2 ) def _get_column_type(self, column_name): """ Determine the type/role of a column (treatment, outcome, or confounder). """ if column_name == self._target: return "outcome" elif column_name == self._treatment: return "treatment" elif column_name in self._confounders: return "confounder" else: return "unknown" @property def target(self) -> pd.Series: """ Get the outcome/outcome variable. Returns ------- pd.Series The outcome column as a pandas Series. """ return self.df[self._target] # Backwards-compat alias expected by CausalEDA: expose `.outcome` as a Series @property def outcome(self) -> pd.Series: return self.target @property def confounders(self) -> List[str]: """List of confounder column names.""" return list(self._confounders) if self._confounders else [] @property def treatment(self) -> pd.Series: """ Get the treatment variable. Returns ------- pd.Series The treatment column as a pandas Series. """ return self.df[self._treatment]
[docs] def get_df( self, columns: Optional[List[str]] = None, include_treatment: bool = True, include_target: bool = True, include_confounders: bool = True ) -> pd.DataFrame: """ Get a DataFrame from the CausalData object with specified columns. Parameters ---------- columns : List[str], optional Specific column names to include in the returned DataFrame. If provided, these columns will be included in addition to any columns specified by the include parameters. If None, columns will be determined solely by the include parameters. If None and no include parameters are True, returns the entire DataFrame. include_treatment : bool, default True Whether to include treatment column(s) in the returned DataFrame. include_target : bool, default True Whether to include target column(s) in the returned DataFrame. include_confounders : bool, default True Whether to include confounder column(s) in the returned DataFrame. Returns ------- pd.DataFrame DataFrame containing the specified columns. Examples -------- >>> from causalis.data import generate_rct >>> from causalis.data import CausalData >>> >>> # Generate data >>> df = generate_rct() >>> >>> # Create CausalData object >>> causal_data = CausalData( ... df=df, ... treatment='treatment', ... outcome='outcome', ... confounders=['age', 'invited_friend'] ... ) >>> >>> # Get specific columns >>> causal_data.get_df(columns=['age']) >>> >>> # Get all columns >>> causal_data.get_df() """ # Start with empty list of columns to include cols_to_include = [] # If specific columns are provided, add them to the list if columns is not None: cols_to_include.extend(columns) # If no specific columns are provided and no include parameters are True, # return the entire DataFrame if columns is None and not any([include_target, include_confounders, include_treatment]): return self.df.copy() # Add columns based on include parameters if include_target: cols_to_include.append(self._target) if include_confounders: cols_to_include.extend(self._confounders) if include_treatment: cols_to_include.append(self._treatment) # Remove duplicates while preserving order cols_to_include = list(dict.fromkeys(cols_to_include)) # Validate that all requested columns exist (only needed if user passed custom columns) missing = [c for c in cols_to_include if c not in self.df.columns] if missing: raise ValueError(f"Column(s) {missing} do not exist in the DataFrame.") # Return the DataFrame with selected columns return self.df[cols_to_include].copy()
[docs] def __repr__(self) -> str: """ String representation of the CausalData object. """ return ( f"CausalData(df={self.df.shape}, " f"treatment='{self._treatment}', " f"outcome='{self._target}', " f"confounders={self._confounders})" )