Causal Forests

Estimating treatment effects often requires moving beyond linear models, which make rather strict assumptions about the relationships in the data and treat them as uniform across all observations. Causal forests (Athey, Tibshirani, and Wager 2019), a non-parametric ensemble method, extend the principles of random forests to estimate treatment effects that vary across subgroups. This approach captures complex, localized patterns in the data. Unlike traditional decision trees, which focus on improving overall predictive accuracy, causal trees (Athey and Imbens 2016) within a causal forest are designed to split the data based on variations in treatment effects rather than variability in the observed outcome variable.

This design allows causal forests to identify subgroups with differing (average) treatment effects, uncovering complex interactions and local structures: Causal forests perform adaptive smoothing, averaging treatment effects over similar data points to produce smooth estimates without overfitting.

In the following, we demonstrate how causal forests leverage local smoothing to estimate conditional average treatment effects (CATEs) in an example with a highly non-linear dependence between two features and individual treatment effects.1

CATE estimation using CF adaptive smoothing

We begin by simulating data where the true treatment effect follows a complex non-linear function of the features and then compare different estimation approaches.

First, we load the necessary libraries and then set up our DGP.

library(cowplot)
library(dplyr)
library(ggplot2)
library(grf)
library(tidyr)
library(tibble)

Data Generation

Now we’ll simulate our dataset according to the DGP in eq. \eqref{eq:cfdgp}. $$ \begin{align} \begin{split} x_1 & \sim \text{U}(0, 1), \quad x_2 \sim \text{U}(0, 1), \quad \text{T} \sim \text{Bern}(0.5) \\ \tau & = 2 + 5 \cdot \left(1 - e^{-(2x_1^2 + 2x_2^2)}\right) + \sin(3\pi x_1) + \cos(3\pi x_2) + e^{-x_1 \cdot x_2} \\ \text{y} & = 2 + x_1 + x_2 + \tau \cdot \text{T} + \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1) \end{split}\label{eq:cfdgp} \end{align} $$

The functional form of the relationship between the covariates $x_1$ and $x_2$ and the treatment effect $\tau$ is highly non-linear, including exponential terms, periodic components (to introduce some cyclical patterns), and an interaction term between $x_1$ and $x_2$.2

We’ll create a treatment effect function true_TE() that combines these non-linear components.

# function for computing true heterogeneous (individual) treatment effects
true_TE <- function(x1, x2) {
  term1 <- 5 * (1 - exp(-(2 * x1^2 + 2 * x2^2))) + 2 
  term2 <- sin(3 * pi * x1) + cos(3 * pi * x2) 
  term4 <- exp(-x1 * x2)                              
  return(term1 + term2 + term4)
}
# Data generation
set.seed(123)

n <- 20000  # Sample size
x1_rand <- runif(n)
x2_rand <- runif(n)
treat <- rbinom(n, 1, 0.5) 

# true treatment effects...
TE <- true_TE(x1_rand, x2_rand)

# ... and outcomes
y <- 2 + x1_rand + x2_rand + treat * TE + rnorm(n)

# gather
df_rand <- data.frame(
  x1 = x1_rand,
  x2 = x2_rand,
  treat = treat,
  y = y
)

Model Fitting and Comparison

We’ll compare three approaches for estimating CATEs:

  1. A linear regression model with polynomial terms and interactions
  2. An “out-of-the-box” causal forest, fitted with the grf package by Tibshirani et al. (2024).3
  3. A tuned causal forest with optimized parameters4
# Prepare matrices for causal forest
X_rand <- df_rand %>% 
select(x1, x2) %>% 
as.matrix()

y_rand <- df_rand$y
W_rand <- df_rand$treat

# fit the CF models
cf_untuned <- causal_forest(
  X = X_rand, 
  Y = y_rand, 
  W = W_rand,
  num.trees = 4000
)

cf_tuned <- causal_forest(
  X = X_rand, 
  Y = y_rand, 
  W = W_rand,
  tune.parameters = "all",
  num.trees = 4000
)

We can compute the estimated average treatment effect (ATE) using grf::average_treatment_effect().

# ATE for untuned causal forest
average_treatment_effect(cf_untuned)
  estimate    std.err 
6.17246241 0.01832128 
# ATE for tuned causal forest
average_treatment_effect(cf_tuned)
  estimate    std.err 
6.17160020 0.01798565 

The estimated Average Treatment Effect (ATE) is approximately 6.17. The grf package leverages asymptotic theory from Athey, Tibshirani, and Wager (2019) to provide standard errors for these estimates, enabling (asymptotically normal) inference about treatment effects. This theoretical foundation is a key advantage of causal forests over many other machine learning methods.

We proceed and fit a linear regression model, controlling for interactions of polynomial terms in $x_1$ and $x_2$ with the treatment indicator $\text{treat}$.

# fit the linear model
linear_model <- lm(
  formula = 
  y ~ (
        x1 + x2 + I(x1^2) + I(x2^2) + I(x1^3) + I(x2^3) 
        + x1:x2 + I(x1^2):x2 + x1:I(x2^2) + I(x1^3):I(x2^3)
      ) * treat, 
  data = df_rand
)

In the regression model, the estimated ATE is the estimated coefficient of treat,

$$ \widehat{\tau}_{\textup{ATE}} = \widehat{\beta}_{\textup{treat}}. $$

# ATE with linear regression
summary(linear_model)$coefficients["treat", ]
     Estimate    Std. Error       t value      Pr(>|t|) 
 5.872876e+00  1.591547e-01  3.690042e+01 2.055669e-288 

The ATE estimate and its standard error are comparable to the random forests.

Now that we’ve examined the Average Treatment Effects, let’s dive deeper into how these models perform when estimating CATEs across different regions of our feature space. We’ll create a detailed visualization comparing CATE predictions over a fine grid spanning the domain of our covariates $x_1$ and $x_2$. This will help us understand how each model captures the underlying heterogeneity in treatment effects.

CATE preditictions

To compare the reliability of each approach in terms of CATE estimation, we’ll create a grid of predictions and visualize the estimated treatment effects. For the causal forest models, we get CATEs using the corresponding predict() method. For the linear model, we compute the CATE estimates for some $(x_1, x_2)’$ as

$$ \widehat{\tau}_{\text{CATE}}(x_1, x_2) = \widehat{f}(x_1, x_2 \mid \text{treat} = 1) - \widehat{f}(x_1, x_2 \mid \text{treat} = 0), $$

where $\widehat{f}(x_1, x_2 \mid \text{treat})$ denotes the estimated regression function.

# Create prediction grid
# grid for covariates (x1, x2)
x1 <- seq(0, 1, length.out = 100)
x2 <- x1

grid_data <- expand.grid(
    x1 = seq(0, 1, length.out = 100), 
    x2 = seq(0, 1, length.out = 100)
)
X_grid <- as.matrix(grid_data)

# Generate TE estimates over grid
## CATE with tuned CF
grid_data$pred_CATE_tuned <- predict(cf_tuned, X_grid)$predictions

## CATE with untuned CF
grid_data$pred_CATE_untuned <- predict(cf_untuned, X_grid)$predictions

# CATE with linear model
grid_data$pred_CATE_lm <- 
    predict(
        linear_model, 
        newdata = data.frame(
            x1 = grid_data$x1,
            x2 = grid_data$x2,
            treat = 1
        )
    ) -
    predict(
        linear_model, 
        newdata = data.frame(
            x1 = grid_data$x1,
            x2 = grid_data$x2,
            treat = 0
        )
    )

We next visualize the true effects in a hexbin plot with ggplot2 and compare against the treatment effects predictions of the model, gathering the plots in a cowplot::plot_grid().

# Plot the true treatment effects
p_truth <- 
  outer(
    X = x1,
    Y = x2,
    FUN = true_TE
  ) %>%
  as_tibble() %>%
  rownames_to_column("x1") %>%
  pivot_longer(cols = -x1, names_to = "x2", values_to = "value") %>%
  mutate(
    x1 = as.numeric(x1) / 100,
    x2 = as.numeric(gsub("V", "", x2)) / 100
  ) %>%

  ggplot(aes(x1, x2, fill = value)) +
  geom_hex(stat = "identity") +
  scale_fill_viridis_c(
    name = "TE", 
    option = "mako",     
    guide = guide_colorbar(
      barheight = unit(6, "inch"), 
      legend.position = "right", 
      legend.direction = "vertical"
    )
  ) +
  labs(title = "True Effect") +
  theme_minimal() +
  theme(
    axis.title.y = element_blank(), 
    axis.title.x = element_blank()
  )

# Plot function for est. effects to avoid code repetition
plot_CATE <- function(data, CATE_col, title, legendpos = "none") {
  ggplot(data, aes(x1, x2, fill = !!sym(CATE_col))) +
    geom_tile() +
    scale_fill_viridis_c(option = "mako") +
    labs(title = title) +
    theme_minimal() +
    theme(
      legend.position = legendpos, 
      axis.title.y = element_blank(), 
      axis.title.x = element_blank()
    )
}

# Plot predicted CATEs with and without tuning, and from the linear model
p_CF_untuned <- plot_CATE(
  data = grid_data, 
  "pred_CATE_untuned", 
  "CF: Est. CATE without Tuning"
)
p_CF_tuned <- plot_CATE(
  data = grid_data, 
  "pred_CATE_tuned", 
  "CF: Est. CATE with Tuning"
)
p_lm <- plot_CATE(
  data =grid_data, 
  "pred_CATE_lm", 
  "LM: Est. CATE"
)
# Gather all plots in a plot_grid
ggdraw(
  plot_grid(
    plot_grid(
      plotlist = list(
        p_truth + theme(legend.position = "none"), p_lm,
        p_CF_untuned, p_CF_tuned
      ), ncol = 2,  align = "hv"
    ), 
    NULL, get_legend(p_truth), 
    ncol = 3, rel_widths = c(1, .025, .075)
  )
)

Figure 1: Truth vs. predicted CATEs

Figure 1 reveals several key properties of the methods:

The true treatment effect surface (top-left) shows complex non-linear patterns and interactions that are difficult to specify correctly in a parametric approach like linear regression (top-right). Clearly, including polynomial terms and their interactions in an attempt to account for unknown non-linearities (which are often present in empirical applications) fails to capture the complexity of the true treatment effect surface.

The untuned causal forest (bottom-left) demonstrates its ability to approximate the true treatment effect surface by capturing the overall structure of complex non-linear patterns and interactions. This performance arises from the core mechanism of causal forests: combining ensemble information from many causal trees, each of which partitions the predictor space ($x_1$, $x_2$) to identify regions with homogeneous treatment effects. By leveraging these partitions, the causal forest effectively estimates heterogeneous treatment effects while averaging over similar observations locally.

While the predictions show some granularity5, the ensemble approach ensures that the model identifies key regions of variation in the true effects, offering a good baseline for further refinement. The causal forest’s ability to perform adaptive smoothing—averaging treatment effects across observations within each partition—results in approximations of the true effect surface that are both flexible and robust.

The tuned causal forest (bottom-right) slightly refines the partitions, adjusting the smoothing process based on cross-validated parameter selection. This may yield predictions with somewhat less noise, showing how the tuning process optimizes the balance between capturing fine-grained local effects and avoiding overfitting.

Conclusion

Accurate estimation of Conditional Average Treatment Effects (CATEs) is crucial for identifying subgroups that would benefit most from targeted interventions. For instance, during the COVID-19 pandemic, vaccination awareness campaigns had varying impacts across different populations, influenced by factors such as education levels, healthcare access, and cultural attitudes toward vaccines.6 With reliable CATE estimates, policymakers can allocate resources more effectively, ensuring that public health initiatives reach and resonate with the communities that need them most. This targeted approach may not only enhance overall health outcomes but also addresses disparities, promoting equity in healthcare delivery.

References

Athey, Susan, and Guido Imbens. 2016. “Recursive Partitioning for Heterogeneous Causal Effects.” Proceedings of the National Academy of Sciences 113 (27): 7353–60. https://doi.org/10.1073/pnas.1510489113.

Athey, Susan, Julie Tibshirani, and Stefan Wager. 2019. “Generalized Random Forests.” The Annals of Statistics 47 (2). https://doi.org/10.1214/18-aos1709.

Tibshirani, Julie, Susan Athey, Erik Sverdrup, and Stefan Wager. 2024. Grf: Generalized Random Forests. https://CRAN.R-project.org/package=grf.


  1. Empirical applications of causal forests often involve datasets with many covariates. For simplicity, this example is limited to two covariates to better illustrate the data-generating process (DGP). ↩︎

  2. For notational simplicity, we omit observation indices here. Note though that $\tau$ represents an individual treatment effect (ITE). ↩︎

  3. Here we just use standard parameters of grf::causal_forest()↩︎

  4. The grf package use a cross-validation procedure to tune model parameters. See ?grf::causal_forest↩︎

  5. ‘Granularity’ here means that the model is picking up some local structure in the training data that is due to randomness, i.e., the model is slightly overfitting↩︎

  6. See, for example https://www.cdc.gov/pcd/issues/2020/20_0245.htm ↩︎