こんにちは。レトリバのリサーチャーの木村@big_wingです。
今回は今年のICML2021で発表されたSparseBERT: Rethinking the Importance Analysis in Self-attentionを紹介します。
- SparseBERTの概要
- どの位置のattentionが重要か?
- 理論解析: 対角要素のattentionなしでもUniversal Approximabilityが成立
- SparseBert: end-to-endでattentionマスクを学習する
- まとめ
SparseBERTの概要
紹介する論文は西鳥羽が以前のブログで紹介したBig Birdと同様に、Transformerにおけるattentionのスパース化についての論文です。Transformerは入力系列長の2乗の時間計算量、空間計算量を必要とするためスパース化したい気持ちになり、様々なattentionのスパース化についての手法が提案されてきました。しかし、これらのattentionのスパース化手法はいずれも事前にのattentionマスク (どこのマスを計算するか)を人手で決め打ちで固定する必要がありました。ここでは入力系列長です。以下の図1は様々なattentionマスクの例です。薄い青色のマスがに該当しattentionを計算し、白のマスがに該当しattentionの計算は行いません。
本論文の貢献は以下の通りです。
- どの位置のattentionが重要であるかを実験的に確認した。先行研究で重要であると考えられていた対角要素の位置のattentionが重要でないという結果を示した。一方で最初のトークンである[CLS]や最後のトークンである[SEP]は先行研究同様に重要であることを示した。
上記対角要素の位置のattentionがなくてもUniversal Approximability(任意の連続関数を近似できる)が成立することを理論的に示した。
attentionマスクをend-to-endで学習するアルゴリズムを提案した。
これらについて順に紹介していきたいと思います。
どの位置のattentionが重要か?
self-attention層の行列表現
最初にself-attention層の行列表現について紹介します。 self-attention層の出力は以下のように書くことができます。
ここでは入力、はマルチヘッドの数、はソフトマックス関数で、 はそれぞれ番目のヘッドのクエリ、キー、バリュー、アウトプットに対する重み行列です。
attentionマスク構造の探索問題として定式化
論文ではどの位置のattentionが重要であるかということを実験的に確認しています。 著者らはこの問題をのattentionマスク構造の探索問題として定式化し、さらに連続緩和を行い問題を解きました。具体的には、
の形式で各ヘッドの位置ごとにattentionの存在確率行列を導入しました。ここでは番目のヘッドにおけるの位置のattentionマスクの存在確率を表し、はシグモイド関数です。この行列を用いてattentionマスク構造探索問題の連続緩和を行いました。ここでは要素ごとの乗算であるアダマール積です。とするとこれは従来のself-attention層と一致します。
実験結果
論文ではBERT-baseを用いてどの位置のattentionが重要であるかを実験的に確認しています。上記のような連続緩和を行い、MLM(Masked Language Model)とNSP(Next Sentence Prediction)を用いて事前学習を行い、を学習しています。以下の図2は学習後のの分布です。この実験においてマルチヘッド数はとしていて、図はこれら12個のの平均です。の値が小さければその位置のattentionは重要ではなく、逆に大きければ重要であるということが示唆されます。
図2(a)から以下のことがわかります。
特に1.が先行研究と矛盾するような結果となっていて興味深いです。 図1のように既存のattentionマスク構造はいずれも対角要素の位置のattentionを考慮しています。この実験結果から先行研究では重要であると考えられていた対角要素の位置にあるattentionが実はそれほど重要ではない可能性が考えられます。
このことを確認するために論文では実際の様々な自然言語処理タスクに対し、対角要素の位置にあるattentionを取り除いた場合の結果を確認しています。以下の表1がその結果です。 GLUEベンチマークデータセットに対して(ours)が著者らによる通常学習、(no diag-attention)が各ヘッドの対角要素の位置のattentionを除いたもの、(randomly-dropped)がattentionをランダムに除いたものになっています。表1からわかるように対角要素の位置にあるattentionを除いてもほとんど精度が悪化していないことが確認できます。
理論解析: 対角要素のattentionなしでもUniversal Approximabilityが成立
先行研究においてはtransformerの理論的な保証の一つとしてUniversal Approximability(任意の連続関数を近似する)があることが挙げられます。先行研究のUniversal Approximabilityの証明においては対角要素の位置にあるattentionが重要な役割を果たしていました。紹介論文において著者らは対角要素の位置にあるattentionがない場合においても同様にUniversal Approximabilityが成立することを証明しています。
SparseBert: end-to-endでattentionマスクを学習する
最後に論文ではend-to-endでattentionマスクを学習する手法を提案しています。これが論文のタイトルにあるSparseBERTです。 途中紹介したattentionの存在確率行列を学習する手法では、最終的なattentionマスクを得るために一旦を学習した後に、適当な閾値でバイナリ化する必要があります。 SparseBERTは、attentionの存在確率行列を以下のように定式化します。
ここで、は互いに独立に同一のガンベル分布に従う確率変数では温度パラメタです。を0に近づける事で はの2値を値に持つ離散分布となります。この性質によってend-to-endでattentionマスクを直接学習することができます。 また学習する際のロス関数を、
と定義します。ここでは事前学習についてのロスではattentionマスクのスパース性に関するハイパーパラメタです。以下の図3が学習アルゴリズムです。図のはと置き換えて下さい。BERTの学習パラメタとattentionマスクのパラメタを同時に更新し、収束するまで繰り返します。
最後にSparseBERTとattentionマスク手法の精度比較を行った実験結果を紹介します。実験はGLUEデータセットを用いて行っています。結果を以下の図4に示します。赤と緑がSparseBERTで、赤が上記のSparseBERT、緑はattentionマスクの構造に制約を入れたSparseBERTです。結果を見ると、SparseBERTが常に性能が最もよいというわけではありませんが、いくつかのタスクでは既存のattentionマスクより精度、スパース性がともに向上しています。
まとめ
今回はattentionマスクをend-to-endで学習するSparseBERTについて紹介しました。先行研究において重要視されていた対角要素の位置にあるattentionの重要性の検証や、対角要素の位置にあるattentionがなくてもUniversal Approximabilityが成り立つことの証明などはとても興味深かったです。 また今回の実験的な検証は自然言語処理におけるタスクで検証していましたが、これが音声認識やゲノムなどのタスクにおいても同様の結果となるのかは気になりました。