Por que a divergência KL é usada com tanta frequência no aprendizado de máquina?

Dec 15 2020

A Divergência KL é muito fácil de calcular em forma fechada para distribuições simples - como gaussianas - mas tem algumas propriedades não muito boas. Por exemplo, não é simétrico (portanto, não é uma métrica) e não respeita a desigualdade triangular.

Por que ele é usado com tanta frequência no ML? Não existem outras distâncias estatísticas que podem ser usadas?

Respostas

2 rhdxor Dec 19 2020 at 16:52

Esta questão é muito geral no sentido de que o motivo pode ser diferente dependendo da área de ML que você está considerando. Abaixo estão duas áreas diferentes de ML onde a divergência KL é uma consequência natural:

  • Classificação: maximizar o log-verossimilhança (ou minimizar o log-verossimilhança negativo) é equivalente a minimizar a divergência KL como típico usado na classificação baseada em DL onde alvos one-hot são comumente usados ​​como referência (verhttps://stats.stackexchange.com/a/357974) Além disso, se você tiver um vetor one-hot$e_y$ com $1$ no índice $y$, minimizando a entropia cruzada $\min_{\hat{p}}H(e_y, \hat{p}) = - \sum_y e_y \log \hat{p}_y = - \log \hat{p}$se resume a maximizar a probabilidade de log. Em resumo, maximizar a probabilidade de log é indiscutivelmente um objetivo natural, e divergência KL (com 0 log 0 definido como 0) surge por causa de sua equivalência com a probabilidade de log em configurações típicas, em vez de ser explicitamente motivado como o objetivo.
  • Bandidos multi-armados (uma subárea de aprendizagem por reforço): Limite de confiança superior (UCB) é um algoritmo derivado de desigualdades de concentração padrão. Se considerarmos MABs com recompensas de Bernoulli, podemos aplicar o limite de Chernoff e otimizar o parâmetro livre para obter um limite superior expresso em termos de divergência KL, conforme indicado abaixo (verhttps://page.mi.fu-berlin.de/mulzer/notes/misc/chernoff.pdf para algumas provas diferentes).

Deixei $X_1, \dots, X_n$ ser iid Bernoulli RVs com parâmetro $p$. $$P(\sum_i X_i \geq (p+t)n) \leq \inf_\lambda M_X (\lambda) e^{-\lambda t} = \exp(-n D_{KL}(p+t||p)).$$

1 ArayKarjauv Dec 19 2020 at 21:11

Em ML, sempre lidamos com distribuições de probabilidade desconhecidas de onde vêm os dados. A maneira mais comum de calcular a distância entre a distribuição real e o modelo é$KL$ divergência.

Por que divergência Kullback-Leibler?

Embora existam outras funções de perda (por exemplo, MSE, MAE), $KL$a divergência é natural quando estamos lidando com distribuições de probabilidade. É uma equação fundamental na teoria da informação que quantifica, em bits, a proximidade de duas distribuições de probabilidade. É também chamada de entropia relativa e, como o nome sugere, está intimamente relacionada à entropia, que por sua vez é um conceito central na teoria da informação. Vamos relembrar a definição de entropia para um caso discreto:

$$ H = -\sum_{i=1}^{N} p(x_i) \cdot \text{log }p(x_i) $$

Como você observou, a entropia por si só é apenas uma medida de uma única distribuição de probabilidade. Se modificarmos ligeiramente esta fórmula adicionando uma segunda distribuição, obtemos$KL$ divergência:

$$ D_{KL}(p||q) = \sum_{i=1}^{N} p(x_i)\cdot (\text{log }p(x_i) - \text{log }q(x_i)) $$

Onde $p$ é uma distribuição de dados e $q$ é a distribuição do modelo.

Como podemos ver, $KL$divergência é a maneira mais natural de comparar 2 distribuições. Além disso, é muito fácil de calcular. Este artigo fornece mais intuição sobre isso:

Essencialmente, o que estamos vendo com a divergência KL é a expectativa da diferença logarítmica entre a probabilidade dos dados na distribuição original com a distribuição aproximada. Novamente, se pensarmos em termos de$log_2$ podemos interpretar isso como "quantos bits de informação esperamos perder".

Entropia cruzada

A entropia cruzada é comumente usada no aprendizado de máquina como uma função de perda onde temos a camada de saída softmax (ou sigmóide), uma vez que representa uma distribuição preditiva sobre as classes. A saída one-hot representa uma distribuição de modelo$q$, enquanto rótulos verdadeiros representam uma distribuição de destino $p$. Nosso objetivo é empurrar$q$ para $p$o mais perto possível. Poderíamos tomar um erro médio quadrático sobre todos os valores, ou poderíamos somar as diferenças absolutas, mas a única medida que é motivada pela teoria da informação é a entropia cruzada. Ele fornece o número médio de bits necessários para codificar amostras distribuídas como$p$, usando $q$ como a distribuição de codificação.

Entropia cruzada com base na entropia e geralmente calcula a diferença entre duas distribuições de probabilidade e intimamente relacionada com $KL$divergência. A diferença é que ele calcula a entropia total entre as distribuições, enquanto$KL$divergência representa entropia relativa. A entropia de Corss pode ser definida da seguinte forma:

$$ H(p, q) = H(p) + D_{KL}(p \parallel q) $$

O primeiro termo nesta equação é a entropia da verdadeira distribuição de probabilidade $p$ que é omitido durante a otimização, uma vez que a entropia de $p$é constante. Portanto, minimizar a entropia cruzada é o mesmo que otimizar$KL$ divergência.

Log probabilidade

Também pode ser mostrado que maximizar a probabilidade (log) é equivalente a minimizar a entropia cruzada.

Limitações

Como você mencionou, $KL$a divergência não é simétrica. Mas, na maioria dos casos, isso não é crítico, uma vez que queremos estimar a distribuição do modelo empurrando-o para o real, mas não vice-versa. Há também uma versão simetrizada chamada divergência de Jensen-Shannon :$$ D_{JS}(p||q)=\frac{1}{2}D_{KL}(p||m)+\frac{1}{2}D_{KL}(q||m) $$ Onde $m=\frac{1}{2}(p+q)$.

A principal desvantagem de $KL$é que tanto a distribuição desconhecida quanto a distribuição do modelo devem ter suporte. Caso contrário, o$D_{KL}(p||q)$ torna-se $+\infty$ e $D_{JS}(p||q)$ torna-se $log2$

Em segundo lugar, deve-se notar que $KL$não é uma métrica, pois viola a desigualdade do triângulo. Ou seja, em alguns casos, ele não nos dirá se estamos indo na direção certa ao estimar a distribuição do nosso modelo. Aqui está um exemplo tirado desta resposta . Dadas duas distribuições discretas$p$ e $q$, nós calculamos $KL$ divergência e métrica de Wasserstein:

Como você pode ver, $KL$ a divergência permaneceu a mesma, enquanto a métrica de Wasserstein diminuiu.

Mas, conforme mencionado nos comentários, a métrica Wasserstein é altamente intratável em um espaço contínuo. Ainda podemos usá-lo aplicando a dualidade Kantorovich-Rubinstein usada em Wasserstein GAN . Você também pode encontrar mais informações sobre este tópico neste artigo .

As 2 desvantagens de $KL$pode ser mitigado adicionando ruído. Mais sobre isso neste artigo