SparseBERTの紹介

こんにちは。レトリバのリサーチャーの木村@big_wingです。

今回は今年のICML2021で発表されたSparseBERT: Rethinking the Importance Analysis in Self-attentionを紹介します。

SparseBERTの概要

紹介する論文は西鳥羽が以前のブログで紹介したBig Birdと同様に、Transformerにおけるattentionのスパース化についての論文です。Transformerは入力系列長の2乗の時間計算量、空間計算量を必要とするためスパース化したい気持ちになり、様々なattentionのスパース化についての手法が提案されてきました。しかし、これらのattentionのスパース化手法はいずれも事前に M=\lbrack 0,1\rbrack ^{n\times n}のattentionマスク (どこのマスを計算するか)を人手で決め打ちで固定する必要がありました。ここで nは入力系列長です。以下の図1は様々なattentionマスクの例です。薄い青色のマスが M_{i,j}=1に該当しattentionを計算し、白のマスが M_{i,j}=0に該当しattentionの計算は行いません。

f:id:retrieva:20210727132442p:plain
図1 既存のattentionマスク。薄い青のマスの箇所のみを計算する。図は紹介論文より引用

本論文の貢献は以下の通りです。

  • どの位置のattentionが重要であるかを実験的に確認した。先行研究で重要であると考えられていた対角要素の位置のattentionが重要でないという結果を示した。一方で最初のトークンである[CLS]や最後のトークンである[SEP]は先行研究同様に重要であることを示した。
  • 上記対角要素の位置のattentionがなくてもUniversal Approximability(任意の連続関数を近似できる)が成立することを理論的に示した。

  • attentionマスクをend-to-endで学習するアルゴリズムを提案した。

これらについて順に紹介していきたいと思います。

どの位置のattentionが重要か?

self-attention層の行列表現

最初にself-attention層の行列表現について紹介します。 self-attention層の出力は以下のように書くことができます。


\begin{align}
Attn(X) &=  X + \sum_{k=1}^{H} \sigma(XW^k_{Q}(XW^k_{K})^{\top})XW^k_{V}W^{k\top}_{O} \\
&=  X + \sum_{k=1}^{H} A^k(X)V^k_{X}W^{k\top}_{O}

\end{align}

ここでX\in \mathbb{R}^{n\times d}は入力、Hはマルチヘッドの数、 \sigmaはソフトマックス関数で、 W^{k}_{Q}, W^{k}_{K},  W^{k}_{V}, W^{k}_{O} \in \mathbb{R}^{d\times d_{h}} はそれぞれk番目のヘッドのクエリ、キー、バリュー、アウトプットに対する重み行列です (d_{h}=d/H)

attentionマスク構造の探索問題として定式化

論文ではどの位置のattentionが重要であるかということを実験的に確認しています。 著者らはこの問題を \lbrack 0,1\rbrack ^{n\times n}のattentionマスク構造の探索問題として定式化し、さらに連続緩和を行い問題を解きました。具体的には、


\begin{gather}
P^k_{i,j} = sigmoid(\alpha^k_{i,j}) \in [0,1]\\
\alpha^k_{i,j} = \alpha^k_{j,i}
\end{gather}

の形式で各ヘッドの位置ごとにattentionの存在確率行列を導入しました。ここでP^{k}_{j,i} k番目のヘッドにおけるの位置(i,j)のattentionマスクの存在確率を表し、sigmoidシグモイド関数です。この行列を用いてattentionマスク構造探索問題の連続緩和を行いました。ここで \odotは要素ごとの乗算であるアダマール積です。P^{k}_{j,i}=1とするとこれは従来のself-attention層と一致します。


\begin{gather}
Attn(X) = X + \sum_{k=1}^{H} (P^k \odot A^k(X))V^k_{X}W^{k\top}_{O} \\

\end{gather}

実験結果

論文ではBERT-baseを用いてどの位置のattentionが重要であるかを実験的に確認しています。上記のような連続緩和を行い、MLM(Masked Language Model)とNSP(Next Sentence Prediction)を用いて事前学習を行い、 P^{k}_{i,j}を学習しています。以下の図2は学習後の P^{k}_{i,j}の分布です。この実験においてマルチヘッド数は H=12としていて、図はこれら12個の P^{k}_{i,j}の平均です。 P^{k}_{i,j}の値が小さければその位置のattentionは重要ではなく、逆に大きければ重要であるということが示唆されます。

f:id:retrieva:20210727170856p:plain:w500
図2 学習後の P_{i,j}の分布。図は紹介論文より引用

図2(a)から以下のことがわかります。

  1. 対角要素の位置にある P_{ij}の値が最も小さい。

  2. 対角要素に隣接する位置の P_{ij}の値は大きい。

  3. 最初のトークンである[CLS]や最後のトークンである[SEP]の P_{ij}の値も大きい。

  4. それ以外の位置の P_{ij}の値は同程度。

特に1.が先行研究と矛盾するような結果となっていて興味深いです。 図1のように既存のattentionマスク構造はいずれも対角要素の位置のattentionを考慮しています。この実験結果から先行研究では重要であると考えられていた対角要素の位置にあるattentionが実はそれほど重要ではない可能性が考えられます。

このことを確認するために論文では実際の様々な自然言語処理タスクに対し、対角要素の位置にあるattentionを取り除いた場合の結果を確認しています。以下の表1がその結果です。 GLUEベンチマークデータセットに対して(ours)が著者らによる通常学習、(no diag-attention)が各ヘッドの対角要素の位置のattentionを除いたもの、(randomly-dropped)がattentionをランダムに除いたものになっています。表1からわかるように対角要素の位置にあるattentionを除いてもほとんど精度が悪化していないことが確認できます。

f:id:retrieva:20210727175738p:plain
表1 GLUEデータセットに対する精度結果。表は紹介論文より引用

理論解析: 対角要素のattentionなしでもUniversal Approximabilityが成立

先行研究においてはtransformerの理論的な保証の一つとしてUniversal Approximability(任意の連続関数を近似する)があることが挙げられます。先行研究のUniversal Approximabilityの証明においては対角要素の位置にあるattentionが重要な役割を果たしていました。紹介論文において著者らは対角要素の位置にあるattentionがない場合においても同様にUniversal Approximabilityが成立することを証明しています。

SparseBert: end-to-endでattentionマスクを学習する

最後に論文ではend-to-endでattentionマスクを学習する手法を提案しています。これが論文のタイトルにあるSparseBERTです。 途中紹介したattentionの存在確率行列P_{i,j}を学習する手法では、最終的なattentionマスクを得るために一旦P_{i,j}を学習した後に、適当な閾値でバイナリ化する必要があります。 SparseBERTは、attentionの存在確率行列P_{i,j}を以下のように定式化します。


\begin{gather}
P_{i,j} = sigmoid( ( \alpha_{i,j} + G_1 + G_2 ) / \tau) \\
\end{gather}

ここでG_{1}G_{2}は互いに独立に同一のガンベル分布に従う確率変数で\tauは温度パラメタです。\tauを0に近づける事でP_{i,j}\{0,1\}の2値を値に持つ離散分布となります。この性質によってend-to-endでattentionマスクを直接学習することができます。 また学習する際のロス関数を、


\begin{gather}
\mathcal{L} = l(BERT(X,A(X)\odot P(\alpha); w)) + \lambda \|P(\alpha)\|_{1} \\
\end{gather}

と定義します。ここでl(BERT(X,A(X)\odot P(\alpha); w)) は事前学習についてのロスで\lambdaはattentionマスクのスパース性に関するハイパーパラメタです。以下の図3が学習アルゴリズムです。図のM_{i,j}P_{i,j}と置き換えて下さい。BERTの学習パラメタとattentionマスクのパラメタを同時に更新し、収束するまで繰り返します。

f:id:retrieva:20210727202349p:plain:w500
図3 SparseBERTの学習アルゴリズム

最後にSparseBERTとattentionマスク手法の精度比較を行った実験結果を紹介します。実験はGLUEデータセットを用いて行っています。結果を以下の図4に示します。赤と緑がSparseBERTで、赤が上記のSparseBERT、緑はattentionマスクの構造に制約を入れたSparseBERTです。結果を見ると、SparseBERTが常に性能が最もよいというわけではありませんが、いくつかのタスクでは既存のattentionマスクより精度、スパース性がともに向上しています。

f:id:retrieva:20210727205744p:plain
図4 SparseBERTと他のattentionマスクの精度比較

まとめ

今回はattentionマスクをend-to-endで学習するSparseBERTについて紹介しました。先行研究において重要視されていた対角要素の位置にあるattentionの重要性の検証や、対角要素の位置にあるattentionがなくてもUniversal Approximabilityが成り立つことの証明などはとても興味深かったです。 また今回の実験的な検証は自然言語処理におけるタスクで検証していましたが、これが音声認識やゲノムなどのタスクにおいても同様の結果となるのかは気になりました。