位置エンコーディング

Posted: 2023-03-10

定義 \(D\) を 正の偶数とする. \(\mathbb{R}^D\) への位置エンコーディング \(f\) は, \(\mathbb{N} = \{0,\,1,\,\ldots\}\) から \(\mathbb{R}^D\) への写像で, 像 \(f(t)\) の第 \(i\) 成分 \({f(t)}_i\) が \begin{align*} {f(t)}_i = \begin{cases} \sin (\omega_{\varphi(i)}\, t) & \text{if \(i \equiv 1 \bmod 2\)}, \\ \cos (\omega_{\varphi(i)}\, t) & \text{otherwise}, \\ \end{cases} \end{align*} ただし, \begin{align*} \varphi(i) = 2 \left\lfloor \frac{i - 1}{2} \right\rfloor,\quad \omega_k = \frac{1}{10000^{ \frac{k}{D} } } \end{align*} のように定める.

\begin{align*} \varphi(1) = 0,\,\varphi(2) = 0,\,\varphi(3) = 2,\,\varphi(4) = 2,\,\varphi(5) = 4,\,\varphi(6) = 4,\,\ldots \end{align*}

\(D = 128\) のとき \begin{gather*} f(0) = \begin{pmatrix} \sin (0) \\ \cos (0) \\ \vdots \\ \sin (0) \\ \cos (0) \end{pmatrix} = \begin{pmatrix} 0 \\ 1 \\ \vdots \\ 0 \\ 1 \end{pmatrix} ,\\[4pt] f(1) = \begin{pmatrix} \sin (\omega_{\varphi(1)}) \\ \cos (\omega_{\varphi(2)}) \\ \vdots \\ \sin (\omega_{\varphi(127)}) \\ \cos (\omega_{\varphi(128)}) \end{pmatrix} = \begin{pmatrix} \sin {\frac{1}{10000^{0/128}}} \\ \cos {\frac{1}{10000^{0/128}}} \\ \vdots \\ \sin {\frac{1}{10000^{126/128}}} \\ \cos {\frac{1}{10000^{126/128}}} \\ \end{pmatrix} \approx \begin{pmatrix} 0.8414 \\ 0.5403 \\ \vdots \\ 0.0001 \\ 0.9999 \end{pmatrix} ,\\[4pt] f(2) = \begin{pmatrix} \sin (\omega_{\varphi(1)} \cdot 2) \\ \cos (\omega_{\varphi(2)} \cdot 2) \\ \vdots \\ \sin (\omega_{\varphi(127)} \cdot 2) \\ \cos (\omega_{\varphi(128)} \cdot 2) \end{pmatrix} = \begin{pmatrix} \sin {\frac{2}{10000^{0/128}}} \\ \cos {\frac{2}{10000^{0/128}}} \\ \vdots \\ \sin {\frac{2}{10000^{126/128}}} \\ \cos {\frac{2}{10000^{126/128}}} \\ \end{pmatrix} \approx \begin{pmatrix} 0.9092 \\ -0.4161 \\ \vdots \\ 0.0002 \\ 0.9999 \end{pmatrix} \end{gather*} となる.

なぜ三角関数を使うのか?

三角関数を用いる理由は,各成分の周期を変えることにより, 位置に固有のベクトルを生成できるからである. 下の 図 1 と 図 2 を見ると分かるように, \(\sin (\omega\, t) \) は \(\omega\) が小さくなるほど,周期 \((= 2\pi / \omega)\) が大きくなる. すなわち,成分を縦に並べたとき,下の方にある成分ほどゆっくり変化する.
図 1:\(y = \sin x\) のグラフ.周期は \(2\pi\).
図 2:\(\displaystyle y = \sin \frac{1}{2}x\) のグラフ.周期は \(\displaystyle \frac{2\pi }{1 / 2} = 4\pi\).
ヒートマップを使って可視化すると 図 3 のようになる.
図 3:\(D = 128,\,t = 0,\,\ldots,\,50\) のヒートマップ.
ヒートマップの作成に使用した Python スクリプト.
import itertools

import numpy
import seaborn
from matplotlib import pyplot


def main() -> None:
    D = 128
    T = 51
    data = numpy.zeros((T, D), dtype=float)
    for (t, d) in itertools.product(range(T), range(D)):
        f = numpy.sin if d % 2 == 0 else numpy.cos
        omega = 1 / (10_000 ** (2 * numpy.floor(d / 2) / D))
        data[t, d] = f(omega * t)
    h = seaborn.heatmap(data.T, vmin=-1.0, cbar_kws={'ticks': [-1.0, -0.5, 0.0, 0.5, 1.0]})
    h.set_xlabel('t (position)')
    h.set_ylabel('i (dimension)')
    xticklabels = [t for t in range(T) if t % 10 == 0]
    h.set_xticks([x + 0.5 for x in xticklabels])
    h.set_xticklabels(xticklabels)
    yticklabels = [1] + [y for y in range(D + 1) if y % 16 == 0 and y > 0]
    h.set_yticks([y - 0.5 for y in yticklabels])
    h.set_yticklabels(yticklabels)
    pyplot.savefig('../images/heatmap.svg', bbox_inches='tight', transparent=True)
    pyplot.show()


if __name__ == '__main__':
    main()

なぜ \(\sin\) と \(\cos\) の両方を使うのか?また,なぜ \(\omega\) の添え字が2個ずつ同じになっているのか?

位置の変位を線型変換に対応させたいから.

We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset \(k\), \(PE_{ \mathit{pos} + k}\) can be represented as a linear function of \(PE_\mathit{pos}\).

証明 \begin{align*} f(t + t') &= \begin{pmatrix} \sin (\omega_{\varphi(1)}\, (t + t')) \\ \cos (\omega_{\varphi(2)}\, (t + t')) \\ \vdots \\ \sin (\omega_{\varphi(D - 1)}\, (t + t')) \\ \cos (\omega_{\varphi(D)}\, (t + t')) \end{pmatrix} \\[4pt] &= \begin{pmatrix} \sin (\omega_0 \, t + \omega_0 \, t') \\ \cos (\omega_0 \, t + \omega_0 \, t') \\ \vdots \\ \sin (\omega_{D/2}\, t + \omega_{D/2}\, t') \\ \cos (\omega_{D/2}\, t + \omega_{D/2}\, t') \end{pmatrix} \\[4pt] &= \begin{pmatrix} \sin (\omega_0\, t) \cos (\omega_0\, t') + \cos (\omega_0\, t) \sin (\omega_0\, t') \\ \cos (\omega_0\, t) \cos (\omega_0\, t') - \sin (\omega_0\, t) \sin (\omega_0\, t') \\ \vdots \\ \sin (\omega_{D/2}\, t) \cos (\omega_{D/2}\, t') + \cos (\omega_{D/2}\, t) \sin (\omega_{D/2}\, t') \\ \cos (\omega_{D/2}\, t) \cos (\omega_{D/2}\, t') - \sin (\omega_{D/2}\, t) \sin (\omega_{D/2}\, t') \\ \end{pmatrix} \\[4pt] &= \begin{pmatrix} \cos(\omega_0\,t') & \sin(\omega_0\,t') & \cdots & 0 & 0 \\ -\sin(\omega_0\,t') & \cos(\omega_0\,t') & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & \cos(\omega_{D/2}\,t') & \sin(\omega_{D/2}\,t') \\ 0 & 0 & \cdots & -\sin(_{D/2}\,t') & \cos(\omega_{D/2}\,t') \\ \end{pmatrix} \begin{pmatrix} \sin (\omega_0\, t) \\ \cos (\omega_0\, t) \\ \vdots \\ \sin (\omega_{D/2}\, t) \\ \cos (\omega_{D/2}\, t) \end{pmatrix} \\[4pt] &= A(t') f(t). \end{align*} ここで \begin{gather} A(t') = \begin{pmatrix} R_0\,(t) & O & \cdots & O \\ O & R_2\,(t) & \cdots & O \\ \vdots & \vdots & \ddots & \vdots \\ O & O & \cdots & R_{D/2}\,(t) \\ \end{pmatrix}, \notag \\[4pt] R_k(t') = \begin{pmatrix} \cos (\omega_k\,t) & \sin (\omega_k\, t) \\ -\sin (\omega_k\,t) & \cos (\omega_k\, t) \\ \end{pmatrix}. \tag*{∎} \end{gather}

\(R_k (t')\) は \(\mathbb{R}^2\) の要素を \(- \omega_k\, t'\) だけ回転させる線型変換である. この結果を得るために \(\sin\), \(\cos\) の両方を用い,また \(\omega\) の添え字を2個ずつ同じにしているのである.

参考文献