Bayesian Shrinkage

What is shrinkage?

Imagine you’re a baseball scout. It’s early in the season and you need to estimate the true batting average for 50 players, each with only 20 at-bats.

One player went 10-for-20 (.500). Another went 1-for-20 (.050). Are those their true abilities? Probably not — with only 20 at-bats, there’s a ton of noise. The .500 hitter probably got lucky. The .050 hitter probably got unlucky.

Shrinkage says: don’t take the raw numbers at face value. Pull (“shrink”) every estimate toward the overall average. The more uncertain you are about an individual estimate, the more you pull.

\[\hat{\theta}_i^{shrunk} = w_i \cdot \bar{\theta}_{overall} + (1 - w_i) \cdot \hat{\theta}_i^{raw}\]

This is the core of empirical Bayes and James-Stein estimation. It sounds like you’re adding bias — and you are — but you’re reducing variance by more than enough to compensate. The result: better predictions overall.

#| '!! shinylive warning !!': |
#|   shinylive does not work in self-contained HTML documents.
#|   Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 620

library(shiny)

ui <- fluidPage(
  tags$head(tags$style(HTML("
    .stats-box {
      background: #f0f4f8; border-radius: 6px; padding: 14px;
      margin-top: 12px; font-size: 14px; line-height: 1.9;
    }
    .stats-box b { color: #2c3e50; }
    .good { color: #27ae60; font-weight: bold; }
    .bad  { color: #e74c3c; font-weight: bold; }
  "))),

  sidebarLayout(
    sidebarPanel(
      width = 3,

      sliderInput("n_players", "Number of players:",
                  min = 10, max = 100, value = 40, step = 5),

      sliderInput("at_bats", "At-bats per player:",
                  min = 5, max = 200, value = 20, step = 5),

      sliderInput("true_spread", "True talent spread (SD):",
                  min = 0.01, max = 0.08, value = 0.03, step = 0.005),

      actionButton("go", "New season", class = "btn-primary", width = "100%"),

      uiOutput("results")
    ),

    mainPanel(
      width = 9,
      fluidRow(
        column(6, plotOutput("shrinkage_plot", height = "420px")),
        column(6, plotOutput("mse_plot", height = "420px"))
      )
    )
  )
)

server <- function(input, output, session) {

  dat <- reactive({
    input$go
    k   <- input$n_players
    n   <- input$at_bats
    tau <- input$true_spread

    # True batting averages (centered around .260)
    true_avg <- rnorm(k, mean = 0.260, sd = tau)
    true_avg <- pmin(pmax(true_avg, 0.100), 0.400)

    # Observed: hits in n at-bats
    hits <- rbinom(k, size = n, prob = true_avg)
    obs_avg <- hits / n

    # Grand mean
    grand_mean <- mean(obs_avg)

    # Empirical Bayes shrinkage
    # Estimate prior variance from data
    obs_var <- var(obs_avg)
    sampling_var <- mean(obs_avg * (1 - obs_avg) / n)
    prior_var <- max(obs_var - sampling_var, 0.0001)

    # Shrinkage weight (toward grand mean)
    w <- sampling_var / (sampling_var + prior_var)
    shrunk_avg <- w * grand_mean + (1 - w) * obs_avg

    # MSE
    mse_raw   <- mean((obs_avg - true_avg)^2)
    mse_shrunk <- mean((shrunk_avg - true_avg)^2)

    # Future performance (another n at-bats from true ability)
    future_hits <- rbinom(k, size = n, prob = true_avg)
    future_avg  <- future_hits / n

    pred_err_raw   <- mean((obs_avg - future_avg)^2)
    pred_err_shrunk <- mean((shrunk_avg - future_avg)^2)

    list(true_avg = true_avg, obs_avg = obs_avg, shrunk_avg = shrunk_avg,
         grand_mean = grand_mean, w = w,
         mse_raw = mse_raw, mse_shrunk = mse_shrunk,
         pred_err_raw = pred_err_raw, pred_err_shrunk = pred_err_shrunk,
         k = k, n = n)
  })

  output$shrinkage_plot <- renderPlot({
    d <- dat()
    par(mar = c(4.5, 4.5, 3, 1))

    ord <- order(d$obs_avg)

    plot(d$obs_avg[ord], seq_along(ord), pch = 16, col = "#e74c3c",
         xlab = "Batting average", ylab = "Player (sorted by raw avg)",
         main = "Shrinkage in Action",
         xlim = range(c(d$obs_avg, d$shrunk_avg, d$true_avg)))

    points(d$shrunk_avg[ord], seq_along(ord), pch = 17, col = "#3498db")
    points(d$true_avg[ord], seq_along(ord), pch = 4, col = "#27ae60", cex = 0.8)

    # Draw arrows from raw to shrunk
    arrows(d$obs_avg[ord], seq_along(ord),
           d$shrunk_avg[ord], seq_along(ord),
           length = 0.05, col = "#bdc3c780", lwd = 1)

    # Grand mean
    abline(v = d$grand_mean, lty = 2, col = "gray50", lwd = 1.5)

    legend("bottomright", bty = "n", cex = 0.8,
           legend = c("Raw average", "Shrunk estimate",
                      "True ability", "Grand mean"),
           col = c("#e74c3c", "#3498db", "#27ae60", "gray50"),
           pch = c(16, 17, 4, NA),
           lty = c(NA, NA, NA, 2), lwd = c(NA, NA, NA, 1.5))
  })

  output$mse_plot <- renderPlot({
    d <- dat()
    par(mar = c(4.5, 6, 3, 1))

    vals <- c(d$mse_raw, d$mse_shrunk, d$pred_err_raw, d$pred_err_shrunk)
    cols <- c("#e74c3c", "#3498db", "#e74c3c80", "#3498db80")
    labels <- c("Raw\nvs truth", "Shrunk\nvs truth",
                "Raw\nvs future", "Shrunk\nvs future")

    bp <- barplot(vals, col = cols, border = NA,
                  names.arg = labels, cex.names = 0.8,
                  main = "Mean Squared Error",
                  ylab = "MSE", las = 1)
    text(bp, vals + max(vals) * 0.03, round(vals, 5), cex = 0.8)

    pct1 <- round((1 - d$mse_shrunk / d$mse_raw) * 100, 0)
    pct2 <- round((1 - d$pred_err_shrunk / d$pred_err_raw) * 100, 0)

    mtext(paste0("Shrinkage reduces estimation error by ~", pct1, "%"),
          side = 1, line = 3.5, cex = 0.85, col = "#2c3e50")
  })

  output$results <- renderUI({
    d <- dat()
    pct_est <- round((1 - d$mse_shrunk / d$mse_raw) * 100, 1)
    pct_pred <- round((1 - d$pred_err_shrunk / d$pred_err_raw) * 100, 1)

    tags$div(class = "stats-box",
      HTML(paste0(
        "<b>Shrinkage weight:</b> ", round(d$w * 100, 1),
        "% toward grand mean<br>",
        "<b>Grand mean:</b> ", round(d$grand_mean, 3), "<br>",
        "<hr style='margin:8px 0'>",
        "<b>MSE (raw):</b> ", round(d$mse_raw, 5), "<br>",
        "<b>MSE (shrunk):</b> <span class='good'>",
        round(d$mse_shrunk, 5), "</span><br>",
        "<b>Improvement:</b> <span class='good'>", pct_est, "%</span><br>",
        "<hr style='margin:8px 0'>",
        "<b>Prediction error (raw):</b> ", round(d$pred_err_raw, 5), "<br>",
        "<b>Prediction error (shrunk):</b> <span class='good'>",
        round(d$pred_err_shrunk, 5), "</span><br>",
        "<b>Improvement:</b> <span class='good'>", pct_pred, "%</span>"
      ))
    )
  })
}

shinyApp(ui, server)

Things to try

  • At-bats = 5: extreme noise. Raw averages are all over the place (some players show .000 or .600). Shrinkage pulls them heavily toward the mean — and the green crosses (true ability) confirm the shrunk estimates are closer.
  • At-bats = 200: lots of data per player. Shrinkage is minimal because the raw averages are already precise. With enough data, shrinkage vanishes.
  • Look at the MSE bars: shrinkage almost always wins, especially with small samples. It also predicts future performance better.
  • True talent spread = 0.01 (everyone is similar): shrinkage is aggressive because individual differences are small relative to noise.
  • True talent spread = 0.08 (wide range of talent): shrinkage is lighter because individual differences are real, not noise.

Why does this work?

It seems wrong to move estimates away from the data. But consider what happens without shrinkage:

  • Players who got lucky are overestimated
  • Players who got unlucky are underestimated
  • These errors don’t cancel — they inflate the overall MSE

Shrinkage dampens both overestimates and underestimates simultaneously. The small bias it introduces (pulling everyone toward the mean) is more than offset by the massive reduction in variance. This is the bias-variance tradeoff in action.

Where you see this in practice

Method What gets shrunk
Ridge regression Coefficients toward zero
LASSO Coefficients toward zero (with selection)
Random effects models Group means toward grand mean
Empirical Bayes Individual estimates toward overall mean
Bayesian priors Posteriors toward prior mean

They all share the same logic: when you have many noisy estimates, borrowing strength across them improves every single one.