ニューラルネットの線型層の初期化

Posted: 2022-09-26

前提 各層の入出力の分散が一定であることが望ましい.

ニューラルネットの第\( \ell \)番目の線型層について,入力の次元を\( n^\ell \),出力の次元を\( n^{\ell + 1} \),入力と重みをそれぞれ \begin{align*} x^\ell = \begin{pmatrix} x^\ell_1 \\ \vdots \\ x^\ell_{n^\ell} \end{pmatrix},\quad W^{\ell} = \begin{pmatrix} w_{11}^{\ell} & \cdots & w_{1 n^\ell}^{\ell} \\ \vdots & \ddots & \vdots \\ w_{n^{\ell + 1} 1}^{\ell} & \cdots & w_{n^{\ell + 1} n^\ell}^{\ell} \end{pmatrix} \end{align*} とする(※\( \ell \)は上付き添え字であり,累乗の指数ではない.).各\( w_{ij}^{\ell} \)の各成分は期待値\( 0 \)の同一の分布から独立に選ぶ.このとき任意の\( i,\,j,\,k \)について\( x_k \)と\( w_{ij} \)は独立になる.線型層の出力\( y^\ell := W^{\ell}x^\ell \)の第\( i \)成分の分散は,\( \mathcal{J}_{n^\ell} := \{1,\,\ldots,\,n^\ell\} \)と表すことにすると \begin{align} V\left[y^\ell_i\right] &= V \left[ \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_j^{\ell} \right] \notag \\ &= E\left[ { \left( \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right) }^2 \right] - { \left( E\left[ \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right] \right) }^2 \label{a} \end{align} となる.

(\{a})の右辺第1項について \begin{align} E\left[ { \left( \sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell} x_i^{\ell} \right) }^2 \right] &= \sum_{(j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell} x_{j}^{\ell} w_{ij'}^{\ell} x_{j'}^{\ell}\right] \notag \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} E\left[ { \left( w_{ij}^{\ell} \right) }^2 \right] E\left[{\left(x_j^{\ell}\right)}^2 \right] + \sum_{ \substack{ (j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell} \\ j \neq j' } } E\left[w_{ij}^{\ell}\right] E \left[w_{ij'}^{\ell} \right] E\left[x_j^{\ell} x_{j'}^{\ell} \right] \notag \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} \left(V\left[w_{ij}\right] - \left(E[w_{ij}]\right)^2\right) E \left[{\left(x_j^{\ell}\right)}^2\right] + \sum_{ \substack{ (j,\, j') \in \mathcal{J}_{n^\ell} \times \mathcal{J}_{n^\ell} \\ j \neq j' } } 0 \cdot 0 \cdot E\left[x_j^{\ell} x_{j'}^{\ell} \right] \notag \\ &= V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} E \left[{\left(x_j^{\ell}\right)}^2\right]. \label{w1} \end{align} ただし\( V_{W}^{\ell} = V\left[w_{i1}^{\ell}\right] = \cdots V\left[w_{n^\ell}^{\ell}\right] \)と置いた.

(\{a})の右辺第2項について \begin{align} {\left(\sum_{j \in \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell}\right]E\left[x_j^{\ell}\right]\right)}^2 = {\left(\sum_{j \in \mathcal{J}_{n^\ell}} 0 \cdot E\left[x_j^{\ell}\right]\right)}^2 = 0 \label{w2} \end{align} が成り立つ.(\{w1})と\((\ref{w2})\)を合わせて \begin{align} V\left[y_i^{\ell}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} E \left[{\left(x_j^{\ell}\right)}^2\right]. \label{v} \end{align} という式を得る.

活性化関数を\( g \)とする.このとき\( x_i^{\ell + 1} := g(y_i^{\ell}) \)である.以下が成り立つと仮定する: \begin{gather} V\left[x_i^{\ell + 1}\right] \approx V\left[y_i^{\ell}\right] \text{ for all \(i \in \mathcal{J}_{n^{\ell + 1}}\)}, \tag{H1} \label{H1} \\ E\left[x_j^{\ell}\right] = 0 \text{ for all \(j \in \mathcal{J}_{n^\ell}\)}. \tag{H2} \label{ex0} \end{gather}

(\{H1})について,活性化関数として用いられるものは,通常\( 0 \) の周りで\( g' \approx 1 \)を満たす.すなわち\( g(y) \approx y + g(0) \)と書ける.したがって \begin{align*} V\left[x_i^{\ell + 1}\right] = V\left[g\left(y_i^{\ell + 1}\right)\right] \approx V\left[y_i^{\ell}+ g(0)\right] = V\left[y_i^{\ell}\right] \end{align*} が成り立つため,妥当な仮定である.(\{ex0})については \begin{align*} E\left[x_i^{\ell}\right] &= E\left[g\left( y_i^{\ell - 1} \right)\right] \\ &\approx E\left[ y_i^{\ell - 1} + g(0)\right] \\ &= E\left[\sum_{j \in \mathcal{J}_{n^\ell}} w_{ij}^{\ell - 1} x_j^{\ell - 1} + g(0)\right] \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} E\left[w_{ij}^{\ell - 1}\right] E \left[ x_j^{\ell - 1} \right] + g(0) \\ &= \sum_{j \in \mathcal{J}_{n^\ell}} 0 \cdot E \left[ x_j^{\ell - 1} \right] + g(0)\\ &= g(0) \end{align*} となるので,\( 0 \)の周りで\( g' \approx 1 \)かつ\( g(0) = 0 \)となるような関数(\(\tanh\)など)を使う限り,妥当な仮定である.

これら仮定の下で(\{v})より \begin{gather*} V\left[x_i^{\ell + 1}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} V\left[x_j^{\ell}\right] \\ V\left[x_i^{\ell + 1}\right] = n^\ell V_{W}^{\ell} V\left[x_{1}^{\ell}\right] \end{gather*} したがって\( V\left[x_i^{\ell + 1}\right] = V\left[x_{1}^{\ell}\right] \\ \)とするためには \begin{gather*} V_{W}^{\ell} = \frac{1}{n^\ell}. \label{n} \end{gather*} とすればよい.逆伝播も考慮して \begin{align} V_W^{\ell} = \frac{2}{n^\ell + n^{\ell + 1}} \label{result} \end{align} と定める.この初期化手法はGlorot et al. (2010)で提案されたためGlorot initializationと呼ばれる.

\(W^{\ell}\)の各要素が\( \mathcal{U}(-a,\,a) \)(\(a\)は正の定数)から選ばれる場合を考える.このとき各要素の分散\( V^{\ell}_W \)は \begin{align} V^{\ell}_W = \frac{(a - (- a))^2}{12} = \frac{a^2}{3}. \label{uniform} \end{align} で与えられる.(\{result})を利用すると \begin{align*} \frac{a^2}{3} = \frac{2}{n^\ell + n^{\ell + 1}}. \end{align*} \(a\)は正だったことから \begin{align*} a = \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}}. \end{align*} したがって\( W^{\ell} \)の各要素は \begin{gather*} \mathcal{U}\left( - \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}},\, \sqrt{\frac{6}{n^\ell + n^{\ell + 1}}} \right) \end{gather*} からサンプリングすればよい.

(\{ex0})が成り立たない場合 例えば活性化関数が\( g(y) := \mathrm{ReLU}(y) = \max\{0,\,y\} \)のとき\( E\left[x_i^{\ell}\right] > 0 \) となる.適当な仮定をおいて,そのような場合にも(\{v})から\( V_{W}^{\ell} \)を求めたい.各\( x_i \)の確率密度関数\( f \) とし,\( f \)は\( y \)軸について対称であるとする.このとき \begin{align*} E\left[{\left(x_j^{\ell}\right)}^2\right] &= \int_{-\infty}^{\infty} {\left(\max\left\{0,\,y_j^{\ell - 1}\right\}\right)}^2 f\left(y_j^{\ell - 1}\right)dy_j^{\ell - 1} \\ &= \int_{0}^{\infty} {\left( y_j^{\ell - 1} \right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \frac{1}{2} \int_{-\infty}^{\infty} {\left( y_j^{\ell - 1} \right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \int_{-\infty}^{\infty} \frac{1}{2} {\left(y_j^{\ell - 1} - E\left[y_j^{\ell - 1}\right]\right)}^2 f\left(y_j^{\ell - 1}\right) dy_j^{\ell - 1} \\ &= \frac{1}{2} V\left[y_j^{\ell - 1}\right]. \end{align*} したがって(\{v})に代入して \begin{align*} V\left[y_i^{\ell}\right] = V_{W}^{\ell} \sum_{j \in \mathcal{J}_{n^\ell}} \frac{1}{2} V\left[y_j^{\ell - 1}\right] \end{align*} \(y_i^{\ell} = y_j^{\ell - 1}\)を仮定すると \begin{align*} V_{W}^{\ell} = \frac{2}{n^\ell}. \end{align*} この初期化手法はHe et al. (2015)で提案されたためHe initializationと呼ばれる.

【例】 \( W^{\ell} \)の各要素が\( \mathcal{U}(-a,\,a) \)(\(a\)は正の定数)から選ばれる場合,(\{uniform})を使って \begin{gather*} \frac{a^2}{3} = \frac{2}{n^\ell} \\ a = \sqrt{\frac{6}{n^\ell}} \end{gather*}

参考文献

  1. Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics (pp. 249–256). JMLR Workshop and Conference Proceedings.
  2. He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In Proceedings of the IEEE international conference on computer vision (pp. 1026–1034).
  3. 岡谷貴之(2022).『深層学習 改訂第2版(機械学習プロフェッショナルシリーズ)』.講談社.