Estimating Heterogeneous Effects
The previous page defined the estimand: \(\tau(x) = E[Y(1) - Y(0) \mid X = x]\). This page covers the estimation machinery — how to recover \(\tau(x)\) from data when you don’t know in advance which covariates drive heterogeneity or what functional form it takes.
Causal forests
Causal forests (Wager and Athey, 2018) adapt random forests to estimate \(\tau(x)\) directly. The key idea: instead of splitting to predict the outcome \(Y\), split to maximize the heterogeneity in treatment effects across the resulting subgroups.
How it works
- For each tree, draw a subsample (honesty splitting — see below)
- At each node, find the covariate split that maximizes the difference in treatment effects between the two child nodes
- Within each terminal leaf, estimate the treatment effect by comparing treated and control outcomes (a local difference in means)
- Average predictions across all trees in the forest
The forest produces an estimate \(\hat{\tau}(x)\) for any covariate vector \(x\) — a personalized treatment effect prediction.
Effect-splitting
In a standard random forest, the splitting criterion is variance reduction in \(Y\): find the split that makes \(Y\) most predictable within each child node. In a causal forest, the splitting criterion is treatment effect heterogeneity: find the split such that the treatment effect in the left child differs maximally from the treatment effect in the right child.
This is what makes it a causal forest rather than a predictive one. The algorithm is optimized for finding the covariates that best explain variation in \(\tau(x)\), not variation in \(Y\).
Honesty
A critical innovation. Each tree uses one subsample to determine the splits (the tree structure) and a separate subsample to estimate the treatment effects within leaves. This prevents overfitting: the splits are chosen to find heterogeneity, but the effect estimates are uncontaminated by this search.
This is the same logic as cross-fitting in DML — separate the model selection step from the estimation step. If you use the same data for both, the effect estimates are biased toward finding heterogeneity even when none exists.
Inference
Causal forests come with valid, asymptotic confidence intervals for \(\hat{\tau}(x)\). The variance is estimated via the jackknife (related to the bootstrap):
\[ \hat{\tau}(x) \pm z_{\alpha/2} \cdot \hat{\sigma}(x) \]
This is remarkable for a machine learning method: valid frequentist inference on a nonparametric estimand. The key is that the forest targets a local parameter (the CATE at a point), and the honesty + subsampling structure ensures regularity conditions hold.
#| standalone: true
#| viewerHeight: 680
library(shiny)
ui <- fluidPage(
tags$head(tags$style(HTML("
.stats-box {
background: #f0f4f8; border-radius: 6px; padding: 14px;
margin-top: 12px; font-size: 14px; line-height: 1.9;
}
.stats-box b { color: #2c3e50; }
.good { color: #27ae60; font-weight: bold; }
.bad { color: #e74c3c; font-weight: bold; }
"))),
sidebarLayout(
sidebarPanel(
width = 3,
sliderInput("n", "Sample size (n):",
min = 200, max = 2000, value = 500, step = 200),
selectInput("pattern", "Heterogeneity pattern:",
choices = c(
"Linear: \u03C4(x) = 2 + 3x" = "linear",
"Step: \u03C4(x) = 1 if x<0, 4 if x\u22650" = "step",
"No heterogeneity: \u03C4(x) = 2" = "constant"
)),
actionButton("go", "New draw", class = "btn-primary", width = "100%"),
uiOutput("results")
),
mainPanel(
width = 9,
fluidRow(
column(6, plotOutput("scatter_plot", height = "430px")),
column(6, plotOutput("cate_plot", height = "430px"))
)
)
)
)
server <- function(input, output, session) {
dat <- reactive({
input$go
n <- input$n
pattern <- input$pattern
# Generate data (randomized experiment)
X <- rnorm(n)
D <- rbinom(n, 1, 0.5)
# True CATE function
tau_fn <- switch(pattern,
linear = function(x) 2 + 3 * x,
step = function(x) ifelse(x < 0, 1, 4),
constant = function(x) rep(2, length(x))
)
tau_x <- tau_fn(X)
Y <- tau_x * D + X + rnorm(n)
# Binned CATE estimation
n_bins <- 10
breaks <- quantile(X, probs = seq(0, 1, length.out = n_bins + 1))
breaks[1] <- breaks[1] - 0.01
breaks[n_bins + 1] <- breaks[n_bins + 1] + 0.01
bins <- cut(X, breaks = breaks, labels = FALSE)
bin_mids <- numeric(n_bins)
bin_ests <- numeric(n_bins)
bin_valid <- logical(n_bins)
for (j in 1:n_bins) {
idx <- bins == j
bin_mids[j] <- mean(X[idx])
t_idx <- idx & D == 1
c_idx <- idx & D == 0
if (sum(t_idx) > 1 && sum(c_idx) > 1) {
bin_ests[j] <- mean(Y[t_idx]) - mean(Y[c_idx])
bin_valid[j] <- TRUE
} else {
bin_ests[j] <- NA
bin_valid[j] <- FALSE
}
}
# ATE: true and estimated
true_ate <- mean(tau_x)
est_ate <- mean(Y[D == 1]) - mean(Y[D == 0])
# For plotting the true curve
x_seq <- seq(min(X), max(X), length.out = 200)
tau_seq <- tau_fn(x_seq)
list(X = X, Y = Y, D = D, tau_x = tau_x,
x_seq = x_seq, tau_seq = tau_seq,
bin_mids = bin_mids, bin_ests = bin_ests, bin_valid = bin_valid,
true_ate = true_ate, est_ate = est_ate,
pattern = pattern, n = n, tau_fn = tau_fn)
})
output$scatter_plot <- renderPlot({
d <- dat()
par(mar = c(4.5, 4.5, 3, 1))
plot(d$X, d$Y, pch = 16, cex = 0.4,
col = ifelse(d$D == 1, "#3498db40", "#e74c3c40"),
xlab = "X (covariate)", ylab = "Y (outcome)",
main = "Observed Data")
# True tau(x) curve
lines(d$x_seq, d$tau_seq, col = "#2c3e50", lwd = 2.5)
legend("topleft", bty = "n", cex = 0.8,
legend = c("Treated", "Control", expression("True " * tau * "(x)")),
col = c("#3498db", "#e74c3c", "#2c3e50"),
pch = c(16, 16, NA), lwd = c(NA, NA, 2.5))
})
output$cate_plot <- renderPlot({
d <- dat()
par(mar = c(4.5, 4.5, 3, 1))
valid <- d$bin_valid
ylim <- range(c(d$tau_seq, d$bin_ests[valid]), na.rm = TRUE)
pad <- diff(ylim) * 0.15
ylim <- ylim + c(-pad, pad)
plot(NULL, xlim = range(d$x_seq), ylim = ylim,
xlab = "X (covariate)",
ylab = expression("Treatment effect " * tau * "(x)"),
main = "CATE: Binned Estimates vs Truth")
# True curve
lines(d$x_seq, d$tau_seq, col = "#2c3e50", lwd = 2.5)
# Binned estimates
points(d$bin_mids[valid], d$bin_ests[valid],
pch = 19, cex = 1.6, col = "#9b59b6")
# Connect with segments for clarity
segments(d$bin_mids[valid], d$tau_fn(d$bin_mids[valid]),
d$bin_mids[valid], d$bin_ests[valid],
col = "#9b59b680", lty = 2, lwd = 1.5)
legend("topleft", bty = "n", cex = 0.8,
legend = c(expression("True " * tau * "(x)"), "Binned estimate"),
col = c("#2c3e50", "#9b59b6"),
lwd = c(2.5, NA), pch = c(NA, 19))
})
output$results <- renderUI({
d <- dat()
pat_label <- switch(d$pattern,
linear = "Linear: \u03C4(x) = 2 + 3x",
step = "Step: \u03C4(x) = 1 if x<0, 4 if x\u22650",
constant = "No heterogeneity: \u03C4(x) = 2"
)
tags$div(class = "stats-box",
HTML(paste0(
"<b>Pattern:</b> ", pat_label, "<br>",
"<b>n:</b> ", d$n, "<br>",
"<hr style='margin:6px 0'>",
"<b>True ATE:</b> ", round(d$true_ate, 3), "<br>",
"<b>Estimated ATE:</b> ", round(d$est_ate, 3), "<br>",
"<b>Bias:</b> ", round(d$est_ate - d$true_ate, 3)
))
)
})
}
shinyApp(ui, server)
Things to try
- Linear heterogeneity: the treatment effect increases with \(X\). The binned estimates trace out the linear relationship. With larger \(n\), the estimates get tighter around the truth.
- Step function: a sharp change at \(X = 0\). The binned estimator smooths over this discontinuity — bins near the cutoff mix units from both regimes. More data and finer bins would help.
- No heterogeneity: all bins should estimate roughly the same effect. If they look scattered, that’s sampling noise — there is no real heterogeneity to find.
Meta-learners
Causal forests are one approach to CATE estimation. A broader framework organizes methods by how they decompose the problem into standard supervised learning tasks:
T-learner
Estimate two separate outcome models — one for treated, one for control — and take the difference:
\[ \hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x) \]
where \(\hat{\mu}_1(x) = \hat{E}[Y \mid X = x, D = 1]\) and \(\hat{\mu}_0(x) = \hat{E}[Y \mid X = x, D = 0]\).
Simple and works with any ML method. But the two models are fit independently — each optimizes prediction of \(Y\), not estimation of the difference. Small treatment effects can be lost in the noise of predicting outcomes.
S-learner
Estimate a single model with treatment as a feature:
\[ \hat{\mu}(x, d) = \hat{E}[Y \mid X = x, D = d] \]
Then \(\hat{\tau}(x) = \hat{\mu}(x, 1) - \hat{\mu}(x, 0)\). The model can capture treatment effect heterogeneity through interactions between \(D\) and \(X\).
Risk: the model may not prioritize learning the treatment effect if the main effects of \(X\) dominate. With many strong predictors and a small treatment effect, the model may effectively ignore \(D\).
R-learner (Robinson, 1988; Nie and Wager, 2021)
Residualize both the outcome and treatment on \(X\) (as in the partially linear model from DML):
\[ \tilde{Y}_i = Y_i - \hat{m}(X_i), \qquad \tilde{D}_i = D_i - \hat{e}(X_i) \]
Then estimate \(\tau(x)\) by minimizing:
\[ \hat{\tau}(\cdot) = \arg\min_\tau \sum_{i=1}^n \left(\tilde{Y}_i - \tau(X_i)\tilde{D}_i\right)^2 \]
This is Frisch-Waugh-Lovell generalized to heterogeneous effects. By partialling out the main effects of \(X\) on both \(Y\) and \(D\), the R-learner isolates the treatment effect signal. It directly targets \(\tau(x)\) rather than backing it out from outcome predictions.
Comparison
| Method | Models to fit | Strengths | Weaknesses |
|---|---|---|---|
| T-learner | Two outcome models | Simple, any ML method | Doesn’t optimize for \(\tau\) |
| S-learner | One outcome model | Uses all data jointly | May underweight treatment heterogeneity |
| R-learner | Two nuisance + CATE model | Directly targets \(\tau(x)\) | More complex to implement |
| Causal forest | Forest with effect-splitting | Valid inference, adaptive | Requires unconfoundedness |
Practical cautions
Discovery vs confirmation. These methods are powerful for discovering heterogeneity — finding subgroups where effects differ. But discovery is not confirmation. Best practice: use causal forests or meta-learners to identify potential heterogeneity in one sample, then confirm it in a held-out sample or a pre-registered replication. Discovery and confirmation should use different data.
Multiple testing. Reporting that “the effect is significant for subgroup A but not subgroup B” is a comparison, not a test. To claim the effects differ, you need to test the interaction — and if you searched over many subgroups to find A, you have a multiple testing problem.
Sample size. Estimating \(\tau(x)\) requires more data than estimating \(\tau\). You’re estimating a function, not a number, and you need enough treated and control units at each value of \(x\). With small samples, the CATE estimates will be noisy and shrunk toward the ATE.
Overlap. The CATE at \(x\) requires both treated and control units near \(x\). If overlap fails in some region — few treated units with \(X\) near \(x\) — the CATE estimate there is unreliable. Causal forests handle this gracefully (they simply don’t split in sparse regions), but the conceptual limitation remains.
Connecting to the course
- Heterogeneous Treatment Effects: defines the estimand \(\tau(x)\) — the conceptual foundation this page builds on
- DML: the R-learner uses DML-style residualization; causal forests use similar honesty/cross-fitting logic
- Doubly Robust: the AIPW score can be extended to target CATE instead of ATE — doubly robust CATE estimation
- FWL: the R-learner is FWL generalized to heterogeneous effects
- Multiple Testing: subgroup analysis without correction is a multiple testing problem
- Bootstrap: the jackknife variance estimator in causal forests is related to bootstrap inference