ニューラルネットの線型層の初期化
前提 各層の入出力の分散が一定であることが望ましい.
ニューラルネットの第\( \ell \)番目の線型層について,入力の次元を\( n^\ell \),出力の次元を\( n^{\ell + 1} \),入力と重みをそれぞれ \begin{align*} x^\ell = \begin{pmatrix} x^\ell_1 \\ \vdots \\ x^\ell_{n^\ell} \end{pmatrix},\quad W^{\ell} = \begin{pmatrix} w_{11}^{\ell} & \cdots & w_{1 n^\ell}^{\ell} \\ \vdots & \ddots & \vdots \\ w_{n^{\ell + 1} 1}^{\ell} & \cdots & w_{n^{\ell + 1} n^\ell}^{\ell} \end{pmatrix} \end{align*} とする(※\( \ell \)は上付き添え字であり,累乗の指数ではない.).各\( w_{ij}^{\ell} \)の各成分は期待値\( 0 \)の同一の分布から独立に選ぶ.このとき任意の\( i,\,j,\,k \)について\( x_k \)と\( w_{ij} \)は独立になる.線型層の出力\( y^\ell := W^{\ell}x^\ell \)の第\( i \)成分の分散は,\( \mathcal{J}_{n^\ell} := \{1,\,\ldots,\,n^\ell\} \)と表すことにすると \begin{align} V\left[y^\ell_i\right] &= V \left[ \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_j^{\ell} \right] \notag \\ &= E\left[ { \left( \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right) }^2 \right] - { \left( E\left[ \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right] \right) }^2 \label{a} \end{align} となる.
(\{a})の右辺第1項について \begin{align} E\left[ { \left( \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right) }^2 \right] &= \sum_{(j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell} x_{j}^{\ell} w_{ij'}^{\ell} x_{j'}^{\ell}\right] \notag \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} E\left[ { \left( w_{ij}^{\ell} \right) }^2 \right] E\left[{\left(x_j^{\ell}\right)}^2 \right] + \sum_{ \substack{ (j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell} \\ j \neq j' } } E\left[w_{ij}^{\ell}\right] E \left[w_{ij'}^{\ell} \right] E\left[x_j^{\ell} x_{j'}^{\ell} \right] \notag \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} \left(V\left[w_{ij}\right] - \left(E[w_{ij}]\right)^2\right) E \left[{\left(x_j^{\ell}\right)}^2\right] + \sum_{ \substack{ (j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell} \\ j \neq j' } } 0 \cdot 0 \cdot E\left[x_j^{\ell} x_{j'}^{\ell} \right] \notag \\ &= V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} E \left[{\left(x_j^{\ell}\right)}^2\right]. \label{w1} \end{align} ただし\( V_{W}^{\ell} = V\left[w_{i1}^{\ell}\right] = \cdots V\left[w_{n^\ell}^{\ell}\right] \)と置いた.
(\{a})の右辺第2項について \begin{align} {\left(\sum_{j \in \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell}\right]E\left[x_j^{\ell}\right]\right)}^2 = {\left(\sum_{j \in \mathcal{J}_{n^\ell}} 0 \cdot E\left[x_j^{\ell}\right]\right)}^2 = 0 \label{w2} \end{align} が成り立つ.(\{w1})と\((\ref{w2})\)を合わせて \begin{align} V\left[y_i^{\ell}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} E \left[{\left(x_j^{\ell}\right)}^2\right]. \label{v} \end{align} という式を得る.
活性化関数を\( g \)とする.このとき\( x_i^{\ell + 1} := g(y_i^{\ell}) \)である.以下が成り立つと仮定する: \begin{gather} V\left[x_i^{\ell + 1}\right] \approx V\left[y_i^{\ell}\right] \text{ for all \(i \in \mathcal{J}_{n^{\ell + 1}}\)}, \tag{H1} \label{H1} \\ E\left[x_j^{\ell}\right] = 0 \text{ for all \(j \in \mathcal{J}_{n^\ell}\)}. \tag{H2} \label{ex0} \end{gather}
(\{H1})について,活性化関数として用いられるものは,通常\( 0 \) の周りで\( g' \approx 1 \)を満たす.すなわち\( g(y) \approx y + g(0) \)と書ける.したがって \begin{align*} V\left[x_i^{\ell + 1}\right] = V\left[g\left(y_i^{\ell + 1}\right)\right] \approx V\left[y_i^{\ell}+ g(0)\right] = V\left[y_i^{\ell}\right] \end{align*} が成り立つため,妥当な仮定である.(\{ex0})については \begin{align*} E\left[x_i^{\ell}\right] &= E\left[g\left( y_i^{\ell - 1} \right)\right] \\ &\approx E\left[ y_i^{\ell - 1} + g(0)\right] \\ &= E\left[\sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell - 1} x_j^{\ell - 1} + g(0)\right] \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell - 1}\right] E \left[ x_j^{\ell - 1} \right] + g(0) \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} 0 \cdot E \left[ x_j^{\ell - 1} \right] + g(0)\\ &= g(0) \end{align*} となるので,\( 0 \)の周りで\( g' \approx 1 \)かつ\( g(0) = 0 \)となるような関数(\(\tanh\)など)を使う限り,妥当な仮定である.
これら仮定の下で(\{v})より \begin{gather*} V\left[x_i^{\ell + 1}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} V\left[x_j^{\ell}\right] \\ V\left[x_i^{\ell + 1}\right] = n^\ell V_{W}^{\ell} V\left[x_{1}^{\ell}\right] \end{gather*} したがって\( V\left[x_i^{\ell + 1}\right] = V\left[x_{1}^{\ell}\right] \\ \)とするためには \begin{gather*} V_{W}^{\ell} = \frac{1}{n^\ell}. \label{n} \end{gather*} とすればよい.逆伝播も考慮して \begin{align} V_W^{\ell} = \frac{2}{n^\ell + n^{\ell + 1}} \label{result} \end{align} と定める.この初期化手法はGlorot et al. (2010)で提案されたためGlorot initializationと呼ばれる.
例 \(W^{\ell}\)の各要素が\( \mathcal{U}(-a,\,a) \)(\(a\)は正の定数)から選ばれる場合を考える.このとき各要素の分散\( V^{\ell}_W \)は \begin{align} V^{\ell}_W = \frac{(a - (- a))^2}{12} = \frac{a^2}{3}. \label{uniform} \end{align} で与えられる.(\{result})を利用すると \begin{align*} \frac{a^2}{3} = \frac{2}{n^\ell + n^{\ell + 1}}. \end{align*} \(a\)は正だったことから \begin{align*} a = \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}}. \end{align*} したがって\( W^{\ell} \)の各要素は \begin{gather*} \mathcal{U}\left( - \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}},\, \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}} \right) \end{gather*} からサンプリングすればよい.
(\{ex0})が成り立たない場合 例えば活性化関数が\( g(y) := \mathrm{ReLU}(y) = \max\{0,\,y\} \)のとき\( E\left[x_i^{\ell}\right] > 0 \) となる.適当な仮定をおいて,そのような場合にも(\{v})から\( V_{W}^{\ell} \)を求めたい.各\( x_i \)の確率密度関数\( f \) とし,\( f \)は\( y \)軸について対称であるとする.このとき \begin{align*} E\left[{\left(x_j^{\ell}\right)}^2\right] &= \int_{-\infty}^{\infty} {\left(\max\left\{0,\,y_j^{\ell - 1}\right\}\right)}^2 f\left(y_j^{\ell - 1}\right)dy_j^{\ell - 1} \\ &= \int_{0}^{\infty} {\left( y_j^{\ell - 1} \right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \frac{1}{2} \int_{-\infty}^{\infty} {\left( y_j^{\ell - 1} \right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \int_{-\infty}^{\infty} \frac{1}{2} {\left(y_j^{\ell - 1} - E\left[y_j^{\ell - 1}\right]\right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \frac{1}{2} V\left[y_j^{\ell - 1}\right]. \end{align*} したがって(\{v})に代入して \begin{align*} V\left[y_i^{\ell}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} \frac{1}{2} V\left[y_j^{\ell - 1}\right] \end{align*} \(y_i^{\ell} = y_j^{\ell - 1}\)を仮定すると \begin{align*} V_{W}^{\ell} = \frac{2}{n^\ell}. \end{align*} この初期化手法はHe et al. (2015)で提案されたためHe initializationと呼ばれる.
【例】 \( W^{\ell} \)の各要素が\( \mathcal{U}(-a,\,a) \)(\(a\)は正の定数)から選ばれる場合,(\{uniform})を使って \begin{gather*} \frac{a^2}{3} = \frac{2}{n^\ell} \\ a = \sqrt{\frac{6}{n^\ell}} \end{gather*}
参考文献
- Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics (pp. 249–256). JMLR Workshop and Conference Proceedings.
- He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision (pp. 1026–1034).
- 岡谷貴之(2022).『深層学習 改訂第2版(機械学習プロフェッショナルシリーズ)』.講談社.