BERTでの語彙追加~add_tokenに気をつけろ!~

こんにちは。レトリバの飯田(@meshidenn)です。TSUNADE事業部 研究チームのリーダーをしており、マネジメントや論文調査、受託のPOCを行なっています。

みなさんは、BERTなどの学習済み言語モデルに対して語彙を追加したくなることはありませんか? 諸々の論文(こちらこちらこちら)により、特定ドメインやrare-wordの語彙を追加することによって、性能が上がることが知られています。

そこで、語彙を追加しようと思い、TransformersのTokenizerの仕様を見ると、add_tokens という関数があります。これを使えば、tokenizerに語彙を追加できるので、あとはembedding側にも新しい語彙を受け取れるようにすれば万事解決です!

とは、うまくいかないので、今回はこの辺りについて、ちょっとした解説をします。

add_tokensの問題点

なにがうまくいかないのかというと、語彙数が増えたとき、add_tokensを使うとかなり遅くなります。なぜでしょうか。 これは、add_tokensで追加された語彙に対しては、先にTrieを使って最長一致で分割を行ってから、通常のTokenize処理を行うためと考えられます。まずは、Tokenizerの中身を見てみましょう。なお、今回は、FastTokenizerではない方についてのお話ですが、大量に語彙を追加すると、add_tokensをすると遅くなるのは、FastTokenizerでも同様です。

add_tokensをした場合のTokenize処理

tokenizerの処理は、PreTrainedTokenizerクラスに書かれています。このクラスは、各モデルのTokenizerの親クラスです。以下、該当部分のみ抜粋したものを記載します。(全体はこちら)

def tokenize(self, text: TextInput, **kwargs) -> List[str]:
    ...
    no_split_token = set(self.unique_no_split_tokens)
    tokens = self.tokens_trie.split(text)
    # ["This is something", "<special_token_1>", "  else"]
    for i, token in enumerate(tokens):
        ....
    # ["This is something", "<special_token_1>", "else"]
    tokenized_text = []
    for token in tokens:
        # Need to skip eventual empty (fully stripped) tokens
        if not token:
            continue
        if token in no_split_token:
            tokenized_text.append(token)
        else:
            tokenized_text.extend(self._tokenize(token))
    # ["This", " is", " something", "<special_token_1>", "else"]
    return tokenized_text

コードに記載されているように、まず、 self.tokens_trie.split(text) によって、This is somthing <special_token_1> else という文字列が、["This is something", "<special_token_1>", " else"] というように、 <special_token_1> の部分だけ分割されます。そして、スペースなど諸々の不要な文字を削った後、no_split_tokenではない部分について、 self._tokenize() を使ってtokenizeしています。self._tokenize() の部分は、直ぐ下にありますが、このクラス内には特に実装がなく、下位のクラスで実装されています。

さて、名前からself.tokens_trie.split()はTrieを使って分割をしています。Trieは、高速に共通接頭辞検索を行えるデータ構造です。詳しくはこちらこちらの書籍を参照ください。Trieによる分割は、最長一致分割で行われます。このコードはすべてpythonで実装されています(実装はこちら)。こちらのsplit関数の中 を見ると、現在の文字が、Trieに辞書登録された語であるかどうかによって、status に入力がなされ、for文が周ります。よって、語彙が追加されれば追加されるほど、pythonのfor文が回ることになり、速度低下の原因になっていると考えられます。

なお、self.unique_no_split_tokens は add_tokensによって追加されるtokenであることが、コードを辿るとわかります。よって、追加された語彙はself._tokeize()で分割されないようになっています。

対応方法

どのようにすれば、速度低下が防げるでしょうか。少し調べてみると、vocab.txtに書き込むという方法があるようです。

実験

まず、英語の場合で見てみましょう。今回の実験には、医療系の検索タスクである、nfcorpusを使用しました。

まず、何も追加しなかった場合の速度です。(e_textsは、nfcorpusの本文を格納したlistです。)

In [-]: etk_a = BertTokenizer.from_pretrained("bert-base-uncased")

In [-]: %time print(etk.tokenize(e_texts[0]))
Out[-]: ['recent', 'studies', 'have', 'suggested', 'that', 'stat', '##ins', ',', 'an', 'established', 'drug', 'group', 'in', 'the', 'prevention', 'of', 'cardiovascular', 'mortality', ',', 'could', 'delay', 'or', 'prevent', 'breast', 'cancer'...]
CPU times: user 4.38 ms, sys: 0 ns, total: 4.38ms
Wall time: 4.38 ms

次に、nfcorpusに対して、space区切りでtokenizeした結果を追加した場合です。

In [-]: etk_a = BertTokenizer.from_pretrained("bert-base-uncased")

In [-]: etk_a.add_tokens(list(e_tk_vocab))
Out[-]: 55449

In [-]: %time print(etk_a .tokenize(e_texts[0]))
Out[-]: ['recent', 'studies', 'have', 'su', 'gg', 'este', '##d', 'that', 'statins,', 'an', 'est', 'abl', 'bl', 'is', '##hed', 'drug',...]
CPU times: user 320 ms, sys: 0 ns, total: 320 ms
Wall time: 320 ms

In [132]: "statins" in e_tk_vocab
Out[132]: True

かかった時間が、4.4msから320msになっており、かなり遅くなっていることがわかります。 また、追加した語彙である、 statins が tokenizeされていないことがわかります。

最後に、vocab.txtに書き込む方式です。

In [-]: etk.save_pretrained("./tmp_tk")
Out[-]:
('./tmp_tk/tokenizer_config.json',
 './tmp_tk/special_tokens_map.json',
 './tmp_tk/vocab.txt',
 './tmp_tk/added_tokens.json')
`
In [-]: with open("./tmp_tk/vocab.txt", "a") as f:
     ...:     for v in e_tk_vocab:
     ...:         print(v, file=f)
     ...:

In [-]: etk_f = BertTokenizer("./tmp_tk/vocab.txt", do_lower_case = True)

In [-]: %time print(etk_f.tokenize(e_texts[0]))
['recent', 'studies', 'have', 'suggested', 'that', 'statins', ',', 'an', 'established', 'drug', 'group', ...]
CPU times: user 4.8 ms, sys: 0 ns, total: 4.8 ms
Wall time: 4.82 ms

変わりない速度で、実行できています。また、追加している語彙である、statins も分割されずに残っています。

日本語だとどうなの?

日本語の場合も、概ね同じです。add_tokensで追加した場合は、前述のように、Trieで分割され、追加した語彙はそのままの状態で返されます。一方、ファイルから語彙を追加した場合は、 _tokenize()によって、分割されます。しかし若干違いがあります。BERTJapaneseTokenizerの_tokenize()を見てみると、内部でself.word_tokenizer が使われています。この部分は、BasicTokenizerとMecabTokenizerのどちらかを使用します。MecabTokenizerの場合には、実際には、fugashiという、Mecabpython wrapperが使用されています。そのため、 word_tokenizermecab を指定した場合、fugashiが分割するのであれば、新しく追加した語彙も分割されます。

余談ですが、英語でもdo_basic_tokenize=Trueとして、tokenizerを初期化すると、BasicTokenizerで分割されるような複合語の場合に、辞書に登録する方式を使用する場合と同様に分割されます。

実験

こちらの実験にはwikipediaの航空宇宙産業のページを使用しています。

まずは、東北大学から公開されている、cl-tohoku/bert-base-japaneseをそのまま使用した場合です。

In [-]: jp_text = "航空宇宙産業 (Aerospace Industry) とは、航空機や航空機の部品、ミサイル、ロケット、宇宙船を製造する産業である。この産業には、設計、製造、テスト、販売、整備などの工程がある。その規模が大きければ部分的に関わる企業、組織が存在する。本項では、エアロスペース・マニュファクチャー(英語: Aerospace manufacturer)についても述べる。"

In [-]: jtk = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")

In [-]: print(jtk.tokenize(jp_text))
['航空', '宇宙', '産業', '(', 'A', '##ero', '##sp', '##ace', 'Ind', '##ust', '##ry', ')', 'と', 'は', '、', '航空機', 'や', , '航空機', 'の', '部品', '、', ...]

次に、add_tokensをした場合です。追加された語彙が分割されずに残っています。

In [-]: jtk_a = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")

In [-]: jtk_a.add_tokens(["航空宇宙産業"])
Out[-]: 1

In [-]: print(jtk_a.tokenize(jp_text))
['航空宇宙産業', '(', 'A', '##ero', '##sp', '##ace', 'Ind', '##ust', '##ry', ')', 'と', 'は', '、', '航空機', 'や', '航空機', 'の', '部品', '、', ...]

最後に、vocab.txtに書き込む場合です。fugashiで分割されてしまうため、語彙を追加しても分割されていることがわかります。

In [-]: jtk_f = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")

In [-]: jtk_f.save_pretrained("./tmp_tk/")
Out[-]:
('./tmp_tk/tokenizer_config.json',
 './tmp_tk/special_tokens_map.json',
 './tmp_tk/vocab.txt',
 './tmp_tk/added_tokens.json')

In [-]: with open("./tmp_tk/vocab.txt", "a") as f:
    ...:     print("航空宇宙産業", file=f)
    ..:

In [-]: jtk_f = BertJapaneseTokenizer("./tmp_tk/vocab.txt", word_tokenizer_type="mecab")

In [-]: print(jtk_f.tokenize(jp_text))
['航空', '宇宙', '産業', '(', 'A', '##ero', '##sp', '##ace', 'Ind', '##ust', '##ry', ')', 'と', 'は', '、', '航空機', 'や', ...]

なお、defaultのword_tokenizerは、BasicTokenizerです。こちらを使うと以下のように追加した語彙はきちんと残っています。一方、##航空, ##機, ##の などと分割されていることから、かなり長いwordをsubwordに分割していることが見て取れます。

In [-]: jtk_f = BertJapaneseTokenizer("./tmp_tk/vocab.txt")

In [-]: print(jtk_f.tokenize(jp_text))
['航空宇宙産業', '(', 'A', '##ero', '##sp', '##ace', 'Ind', '##ust', '##ry', ')', 'と', '##は', '、', '航空機', '##や', '##航空', '##機', '##の', '##部', '##品', '、', ...]

まとめ

この記事では、TransformersにあるTokenizerに対して、語彙を追加する方法をご紹介しました。また、方法による違いについてご紹介しました。なお、FastTokenizerはこの限りではない箇所もありますので、ご注意ください。