CTC 誤差関数を完全に理解したい(後編)

こんにちは、リサーチャーの古谷(@kk_fry_)です。 私は普段、音声認識の研究開発をしています。 前回の記事から、End-to-End 音声認識で用いられる Connectionist Temporal Classification (CTC) 誤差関数の解説をしています。 本記事はその後編となります。

前回は、CTC 誤差関数の定義と計算方法を解説し、その偏導関数を導出しました。 今回は、CTC 誤差関数の勾配降下法による学習について解説し、その解釈を考えてみます。

本記事の目的

本記事の目的は以下のようになります。

  • CTC 誤差関数の計算方法を解説(前回)
  • CTC 誤差関数の偏導関数について解説(今回のメイン)
  • CTC 誤差関数の勾配降下法の解釈(今回のおまけ)

この流れで書いていきますが、最初に色々と準備をします。

参考文献

以下の 2 つの文献を参考にしています。

CTC 誤差関数の原著論文:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks (Alex Graves et al.) 参考テキスト(7章):Supervised Sequence Labelling with Recurrent Neural Networks (Alex Graves)

準備:記号の表記

本記事で用いる記号の定義です。基本的には前述の参考文献の記号に従っていますが、一部変更してあります。

U : 正解ラベル長(文字数、音素数など。ラベルとは、文字や音素など、出力する単位のこと)

T : 入力長(音響特徴量の時間フレーム数)。なお、U\le T でなければいけない

L : ラベルの集合

L'=L\cup {\mbox{blank}} : ブランクを含むラベルの文字集合

L^{\le T} : T 以下の、L の要素からなる列

\mathcal{B}: L'^T \mapsto L^{\le T} : CTC の縮約規則に従う縮約関数。ブランクがなくなる。逆関数\mathcal{B}^{-1} だがこちらは一対多になる

N=\{1, \dots, |L|\} : ラベルに用いられる文字の番号の集合

\boldsymbol{x} = (x_1, \dots , x_T) \in (\mathbb{R}^ m)^ T : 音響特徴量。m は特徴量の次元。 m 次元対数メルスペクトルなど

\boldsymbol{z} = (z_1, \dots , z_U) \in L^ U : 正解ラベル列

S : データセット(\boldsymbol{x},\boldsymbol{z}) の集合

\boldsymbol{l} : ブランクを含まないラベル列の例

\boldsymbol{l}' : ラベル列 \boldsymbol{l} の両端と各ラベルの間にブランクを追加したブランク入りラベル列。 |\boldsymbol{l}'|=2|\boldsymbol{l}|+1 である

\boldsymbol{y} \in (\mathbb{R}^n)^ T : 音響モデル出力。時刻の添字は右上に付く( y^ t が時刻 t における出力)。なお、n はラベルに使われる文字の集合の要素数 |L|+1 である(+1 はブランクに対応)

y_k^ t : 時刻 t におけるラベル k の確率

\pi \in L'^ T : ブランクを含む縮約前ラベル列の例

lab(\boldsymbol{l}, k) : \boldsymbol{l}′_s=k となる添字 s の集合(k\in L'。ラベル列 \boldsymbol{l} の両端と間にブランクを挿入したラベル列 \boldsymbol{l}' を用いていることに注意)

CTC 誤差関数の偏導関数

音声認識における教師あり学習では、正解ラベルが出力される確率を最大化したいです。

そのため、誤差関数は音響特徴量 \boldsymbol{x} から正解ラベル \boldsymbol{z} が出力される確率 P(\boldsymbol{z}|\boldsymbol{x}) の負の対数とします。

この値を、データセット S に含まれるすべての音声・テキスト対に対して合計した値が CTC 誤差関数となります。


\begin{aligned}
\mathcal{L}_{\mathrm{CTC}}(S)=-\sum_{(\boldsymbol{x}, \boldsymbol{z}) \in S}\log P(\boldsymbol{z}|\boldsymbol{x})
\end{aligned}

勾配を計算する際には、1 つのデータ (\boldsymbol{x}, \boldsymbol{z}) について計算したものを合計すれば良いので、以下では 1 つのデータに対する誤差関数を考えます。


\begin{aligned}\mathcal{L}_{\mathrm{CTC}}(\{(\boldsymbol{x}, \boldsymbol{z})\})=-\log P(\boldsymbol{z}|\boldsymbol{x})
\end{aligned}

P(\boldsymbol{z} | \boldsymbol{x}) を音響モデルの出力値 y_k^ t偏微分すると以下のようになることを前回の記事で確認しました。


\begin{aligned}
\frac{\partial}{\partial y_k^t}P(\boldsymbol{z} | \boldsymbol{x})=\frac{1}{{y_k^t}^2}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{x},\boldsymbol{z})\beta_t(s|\boldsymbol{x},\boldsymbol{z})
\end{aligned}

なお、\alpha_t(s|\boldsymbol{x},\boldsymbol{z}), \beta_t(s|\boldsymbol{x},\boldsymbol{z}) については前回の記事をご覧ください。

これを用いて計算すると、誤差関数の偏導関数が以下のようになります。


\begin{aligned}
\frac{\partial}{\partial y_k^t}\mathcal{L}_{\mathrm{CTC}}(\{(\boldsymbol{x}, \boldsymbol{z})\})=-\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{\partial}{\partial y_k^t}P(\boldsymbol{z}| \boldsymbol{x})=-\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{{y_k^t}^2}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{x},\boldsymbol{z})\beta_t(s|\boldsymbol{x},\boldsymbol{z})
\end{aligned}

y_k^ t の値は、時刻 t におけるラベルが k である確率を表すので、softmax 関数の出力値であることが普通です。

softmax 関数を通る前の値を u_k^ t として、CTC 誤差関数を u_k^ t偏微分した値を考えてみましょう。softmax 関数の定義より


\begin{aligned}
y_k^t=\frac{\exp(u_k^t)}{\sum_{k'} \exp(u_{k'}^t)}
\end{aligned}

であるので、


\begin{aligned}
\frac{\partial y_{k'}^t}{\partial u_k^t}=y_{k'}^t\delta_{kk'}-y_{k'}^ty_k^t
\end{aligned}

となります。ただし、\delta_{kk'}クロネッカーのデルタです。途中式についてはこちらの記事などを参照してください。

この式を用いて、CTC 誤差関数の偏導関数を計算していきます。合成関数の微分(チェーンルール)により、


\begin{aligned}
\frac{\partial}{\partial u_k^t}\mathcal{L}_{\mathrm{CTC}}(S)=\sum_{k'}\frac{\partial}{\partial y_{k'}^t}\mathcal{L}_{\mathrm{CTC}}(S)\frac{\partial y_{k'}^t}{\partial u_k^t}
\end{aligned}

\begin{aligned}
=-\sum_{k'}\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{{y_{k'}^t}^2}\left ( \sum_{s\in lab(\boldsymbol{z},k')}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})\right )\left (y_{k'}^t\delta_{kk'}-y_{k'}^ty_k^t\right )
\end{aligned}

となります。ここで、クロネッカーのデルタがついている項については、k'=k の場合だけ考えれば良いので、


\begin{aligned}
-\frac{1}{P(\boldsymbol{z}|\boldsymbol{x})}\frac{1}{y_k^t}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

となります。また、クロネッカーのデルタがつかない項の計算で、以下の等式を用います。


\begin{aligned}
\sum_{k'}\sum_{s\in lab(\boldsymbol{z},k')}
\frac{\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})}{y_{k'}^t}=P(\boldsymbol{z}| \boldsymbol{x})
\end{aligned}

これは、lab(\boldsymbol{z}, k') が、z_s=k' となる s の集合なので、この二重総和は時刻 t におけるすべてのラベルに関する総和となるため成立します。前回の記事にも似たような式があります。これにより、クロネッカーのデルタがつかない項について、


\begin{aligned}
\sum_{k'}\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{{y_{k'}^t}^2}y_{k'}^ty_k^t\sum_{s\in lab(\boldsymbol{z},k')}
\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

\begin{aligned}
=\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}y_k^t\sum_{k'}\sum_{s\in lab(\boldsymbol{z},k')}
\frac{\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})}{y_{k'}^t}=y_k^t
\end{aligned}

となります。したがって、


\begin{aligned}
\frac{\partial}{\partial u_k^t}\mathcal{L}_{\mathrm{CTC}}(\{(\boldsymbol{x}, \boldsymbol{z})\})=y_k^t-\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{y_{k}^t}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

となります。この式に出てくる値はすべて CTC 誤差関数の値を計算するときに計算済みです。これを用いることで、softmax 関数を通す前の状態から勾配を計算することができます。

アンダーフローを防ぐ計算方法

\alpha_t(s|\boldsymbol{z}),\beta_t(s|\boldsymbol{z}) の値をそのまま計算しようとすると、非常に小さい値になってアンダーフローを引き起こすおそれがあります。そのため、実際に計算する際には、以下のように正規化して計算します。


\begin{aligned}
\hat\alpha_t(s|\boldsymbol{z}):=\frac{\alpha_t(s|\boldsymbol{z})}{\sum_{s'} \alpha_t(s'|\boldsymbol{z})}
\end{aligned}

\begin{aligned}
\hat\beta_t(s|\boldsymbol{z}):=\frac{\beta_t(s|\boldsymbol{z})}{\sum_{s'} \beta_t(s'|\boldsymbol{z})}
\end{aligned}

さらに、P(\boldsymbol{z} | \boldsymbol{x}) の代わりに


\begin{aligned}
Z_t:=\sum_{s=1}^{|\boldsymbol{z}'|}\frac{\hat\alpha_t(s|\boldsymbol{z})\hat\beta_t(s|\boldsymbol{z})}{y_{\boldsymbol{z}_s'}^t}
\end{aligned}

を用いて


\begin{aligned}
\frac{\partial}{\partial u_k^t}\mathcal{L}_{\mathrm{CTC}}(\{(\boldsymbol{x}, \boldsymbol{z})\})=y_k^t-\frac{1}{Z_t}\frac{1}{y_{k}^t}\sum_{s\in lab(\boldsymbol{z},k)}\hat\alpha_t(s|\boldsymbol{z})\hat\beta_t(s|\boldsymbol{z})
\end{aligned}

と計算することで、正規化を分母分子で打ち消すことができます。

CTC 誤差関数の勾配降下法の解釈


\begin{aligned}
\frac{\partial}{\partial u_k^t}\mathcal{L}_{\mathrm{CTC}}(\{(\boldsymbol{x}, \boldsymbol{z})\})=y_k^t-\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{y_{k}^t}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

の意味について考えてみましょう。一部、感覚的な話が出てきますがご了承ください。


\begin{aligned}
\frac{\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})}{y_{\boldsymbol{l}'_s}^t}
\end{aligned}

は、前回の記事に載せたグラフにおいて、頂点 (s,t) を通って始点から終点まで行く確率です。

したがって、


\begin{aligned}
\frac{1}{y_{k}^t}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

は、\boldsymbol{l}'_s=k となるような全ての添字 s に対する、頂点 (s, t) を通って始点から終点へ行く確率の総和となります。そのため、


\begin{aligned}
\frac{1}{P(\boldsymbol{z}| \boldsymbol{x})}\frac{1}{y_{k}^t}\sum_{s\in lab(\boldsymbol{z},k)}\alpha_t(s|\boldsymbol{z})\beta_t(s|\boldsymbol{z})
\end{aligned}

は、グラフにおいて始点から終点へ行く、つまり音響特徴量 \boldsymbol{x} から正解ラベル列 \boldsymbol{z} が実現するという条件の下で、縮約前ラベル列において時刻 t のラベルが k である条件付き確率になります。

k が正解ラベル列に存在しないラベルのとき、誤差の第2項が 0 になるため、誤差の勾配の値は y_k^ t>0 となり、u_k^ t の値を小さくするように作用します。

k が正解ラベル列の序盤のみに存在するラベルの場合、第2項の値は t が小さいときに大きくなり、t が大きいときには小さくなります。すでに y_k^ t の値がそのようになっている場合は、誤差の勾配は 0 に近い値になります。

k が正解ラベル列の序盤のみに存在するラベルの場合で、t が小さいのに y_k^ t の値が小さい場合、誤差の勾配が負になり、u_k^ t の値を大きくするように作用します。同様に、t が大きいのに y_k^ t の値が大きい場合、誤差の勾配が正になり、u_k^ t の値を小さくするように作用します。

k がブランクの場合、第2項はそこそこ大きな値になります。なぜなら、lab(\boldsymbol{z},\mathrm{blank}) は「正解ラベル列 \boldsymbol{z} の両端と間にブランクを挿入した文字列のうち、ブランクになっている部分の添字」であり、添字全体にまんべんなく、過半数存在するからです。したがって、誤差の勾配が負になりやすく、u_{\mathrm{blank}}^ t は大きな値になりやすいです。

したがって、基本的にはブランクの出力確率が大きくなりつつ、それだけだと P(\boldsymbol{z}|\boldsymbol{x}) の値が小さくなりすぎて第2項が大きくなるので、勾配降下法によって適度な位置で適切なラベルを出力する確率が上がっていき、ちょうどいいところを目指してくれる、という形になります。

CTC 誤差関数を用いた勾配降下法による学習はこのように解釈することができるかなと思います。

おわりに

今回は、前回の記事に引き続き、CTC 誤差関数を用いた勾配降下法による学習について数式を追ってみました。

PyTorch などの機械学習ライブラリには CTC 誤差関数が実装されているので、数式を理解せずとも音響モデルの学習は可能ですが、ブラックボックスのまま使うよりも、理解して使う方が楽しいのではないかと思います。

それでは皆様、よき音声認識ライフを。