Training Neural Networks as Maximum Likelihood
The loss functions used to train neural networks are not arbitrary design choices. Most of them are negative log-likelihoods in disguise. If you understand MLE, you already understand what neural network training is doing — and, critically, what it is not doing.
Cross-entropy is negative log-likelihood
The standard loss for classification is cross-entropy. For a binary outcome \(Y_i \in \{0, 1\}\) and a model that predicts \(\hat{p}_i = P(Y_i = 1 \mid X_i)\):
\[ \mathcal{L} = -\frac{1}{n}\sum_{i=1}^n \left[Y_i \log \hat{p}_i + (1 - Y_i)\log(1 - \hat{p}_i)\right] \]
Compare this to the log-likelihood of a Bernoulli model from the MLE page:
\[ \ell(p) = \sum_{i=1}^n \left[Y_i \log p_i + (1 - Y_i)\log(1 - p_i)\right] \]
They are the same expression, up to a sign and a scaling constant. Minimizing cross-entropy loss is maximizing the Bernoulli log-likelihood. The neural network’s output layer (sigmoid activation) parameterizes \(p_i\) as a flexible function of \(X_i\), but the estimation principle is identical to logistic regression — which is itself MLE.
For multi-class classification with \(K\) categories, softmax cross-entropy is the negative log-likelihood of a multinomial model. For regression with squared-error loss:
\[ \mathcal{L} = \frac{1}{n}\sum_{i=1}^n (Y_i - \hat{Y}_i)^2 \]
this is the negative log-likelihood of a Gaussian model with constant variance — exactly the same equivalence shown on the MLE page between OLS and MLE under normality.
| Loss function | Equivalent to | Implicit distributional assumption |
|---|---|---|
| Squared error (MSE) | Gaussian MLE | \(Y \mid X \sim N(\hat{Y}, \sigma^2)\) |
| Binary cross-entropy | Bernoulli MLE | \(Y \mid X \sim \text{Bernoulli}(\hat{p})\) |
| Categorical cross-entropy | Multinomial MLE | \(Y \mid X \sim \text{Multinomial}(\hat{p}_1, \ldots, \hat{p}_K)\) |
SGD as approximate MLE
Classical MLE computes the gradient of the full log-likelihood and solves the score equation exactly. Neural networks can’t do this — the models are nonconvex and the datasets are enormous. Instead, they use stochastic gradient descent (SGD): at each step, sample a mini-batch of data, compute the gradient of the loss on that mini-batch, and take a step.
This is approximate MLE. The mini-batch gradient is a noisy, unbiased estimate of the full gradient. Over many steps, SGD traces out a path that (under regularity conditions) converges to a local maximum of the likelihood — though not necessarily the global one.
The analogy to classical statistics is instructive:
| Classical MLE | Neural network training |
|---|---|
| Full-sample gradient, solve exactly | Mini-batch gradient, iterate |
| Convex log-likelihood (often) | Nonconvex loss landscape |
| Single global optimum (typically) | Multiple local optima |
| Closed-form or Newton-Raphson | SGD, Adam, or variants |
| Fisher information → standard errors | No standard errors by default |
The last row matters. Classical MLE gives you standard errors through the Fisher information. Neural network training typically does not. The model gives you a point prediction, but no measure of uncertainty — a limitation addressed in Calibration and Uncertainty.
The simulation below runs gradient descent on a logistic regression — first with the full sample (full-batch, blue), then with random mini-batches (red). Both converge to the same MLE, but the mini-batch path is noisier — exactly the tradeoff described above.
Things to try
- Full batch (batch size = n): the loss curve descends smoothly to the MLE minimum. No noise.
- Small batch size: the loss curve is noisy — each step uses a different random subset, so the gradient estimate is noisy. But it still converges to the right neighborhood.
- Very small batch + high learning rate: the path oscillates wildly. This is why learning rate schedules (reducing the step size over time) are important in practice.
- Large n: the mini-batch noise is smaller relative to the signal, so even small batches converge smoothly.
#| standalone: true
#| viewerHeight: 680
library(shiny)
ui <- fluidPage(
tags$head(tags$style(HTML("
.stats-box {
background: #f0f4f8; border-radius: 6px; padding: 14px;
margin-top: 12px; font-size: 14px; line-height: 1.9;
}
.stats-box b { color: #2c3e50; }
"))),
sidebarLayout(
sidebarPanel(
width = 3,
sliderInput("n", "Sample size (n):",
min = 200, max = 2000, value = 500, step = 200),
sliderInput("batch", "Batch size:",
min = 10, max = 500, value = 50, step = 10),
sliderInput("lr", "Learning rate:",
min = 0.1, max = 5, value = 1, step = 0.1),
sliderInput("n_iter", "Iterations:",
min = 50, max = 500, value = 200, step = 50),
actionButton("go", "New draw", class = "btn-primary", width = "100%"),
uiOutput("results")
),
mainPanel(
width = 9,
fluidRow(
column(6, plotOutput("loss_plot", height = "470px")),
column(6, plotOutput("beta_plot", height = "470px"))
)
)
)
)
server <- function(input, output, session) {
dat <- reactive({
input$go
n <- input$n
bs <- min(input$batch, n)
lr <- input$lr
n_iter <- input$n_iter
# Generate logistic regression data
X <- rnorm(n)
p_true <- 1 / (1 + exp(-(0.5 + 1.5 * X)))
Y <- rbinom(n, 1, p_true)
# Exact MLE via Newton-Raphson (manual, to avoid glm overhead)
beta_mle <- c(0, 0)
for (nr in 1:50) {
eta <- beta_mle[1] + beta_mle[2] * X
p_hat <- 1 / (1 + exp(-eta))
W_nr <- p_hat * (1 - p_hat)
Xmat <- cbind(1, X)
grad <- t(Xmat) %*% (Y - p_hat)
H <- -t(Xmat) %*% (Xmat * W_nr)
beta_mle <- beta_mle - solve(H) %*% grad
}
beta_mle <- as.numeric(beta_mle)
# MLE loss
eta_mle <- beta_mle[1] + beta_mle[2] * X
p_mle <- 1 / (1 + exp(-eta_mle))
p_mle <- pmax(pmin(p_mle, 1 - 1e-10), 1e-10)
loss_mle <- -mean(Y * log(p_mle) + (1 - Y) * log(1 - p_mle))
# --- Full-batch gradient descent ---
beta_fb <- c(0, 0)
loss_fb <- numeric(n_iter)
beta1_fb <- numeric(n_iter)
for (t in 1:n_iter) {
eta <- beta_fb[1] + beta_fb[2] * X
p_hat <- 1 / (1 + exp(-eta))
grad <- c(-mean(Y - p_hat), -mean((Y - p_hat) * X))
beta_fb <- beta_fb - lr * grad
eta2 <- beta_fb[1] + beta_fb[2] * X
p2 <- 1 / (1 + exp(-eta2))
p2 <- pmax(pmin(p2, 1 - 1e-10), 1e-10)
loss_fb[t] <- -mean(Y * log(p2) + (1 - Y) * log(1 - p2))
beta1_fb[t] <- beta_fb[2]
}
# --- Mini-batch SGD ---
beta_mb <- c(0, 0)
loss_mb <- numeric(n_iter)
beta1_mb <- numeric(n_iter)
for (t in 1:n_iter) {
idx <- sample(1:n, bs, replace = TRUE)
x_b <- X[idx]
y_b <- Y[idx]
eta <- beta_mb[1] + beta_mb[2] * x_b
p_hat <- 1 / (1 + exp(-eta))
grad <- c(-mean(y_b - p_hat), -mean((y_b - p_hat) * x_b))
beta_mb <- beta_mb - lr * grad
# Full-sample loss for tracking
eta2 <- beta_mb[1] + beta_mb[2] * X
p2 <- 1 / (1 + exp(-eta2))
p2 <- pmax(pmin(p2, 1 - 1e-10), 1e-10)
loss_mb[t] <- -mean(Y * log(p2) + (1 - Y) * log(1 - p2))
beta1_mb[t] <- beta_mb[2]
}
list(loss_fb = loss_fb, loss_mb = loss_mb,
beta1_fb = beta1_fb, beta1_mb = beta1_mb,
loss_mle = loss_mle, beta_mle = beta_mle,
n_iter = n_iter, bs = bs, n = n)
})
output$loss_plot <- renderPlot({
d <- dat()
par(mar = c(4.5, 4.5, 3, 1))
ylims <- range(c(d$loss_fb, d$loss_mb, d$loss_mle), na.rm = TRUE)
ylims[2] <- min(ylims[2], ylims[1] + (ylims[2] - ylims[1]) * 2)
plot(1:d$n_iter, d$loss_fb, type = "l", col = "#3498db", lwd = 2,
main = "Loss vs Iteration",
xlab = "Iteration", ylab = "Negative log-likelihood",
ylim = ylims)
lines(1:d$n_iter, d$loss_mb, col = "#e74c3c", lwd = 1.5)
abline(h = d$loss_mle, lty = 2, lwd = 2, col = "#2c3e50")
legend("topright", bty = "n", cex = 0.85,
legend = c("Full-batch GD", paste0("Mini-batch SGD (bs=", d$bs, ")"),
"MLE optimum"),
col = c("#3498db", "#e74c3c", "#2c3e50"),
lwd = c(2, 1.5, 2), lty = c(1, 1, 2))
})
output$beta_plot <- renderPlot({
d <- dat()
par(mar = c(4.5, 4.5, 3, 1))
ylims <- range(c(d$beta1_fb, d$beta1_mb, d$beta_mle[2]), na.rm = TRUE)
plot(1:d$n_iter, d$beta1_fb, type = "l", col = "#3498db", lwd = 2,
main = expression("Parameter " * hat(beta)[1] * " vs Iteration"),
xlab = "Iteration", ylab = expression(hat(beta)[1]),
ylim = ylims)
lines(1:d$n_iter, d$beta1_mb, col = "#e74c3c", lwd = 1.5)
abline(h = d$beta_mle[2], lty = 2, lwd = 2, col = "#2c3e50")
legend("bottomright", bty = "n", cex = 0.85,
legend = c("Full-batch GD", paste0("Mini-batch SGD (bs=", d$bs, ")"),
expression("MLE " * hat(beta)[1])),
col = c("#3498db", "#e74c3c", "#2c3e50"),
lwd = c(2, 1.5, 2), lty = c(1, 1, 2))
})
output$results <- renderUI({
d <- dat()
fb_final <- d$beta1_fb[d$n_iter]
mb_final <- d$beta1_mb[d$n_iter]
mle_val <- d$beta_mle[2]
fb_err <- abs(fb_final - mle_val)
mb_err <- abs(mb_final - mle_val)
tags$div(class = "stats-box",
HTML(paste0(
"<b>MLE solution:</b><br>",
" \u03b2\u0302\u2080 = ", round(d$beta_mle[1], 3),
", \u03b2\u0302\u2081 = ", round(mle_val, 3), "<br>",
"<hr style='margin:6px 0'>",
"<b>Full-batch final:</b><br>",
" \u03b2\u0302\u2081 = ", round(fb_final, 3),
" (|error| = ", round(fb_err, 4), ")<br>",
"<b>Mini-batch final:</b><br>",
" \u03b2\u0302\u2081 = ", round(mb_final, 3),
" (|error| = ", round(mb_err, 4), ")"
))
)
})
}
shinyApp(ui, server)
What training optimizes — and what it does not
Training a neural network finds parameters \(\hat{\theta}\) that minimize prediction error on the training distribution. This is optimization of a statistical objective: \(\hat{\theta} = \arg\min_\theta \mathcal{L}(\theta)\).
But prediction is not the only thing you might care about. The MLE page noted that if the model is misspecified, MLE converges to the distribution closest to the truth in Kullback-Leibler divergence — not necessarily the “right” answer.
For neural networks, this matters acutely:
- The model learns \(P(Y \mid X)\) — the conditional distribution. It does not learn \(P(Y \mid do(X))\) — the interventional distribution. This distinction is explored in Prediction vs Causation in Foundation Models.
- The model minimizes loss on the training distribution. If the deployment distribution differs (distribution shift), the guarantees vanish.
- The model has no notion of identification. It finds a good predictor, not a causally interpretable parameter.
Connecting to the course
This page bridges two frameworks:
- MLE provides the estimation principle. Neural network training is MLE (or regularized MLE) with a flexible function class.
- Regularization as Bayesian inference shows that weight decay, dropout, and other regularization techniques have principled statistical interpretations — they are not ad hoc tricks.
- The Algebra Behind OLS derived standard errors from \((X'X)^{-1}\). Neural networks lack this closed-form machinery, which is why uncertainty quantification requires separate tools.