CTC 誤差関数を完全に理解したい(前編)

こんにちは、リサーチャーの古谷(@kk_fry_)です。 私は普段、音声認識の研究開発をしています。 今回は、End-to-End 音声認識で用いられる Connectionist Temporal Classification (CTC) 誤差関数の計算式と微分について数式を追ってみたので、細かく書いてみようと思います。

記事が長くなったので、前後編に分けています。前編の今回は CTC 誤差関数の定義と計算方法を解説し、その偏導関数を導出します。 後編では、勾配の計算について解説し、解釈を考えてみる予定です。

本記事の目的

本記事の目的は以下のようになります。

  • CTC 誤差関数の計算方法を解説(今回はこれがメイン)
  • CTC 誤差関数の偏導関数について解説(今回はちょこっと覗くだけ)
  • CTC 誤差関数の勾配降下法の解釈(次回やりたい)

この流れで書いていきますが、最初に色々と準備をします。

参考文献

以下の 2 つの文献を参考にしています。

CTC 誤差関数の原著論文:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks (Alex Graves et al.)

参考テキスト(7章):Supervised Sequence Labelling with Recurrent Neural Networks (Alex Graves)

準備:記号の表記

本記事で用いる記号の定義です。基本的には前述の参考文献の記号に従っていますが、一部変更してあります。

U : 正解ラベル長(文字数、音素数など。ラベルとは、文字や音素など、出力する単位のこと)

T : 入力長(音響特徴量の時間フレーム数)。なお、U\le T でなければいけない

L : ラベルの集合

L'=L\cup {\mbox{blank}} : ブランクを含むラベルの文字集合

L^{\le T} : T 以下の、L の要素からなる列

\mathcal{B}: L'^T \mapsto L^{\le T} : CTC の縮約規則に従う縮約関数。ブランクがなくなる。逆関数\mathcal{B}^{-1} だがこちらは一対多になる

N={1, \dots, |L|} : ラベルに用いられる文字の番号の集合

\boldsymbol{x} = (x_1, \dots , x_T) \in (\mathbb{R}^ m)^ T : 音響特徴量。m は特徴量の次元。 m 次元対数メルスペクトルなど

\boldsymbol{z} = (z_1, \dots , z_U) \in L^ U : 正解ラベル列

S : データセット(\boldsymbol{x},\boldsymbol{z}) の集合

\boldsymbol{l} : ブランクを含まないラベル列の例

\boldsymbol{l}' : ラベル列 \boldsymbol{l} の両端と各ラベルの間にブランクを追加したブランク入りラベル列。 |\boldsymbol{l}'|=2|\boldsymbol{l}|+1 である

\boldsymbol{y} \in (\mathbb{R}^n)^ T : 音響モデル出力。時刻の添字は右上に付く( y^ t が時刻 t における出力)。なお、n はラベルに使われる文字の集合の要素数 |L|+1 である(+1 はブランクに対応)

y_k^ t : 時刻 t におけるラベル k の確率

\pi \in L'^ T : ブランクを含む縮約前ラベル列の例

CTC 誤差関数の定義

CTC の考え方

まず、音声認識システムと音響モデルについて定義し、CTC 誤差関数の定義を紹介します。

音声認識システムとは、音声を入力として、文字列を出力するシステムです。

CTC 誤差関数を用いる音声認識システムでは、ニューラルネットワークを用いた音響モデルを使用します。この音響モデルは、対数メルスペクトルなどの音響特徴量の列を入力として受け取り、音素や文字などのラベルの列を出力します。

音響モデルには、RNN などの可変長入力に対応したニューラルネットワークを使用しますが、そのまま使うと入力長と同じ長さのラベル列が出力されてしまいます。一般的に、音声フレームの数よりも音素や文字の数の方が少なくなるので、対処が必要です。

そこで、CTC の考え方では、出力ラベルの種類に blank を追加し、以下の処理を行う縮約関数 \mathcal{B}: L'^T \mapsto L^{\le T} を定義します。以下、ラベル列の例において blank-(ハイフン)で記述します。

  • まず、入力において同じラベルが連続している箇所を 1 つにまとめる(例:aa--aaa-bb-a-a-b-
  • 次に、ブランクを削除する(例:a-a-b-aab

この縮約によって、音響モデルの出力を短くすることができます。

したがって、音響モデルは「縮約すると正解ラベル列になるような、ブランクを含みうるラベル列(音声フレームと同じ長さ)」を出力することを目指します。

具体的な計算方法

音響モデルの入力は、音響特徴量 \boldsymbol{x} = (x_1, \dots , x_T) \in (\mathbb{R}^ m)^ T です。これは、対数メルスペクトルなどです。T は音声フレーム数です。

RNN などで構成される音響モデルは、この音響特徴量を受け取り、出力確率 \boldsymbol{y} \in (\mathbb{R}^ n)^ T を出力します。n はブランクを含む出力ラベルの種類数であり、y_k^ t が、第 t フレームにおけるラベル k の実現確率を表します。

このとき、ある(ブランクを含む)ラベル列 \pi の実現確率は


\begin{aligned}
P(\pi | \boldsymbol{x})=\prod_{t=1}^ T y_{\pi_t}^ t
\end{aligned}

となります。縮約関数を考慮すると、縮約後のラベル列 \boldsymbol{l} の実現確率は


\begin{aligned}
P(\boldsymbol{l} | \boldsymbol{x})=\sum_{\pi \in \mathcal{B}^{-1}(\boldsymbol{l})}P(\pi | \boldsymbol{x})
\end{aligned}

となります。

音声認識システムは、この P(\boldsymbol{l} | \boldsymbol{x}) が最大となるようなラベル \boldsymbol{l} を出力します。そのため、正解系列 \boldsymbol{z} における確率 P(\boldsymbol{z} | \boldsymbol{x}) が大きくなるようにニューラルネットワークのパラメータを調整したいです。

そのために、あるラベル列 \boldsymbol{l} に対する確率 P(\boldsymbol{l} | \boldsymbol{x}) を計算したいです。(ゆくゆくはこれを微分したいです。)

ここで、例として "cat" というラベル列について考えます。このとき、以下のようなグラフ(オートマトン)を用いて確率を計算することができます。

f:id:furuya1223:20210624105402j:plain
"cat" に関わる部分のみを抜粋したグラフ

このグラフは、始点(左上の二重丸)からスタートして、確率的に遷移し、たどり着いた頂点に対応するラベルを出力していくものです。特に、ラベル列の例("cat")に関わる部分のみ抜粋しています。ブランクはハイフンで表記しています。

各頂点には、「前の時刻からその頂点へ遷移する確率」が割り当てられています。これが、音響モデルの出力 y_k^ t となります。画像では、時刻 t の部分だけ、この確率を記載しています。

始点からスタートし、このグラフ上の経路をたどって、終点(右下の 2 つの二重丸のいずれか)で終了した場合に、"cat" が出力されます。終点は、"cat" の "t" またはその後のブランクの 2 通りあります。

なお、非ブランクラベルから別の非ブランクラベルに遷移することがあるので、c→a など、ブランクをまたぐ辺が存在します。

このグラフにおいて、始点から終点までの経路の一つを辿る確率が、その経路に対応するブランクありラベル列 \pi を出力する確率 P(\pi | \boldsymbol{x}) となります。このラベル列は、縮約によって "cat" になります。

そして、始点から終点まで移動する確率(全ての経路の確率の和)が、P(\texttt{"cat"} | \boldsymbol{x}) になります。

例えば、ブランクありラベル列 \pi の例として "-caa--t--" を考えると、この列は下図の経路に該当します。したがって、下図に記載されている確率を全て掛けると P(\texttt{"-caa--t--"} | \boldsymbol{x}) が得られます。

f:id:furuya1223:20210624105452j:plain
"-caa--t-" を出力する経路とそれに関わる確率

このグラフにおいて、始点から頂点 (s, t) に到達する確率を \alpha_t(s|\boldsymbol{x},\boldsymbol{l}) 、逆向きに見たときに終点から頂点 (s, t) に到達する確率を \beta_t(s|\boldsymbol{x},\boldsymbol{l}) とします1s はこの図の上から何段目の頂点かを表し、t は時刻を表します。 これらの確率は動的計画法で効率的に計算できます2

ここで、ブランク無しラベル列 \boldsymbol{l} の前後と各ラベルの間にブランクを挿入したラベル列 \boldsymbol{l}' を考えます。例えば、\boldsymbol{l} が "cat" のとき、 \boldsymbol{l}' は "-c-a-t-" です。 このとき、任意の時刻 t に対して、以下の等式が成り立ちます。


\begin{aligned}
P(\boldsymbol{l} | \boldsymbol{x})=\sum_{s=1}^{|\boldsymbol{l}'|}\frac{\alpha_t(s|\boldsymbol{x},\boldsymbol{l})\beta_t(s|\boldsymbol{x},\boldsymbol{l})}{y_{\boldsymbol{l}_s'}^t}
\end{aligned}

この等式は、「始点から終点まで移動する」という事象を「始点から終点まで、頂点 (1, t) を通って移動する」「……頂点 (2, t) を通って……」……という排反な事象に分割し、確率の和の法則を適用して得られます。なお、\alpha_t(s|\boldsymbol{x},\boldsymbol{l})\beta_t(s|\boldsymbol{x},\boldsymbol{l}) だと y_{\boldsymbol{l}'_s}^ t が 2 回掛けられることになるため、y_{\boldsymbol{l}'_s}^t で割っています。

入力特徴量 \boldsymbol{x} に対する正解ラベル列 \boldsymbol{z} の実現確率 P(\boldsymbol{z} | \boldsymbol{x}) が、最大化したいものであり、 -\log P(\boldsymbol{z} | \boldsymbol{x}) をデータセット  S の全てのデータ (\boldsymbol{x}, \boldsymbol{z}) について合計した値が、最小化したい CTC 誤差関数になります。


\begin{aligned}
\mathcal{L}_{\mathrm{CTC}}(S)=-\sum_{(\boldsymbol{x}, \boldsymbol{z}) \in S}\log P(\boldsymbol{z}|\boldsymbol{x})
\end{aligned}

勾配降下法による学習のために、確率 P(\boldsymbol{l}|\boldsymbol{x}) を音響モデル出力 y_k^ t微分したいです。このとき、\boldsymbol{l}'_s=k となる s のみ考慮すればよいです。そのような s の集合を lab(\boldsymbol{l},k) と書くことにします。

また、\alpha_t(s|\boldsymbol{x},\boldsymbol{l}), \beta_t(s|\boldsymbol{x},\boldsymbol{l})動的計画法における漸化式を考えると、それぞれ 「y_{\boldsymbol{l}'_s}^ t によらない値 \times y_{\boldsymbol{l}'_s}^ t」の形になっているため、


\begin{aligned}
\frac{\alpha_t(s|\boldsymbol{x},\boldsymbol{l})\beta_t(s|\boldsymbol{x},\boldsymbol{l})}{{y_{\boldsymbol{l}_s'}^t}^2}
\end{aligned}

y_{\boldsymbol{l}'_s}^ t によらない値となります。したがって、シグマの中身はこの値に y_{\boldsymbol{l}'_s}^ t が掛かっている値、つまり y_{\boldsymbol{l}'_s}^ t の定数倍であるため、偏微分の値は以下のようになります(係数だけが残る形です)。


\begin{aligned}
\frac{\partial}{\partial y_k^t}P(\boldsymbol{l} | \boldsymbol{x})=\frac{1}{{y_k^t}^2}\sum_{s\in lab(\boldsymbol{l},k)}\alpha_t(s|\boldsymbol{x},\boldsymbol{l})\beta_t(s|\boldsymbol{x},\boldsymbol{l})
\end{aligned}

この偏導関数を用いて、勾配降下法に基づくパラメータの更新を行うのですが、詳しい計算は後編の記事で見ていきましょう。

おわりに

本記事では、CTC 誤差関数の定義・計算方法・偏導関数について解説しました。 後編では、softmax 関数を考慮して計算される勾配や、勾配降下法の解釈などについて解説する予定です。 後編の記事とあわせて、原著論文や参考テキストの式がかなり理解できるようになる予定ですので、ご期待ください。


  1. これらはそれぞれ前向き確率・後ろ向き確率と呼ばれます。また、原著論文参考テキスト\beta_t(s|\boldsymbol{x},\boldsymbol{l}) の定義が異なっていますが、原著論文の定義を採用しています。

  2. 動的計画法の漸化式については原著論文の (6), (7), (10), (11) 式をご参照ください。