こんにちは。Chief Research Officerの西鳥羽です。今回はPromptingと呼ばれる手法とその手法を手軽に扱えるフレームワークのOpenPromptの紹介を行います。
Promptingとは
Prompting(或いはprompt learning)とは、分類・NLIなどでラベルやカテゴリなどの形で定式化されているタスクを自然文生成の形にして解く手法です。例えば、「4番打者がホームランを打った」のような文に対してそのカテゴリを当てるとします。通常の分類問題で解く際は「首相官邸で公式の発表があった」は「政治」のカテゴリとか、「日経平均株価が上昇した」は「経済」のカテゴリなどの学習データを用いて分類器を学習し、「政治」「経済」「スポーツ」などのカテゴリを推定します。 一方、Promptingでは言語モデルを用いて推定を行います。先ほどの「4番打者がホームランを打った」の場合でしたら「4番打者がホームランを打った。この文書のカテゴリは」を入力として、言語モデルが次に推定する単語を用いてカテゴリを推定します。この次に来る単語が「スポーツです」となるようであればこの「4番打者がホームランを打った」という文のカテゴリは「スポーツ」と推定したとみなします。
このPromptingですが、通常の文書分類における推定の候補がラベルだけで済むのに比べて、Promptingでは単語となってしまい、学習が難しくなってしまいます。一見するとわざわざ難しくしてまでこの手法を用いる理由はありません。しかしながら、ここ最近研究が活発になってきています。理由としてはGPT3やT5など、大規模なモデルによる言語モデルの性能がとてもよく、それを用いるとPromptingによる推論でも良い精度が出るようになってきたからです。特に、few shot learningと呼ばれる学習データが少ない条件下やzero shot learningと呼ばれる学習データを用いない条件下での分類において、BERTを用いた分類よりも良い精度がでるようになりました。分類の学習データの作成はアノテーションを必要ととするのですが、その作成コストが高いため学習データをどう用意するかという点が課題になることがよくあります。一方、言語モデルの学習データはアノテーションを必要としないため、比較的容易に用意することができます。そのため、zero shot learningやfew shot learningが活用できる場面は多々あるため、その精度向上が見込めるということでPromptingを用いたzero shot learningや few shot learningの最近研究が進んでいます。
OpenPrompt
このように研究が盛んで様々な手法が登場しているPromptingですが、それを統一的に扱えるOpenPrompt というフレームワークも実装されています。今回はこれを用いてPromptingの学習と精度測定を試してみます。 SST-2というデータセットを代表的な手法であるLMBFFで学習してみます。
まず git レポジトリからインストールします。事前にscikit-learnおよびpytorchのインストールも必要ですがここでは割愛します。 Installing scikit-learn — scikit-learn 1.1.2 documentation や Start Locally | PyTorch などを参考にインストールしてもらえればと思います。
git clone https://github.com/thunlp/OpenPrompt.git cd OpenPrompt pip install -r requirements.txt python setup.py install
次にデータセットのダウンロードを行います。今回用いるSST-2はTextClassificationに含まれています。
cd datasets/ ./download_text_classification.sh
次に以下のコマンドで学習を行います。今回はレポジトリに登録されている設定ファイルを用います。この設定ファイルはPromptingの手法としてはLMBFFを用い、SST-2を解きます。学習データとしては各クラス16個の学習データを用いる設定となっています*1。
cd ../ mkdir logs python experiments/cli.py --config_yaml experiments/lmbff.yaml
上記コマンドを実行後、学習の経過が出力されます。最後に以下のように表示されて終了します。
[2022-08-29 19:35:47,908 INFO] trainer.inference_epoch test Performance: OrderedDict([('micro-f1', 0.926605504587156), ('accuracy', 0.926605504587156)])
今回学習したモデルのtestデータによる精度は 0.926605504587156
となります*2。
学習実行時に --config_yaml
オプションで指定したファイルが設定ファイルになります。このファイルにタスクの種類、学習データ、Promptingで用いるモデルの設定、few shot learningの設定など細かく指定することができます。詳しくはOpenPromptのドキュメントを参考にしていただければと思います。
まとめ
今回はPromptingという手法とそれを行うフレームワークのOpenPromptの紹介を行いました。OpenPromptの使用例としてLMBFFをSST-2を用いて学習する手順を紹介しました。 大量の学習データを用いてDeep Learningで精度の高い手法は今まで数多くでてきていますが、一方で学習データをあまり多く用意できない状況もまた多くあります。そのような状況で活用できる few shot learningの手法およびフレームワークとしてPromptingとOpenPromptに注目しております。