Maximum Likelihood Estimation is Minimizing KL Divergence

Maximum likelihood estimation is about finding model parameters that maximize the likelihood of the data. KL divergence measures how similar one probability distribution is to another. So what do these have in common? This:

arg minθDKL(pq)=arg maxθp(Dθ)\argmin_\theta D_{\text{KL}}(p \parallel q) = \argmax_\theta p(\mathcal{D}|\theta)

Let's take a look at the definition of the KL divergence.

DKL(pq)=Exp[logp(x)q(x)] D_\text{KL} (p \parallel q) = \mathbb{E}_{x \sim p} \left[ \log \frac{p(x)}{q(x)}\right]

This is a good video that nicely explains the KL divergence

Here, pp is the underlying distribution. We never truly have access to this, but we want to approach it using our model qq with parameters θ\theta. So when we're fitting qq we want to set θ\theta such that this divergence is minimized.

arg minθDKL(pq)=arg minθExp[logp(x)q(x)] \argmin_\theta D_\text{KL} (p \parallel q) = \argmin_\theta \mathbb{E}_{x \sim p} \left[ \log \frac{p(x)}{q(x)}\right]

Now, because of this arg min\argmin that we're looking for, we can simplify a bit. We can cross out a term that doesn't matter within the optimization problem: pp, the underlying distribution. It doesn't depend on the model parameters θ\theta at all, so we can just take it out of the equation.

arg minθDKL(pq)=arg minθExp[logp(x)q(x)]=arg minθExp[logp(x)logq(x)]=arg minθExp[logq(x)]=arg minθExp[logq(x)]\begin{aligned} \argmin_\theta D_\text{KL} (p \parallel q) &= \argmin_\theta \mathbb{E}_{x \sim p} \left[ \log \frac{p(x)}{q(x)}\right]\\ &= \argmin_\theta \mathbb{E}_{x \sim p} \left[ \log p(x) - \log q(x)\right]\\ &= \argmin_\theta \mathbb{E}_{x \sim p} \left[ - \log q(x)\right]\\ &= \argmin_\theta - \mathbb{E}_{x \sim p} \left[ \log q(x)\right]\\ \end{aligned}

Now: finding the minimum of some function is the same as finding the maximum of that function flipped upside down, right? So we can use the minus here to flip that arg min\argmin to an arg max\argmax.

arg maxθExp[logq(x)]\argmax_\theta \mathbb{E}_{x \sim p} \left[ \log q(x)\right]

Beautiful. And this looks familar, too! Let's write the expectation out.

arg maxθExp[logq(x)]=arg maxθlogiNp(xiθ)\argmax_\theta \mathbb{E}_{x \sim p} \left[ \log q(x)\right] = \argmax_\theta \log \prod_i^N p(x_i | \theta)

Hah, this is maximizing the log likelihood!

arg maxθlogiNp(xiθ)=arg maxθlogp(Dθ)\argmax_\theta \log \prod_i^N p(x_i | \theta) = \argmax_\theta \log p(\mathcal{D}|\theta)

And maximizing the log likelihood gives you the same θ\theta as maximizing the likelihood itself.

arg maxθlogp(Dθ)=arg maxθp(Dθ)\argmax_\theta \log p(\mathcal{D}|\theta) = \argmax_\theta p(\mathcal{D}|\theta)

And there we have it. Doing maximum likelihood estimation is the same as minimizing KL divergence.

arg minθDKL(pq)=arg maxθp(Dθ)\argmin_\theta D_{\text{KL}}(p \parallel q) = \argmax_\theta p(\mathcal{D}|\theta)