CausalData at a glance#
CausalData is the light-weight input container used across CausalKit. It wraps a pandas DataFrame and records which columns are the outcome (target), the treatment, and the confounders.
Quick start#
from causalkit.data import generate_rct_data, CausalData
# Example data
df = generate_rct_data(n_users=5_000)
# Declare column roles
causal_data = CausalData(
df=df,
treatment='treatment',
outcome='outcome',
confounders=['age', 'cnt_trans', 'platform_Android', 'platform_iOS', 'invited_friend']
)
Note: Internally, the stored DataFrame is trimmed to only these columns: [outcome, treatment, confounders].
API essentials#
Init parameters
df: pandas DataFrame (no NaNs)
treatment: name of the treatment column (numeric)
outcome: name of the outcome/target column (numeric)
confounders: one or more confounder column names (numeric)
Properties
target: pandas Series (the outcome)
outcome: alias of target
treatment: pandas Series
confounders: list[str] of confounder column names
Method
get_df(columns=None, include_treatment=True, include_target=True, include_confounders=True) -> DataFrame Selects columns by name and/or by role. Returns a copy.
Validation (on construction)#
No missing values anywhere in df.
All referenced columns must exist.
Outcome, treatment, and confounders must be numeric (int/float).
None of these columns can be constant (zero variance).
Any two used columns having identical values is disallowed (raises ValueError).
Duplicate rows across the used columns trigger a warning (not an error).
Common snippets#
from causalkit.data import generate_rct_data, CausalData
df = generate_rct_data(n_users=1_000)
causal_data = CausalData(
df=df,
treatment='treatment',
outcome='outcome',
confounders=['age', 'cnt_trans', 'platform_Android', 'platform_iOS', 'invited_friend']
)
# Access pieces
causal_data.treatment # Series
causal_data.target # Series
causal_data.confounders # list[str]
# Full data used by CausalData
default_df = causal_data.df # or equivalently
default_df = causal_data.get_df()
# DataFrame of only confounders
X = causal_data.get_df(include_target=False, include_treatment=False)
# or
X = causal_data.df[causal_data.confounders]
# Select a subset by name(s)
small = causal_data.get_df(columns=['age'])
Tips#
For categorical confounders, encode them numerically (e.g., one-hot) before creating CausalData.
If you see the duplicate-rows warning, consider deduplicating if duplicates are unintended.
repr shows the stored shape and declared roles for quick inspection.