Cross entropyやsoftmaxに関するメモ

Posted: 2020-03-05 (Updated: 2022-10-07)

背景にある発想

\(K\)-面サイコロについて、出目が\( k \)となる確率を\( q_k \) とする。このとき各\( q_k \) について\( 0 \leq p \leq 1 \)かつ\( \sum_{k = 1}^K q_k =1 \) である。したがって\( \boldsymbol{q} = {\begin{pmatrix} q_1 & \cdots & q_K\end{pmatrix}}^\top \)は\((K - 1)\)-標準単体(standard simplex) \( \varDelta^{K -1} = \{ \boldsymbol{x} \in \mathbb{R}^K \mid \sum_{i = 1}^K x_i = 1,\, x_1 \geq 0,\, \ldots,\, x_K \geq 0 \} \)に含まれる。この\( \boldsymbol{q} \)がサイコロのモデルになる。

全部で\( N \)回このサイコロを振る。出目が\( k \)となった回数を\( n_k \) とし、\( p_k := n_k / N \)とする。このとき\( \boldsymbol{p} = {\begin{pmatrix}p_1 & \cdots & p_K\end{pmatrix}}^\top \)は観測によって得られた確率である。最尤推定により\( \boldsymbol{q} \)を求めるには以下を求める。 \begin{align*} \operatorname*{arg\,max}_{\boldsymbol{q} \in \varDelta^{K - 1}} \prod_{k = 1}^K q_k^{Np_k} \end{align*} 負の対数尤度を使って以下の問題を解いても同じである。 \begin{align*} \operatorname*{arg\,min}_{\boldsymbol{q} \in \varDelta^{K - 1}} -\sum_{k = 1}^K p_k \log q_k \end{align*} この目的関数を\(\boldsymbol{p}\)と\(\boldsymbol{q}\)のクロスエントロピー(cross entropy)と呼ぶ。 \begin{align*} \operatorname*{CrossEntropy}(\boldsymbol{p}, \boldsymbol{q}) = - \sum_{k = 1}^K p_k \log q_k \end{align*}

\(\varDelta^{K - 1}\)の要素とは限らない尤度を表すスコア\( \boldsymbol{a} \in \mathbb{R}^K \) が与えられた場合、ソフトマックス(softmax)を用いて正規化するのが一般的である。 \begin{align*} \softmax(\boldsymbol{a}) = \frac{1}{\sum_{k = 1}^K \exp(a_k)} \begin{pmatrix} \exp(a_1) \\ \vdots \\ \exp(a_K) \end{pmatrix} \end{align*} ソフトマックスの値域は\( \operatorname{relint} \varDelta^{K - 1} = \{ \boldsymbol{x} \in \mathbb{R}^K \mid \sum_{i = 1}^K x_i = 1,\, x_1 \gt 0,\, \ldots,\, x_K \gt 0 \} \subset \varDelta^{K - 1} \)であるので、その出力を確率と見なすことができるのである。

誤差逆伝播

\(\boldsymbol{q} = \operatorname{softmax}(\boldsymbol{a})\)が成り立つとき、 \begin{align*} \frac{\partial q_k}{\partial a_i} &= \frac{ \left( \frac{ \partial }{ \partial a_i } \exp(a_k) \right) \sum_{j = 1}^K \exp(a_j) - \exp(a_k) \frac{ \partial }{ \partial a_i } \sum_{j = 1}^K \exp(a_j) } { (\sum_{j = 1}^K \exp (a_j))^2 } \\ & = \frac{ \delta_{ik} \exp(a_k) \sum_{j = 1}^K \exp(a_j) - \exp(a_k) \exp(a_i) } { (\sum_{j = 1}^K \exp (a_j))^2 } \\ & = \delta_{ki} \frac{\exp(a_k)}{\sum_{j = 1}^K \exp(a_j)} - \frac{\exp(a_k) \exp(a_i)}{(\sum_{j = 1}^K \exp(a_j))^2} \\ & = q_k( \delta_{ki} - q_i) \\ \end{align*}

クロスエントロピー\( H := \operatorname{CrossEntropy}(\boldsymbol{p}, \boldsymbol{q}) \) の\( a_k \)に関する偏導関数は以下のように計算できる。 \begin{align*} \frac{\partial H}{\partial a_i} & = - \sum_{k = 1}^K p_k \frac{\partial \log q_k}{\partial q_k} \frac{\partial q_k}{\partial a_i} \\ & = - p_i \frac{q_i(1 - q_i)}{q_i} - \sum_{k \in \{1,\,\ldots,\,K\} \backslash \{i\}} p_k \frac{- q_k q_i}{q_k}\\ & = - p_i (1 - q_i) + q_i \sum_{k = 1}^K p_k - p_i q_i \\ & = - p_i (1 - q_i) + q_i (1 - p_i ) \\ & = q_i - p_i \end{align*}