レトリバセミナーで関数型(Elixir)でConvolutional Neural Networkを実装した話をしました

レトリバ製品企画部の田村(@masatam81)です。
1/15にレトリバセミナーにて「Functional CNN in Elixir」というタイトルでお話しました。

retrieva.connpass.com

セミナーの内容としては、趣味や20%ルールで書いていたElixirでのDeep LearningフレームワークでConvolution2D/MaxPooling2Dまで使えるようになったので、関数型言語ならではの苦労や工夫・学びの共有といった感じでした。本記事は、そのフォローアップ記事です。当日はDeep LearningやElixirに詳しくない人もいたため、それらの簡単な説明も含めて、なるべく伝わるように話しました。

本記事ではDNNやCNN、Elixirの簡単な説明やConvolution2D、MaxPooling2D以外のレイヤーの話は省きます。その辺りの話は下記のYouTubeや関連書籍を見て下さい。動画ではNIFやMatrexに関しても少しだけ話しています。


Functional CNN in Elixir

スライドはこちらになります。

www.slideshare.net

また、本記事ではConvolution2D, MaxPooling2DのElixirでの実装で工夫した点にフォーカスして書きたいと思います。 セミナーでは最初に defstruct で構造体を作り、 defprotocol を利用して各レイヤーをLayerモジュールの継承関係のように扱う例を説明してから、それを関数型らしく修正した話をしましたが、今回の引用コードは修正後のコードを示します。

実装のベースになっている知識はオライリー・ジャパンの「ゼロから作るDeep Learning」です。元の処理に関する詳細な説明は、そちらをご覧下さい。

www.oreilly.co.jp

畳み込みの準備

まず、Convolutionなど畳み込み処理によく使われるim2colという手法の説明と実装です。

im2colとは

ConvolutionやPoolingの処理では指定されたn × m(n == mが多い)のカーネルを順に掛けていって、カーネル内の各要素、もしくはカーネル全体に対して処理を行います。このカーネルに対する処理を行いやすくするために、カーネル内の各要素をフラットに並べて積む処理がim2colです。f:id:ret_tamura:20200121132841p:plain

また、Convolutionの重みとの行列積やバイアスの足し合わせを容易にするため、チャンネル(4次元のリスト=テンソルのバッチの次の2次元目をチャンネルと呼びます)を後ろにつなぎ合わせたりします。

f:id:ret_tamura:20200121133236p:plain

Elixirでの実装

これをElixirで、単方向リストで処理しやすい形で実装したのですが、通常のやり方で多次元リストを任意の形に変換するのはなかなか面倒です。

下記コードではforでリストの形を作っているのですが、Elixirのforとジェネレーターの組み合わせは他のプログラミング言語と違って、ジェネレーター(x <- ... のようなもの)で作成される要素分のリストを生成します。それを用いて、チャンネルの位置でim2colの行列を結合、もしくはそれぞれのチャンネルに吐いています。つまり、一般的な「ゼロから作るDeep Learning」などでは(バッチ数 × 被カーネル行 × 被カーネル列, カーネルの要素数 × 入力チャンネル数)になりますが、Elixirのリストでの使い勝手を考慮して、(バッチ数, 被カーネル行, 被カーネル行, カーネルの要素数 × チャンネル数)になっています。下記は行列に部分に対してim2colをかけるコードです。

  defp matrix_filtering_impl(list, filter_height, filter_width, stride, padding, map_func) do
    list = if padding == 0, do: list, else: pad(list, padding)
    org_h = length(list)
    org_w = length(hd(list))
    out_h = div(org_h - kernel_height, stride) + 1
    out_w = div(org_w - kernel_width, stride) + 1

    for y <- for(i <- 0..(out_h - 1), do: i * stride) |> Enum.filter(&(&1 < org_h)) do
      for x <- for(i <- 0..(out_w - 1), do: i * stride) |> Enum.filter(&(&1 < org_w)) do
        list
        |> Enum.drop(y)
        |> Enum.take(kernel_height)
        |> Enum.map(&(Enum.drop(&1, x) |> Enum.take(kernel_width)))
        |> List.flatten()
        |> map_func.()
      end
    end
  end

out_h = div(org_h - kernel_height, stride) + 1ストライド幅とカーネルの行数、元の行列の行数を元に、その行列にカーネルが何回かけられるかを算出しています。本記事内では被カーネル行と表現させて頂きます。 out_w は同様に列側に何回かけられるかで、被カーネル列と便宜的に呼びます。

        |> Enum.drop(y)
        |> Enum.take(kernel_height)
        |> Enum.map(&(Enum.drop(&1, x) |> Enum.take(kernel_width)))
        |> List.flatten()

この部分をy=1, x=1の場合で説明します。 まずカーネルの開始位置を基準に、

  • 開始位置より前の行をdrop(Enum.drop(y))

f:id:ret_tamura:20200121143540p:plain

  • カーネルの行数分だけtakeで残す(Enum.take(kernel_height))

f:id:ret_tamura:20200121143842p:plain

  • 各行に対して開始位置より前のカラムをdrop(Enum.map(&(Enum.drop(&1, x))

f:id:ret_tamura:20200121144115p:plain

  • 最後にカーネルの列数分だけtakeで残す(|> Enum.take(kernel_width))))

f:id:ret_tamura:20200121144323p:plain

という処理を行い、最後にflattenで1列に並び替えています。

その後に map_func.() を呼んでいるのは、Poolingの場合、この時点で処理をしてしまえばカーネルで切り出しながら処理を終わらせられるためです。

以降はこのim2colをカーネルに対する処理の準備として実行します。言語特性上、処理を簡単にさせたいため、私の書いたim2colではチャンネルは結合するかそのままか選択できるように書いています。

forward処理

推論のforwardとパラメーターを誤差逆伝播するbackwardでは、forward同士・backward同士で実装が似ているため、処理ごとに分けて説明します。まずは推論に用いるforward処理です。forward処理は関数型、と言ってもあまり違いがない気もします。

MaxPooling2D

MaxPooling2Dには下記の特徴があります。

  • カーネル内の最大の値のみを残す
  • チャンネル数に変更はない

チャンネルに変更はないので、チャンネル内の行列から、行列に適応されるカーネル数のデータが出力されます。

Elixirでの実装

  def forward(x, [:max_pooling2d, fix_params]) do
    res =
      NN.matrix_filtering(
        x,
        fix_params[:pool_height],
        fix_params[:pool_width],
        fix_params[:stride],
        fix_params[:padding],
        fn list -> Enum.max(list) end,
        :normal
      )

    mask =
      NN.matrix_filtering(
        x,
        fix_params[:pool_height],
        fix_params[:pool_width],
        fix_params[:stride],
        fix_params[:padding],
        fn list -> NN.argmax(list) end,
        :normal
      )

    [_, _, height, width] = NN.shape(x)

    {res, [mask: mask, original_height: height, original_width: width]}
  end

前述のim2colでカーネルで取り出したフラットなリストに関数を当てられるようにしていたので、 Enum.max でリストの最大値のみを取り出しています。カーネルのそれぞれの位置に1つの数値が残るので、これだけでOKです。最後の :normal はチャンネルを結合(同じバッチの全てのチャンネルが繋がり1チャンネルになる)するかのatom:merge 以外はそのままになっています。 セミナーでは一般的なim2colを使ったMaxPooling2Dの動きを説明しましたが、厳密には私の書いたMaxPooling2Dの動きは下のようにreshapeの要らないシンプルな動きになっています。

f:id:ret_tamura:20200122133642p:plain

maskに関してはbackwardで使用します。MaxPoolingでは最大値のデータ以外は以降の推論に影響しないため、パラメータの修正もそのデータの部分のみで良いため、"カーネルを平らに並べた何番目のデータを使ったか"を残しています。

Convolution2D

Convolution2Dは2つのパラメータがあります。

  • カーネルに対する重み(出力チャンネル数, 入力チャンネル数 × カーネル行 × カーネル列)
  • チャンネルに対するバイアス(1つのチャンネルに1つの数値)

バイアスは対象のチャンネルの行列、全てのデータに同じものが足されます。im2colの説明で"カーネルに対する処理を行いやすくするため"と書きましたが、Convolution2Dの処理はかなり楽になります(ただし、後処理が必要になりますが)。

まず、im2colでチャンネルが結合したカーネルで行列の特定位置を抜いたリストが作られます。チャンネルをわかりやすく言うと、RGBの画像データのカーネルに対応する同じ位置のRのデータ、Gのデータ、Bのデータで、これらが並んだ状態になっています。今回のim2colをかけると、バッチ数Nとして(N, 被カーネル行, 被カーネル列, 入力チャンネル × カーネル行 × カーネル列)のデータになります。処理の都合上(行列積は行列同士にすると一気に計算できる)、一旦前の3つのサイズをまとめた形でreshapeします。重みの転置とのreshape後の行列の積は(N × 被カーネル行 × 被カーネル列, 出力チャンネル)になるので、それぞれのカーネルの位置の全チャンネル情報から新しいチャンネル数分のカーネル処理後が出力されるようなイメージになります。

さらにバイアスの形は出力チャンネル数分の値のリストなので、そのまま足せば、カーネルで取得した値に重みを掛けてバイアスを足す、という処理が完了します。

後処理が必要と書いたのは、ここまで結果が(N × 結果カーネル行 × 結果カーネル列, チャンネル数)という形になってるためです。なので、一旦reshapeにて(N, カーネル行, カーネル列, チャンネル数)という形に解釈を変えます(基本的にデータの並びは変わらない)。すると下の図のようになるので、

f:id:ret_tamura:20200122165548p:plain
2チャンネルの入力に、出力チャンネル2で(2, 2)のカーネルの例

あとは、transpose(0, 3, 1, 2)で並び替えればOKです。

Elixirでの実装

  def forward(x, [:conv2d, params, fix_params]) do
    col =
      NN.matrix_filtering(
        x,
        fix_params[:filter_height],
        fix_params[:filter_width],
        fix_params[:stride],
        fix_params[:padding]
      )

    size = NN.data_size(col)
    target_size = List.last(NN.shape(col))
    col = NN.reshape(col, [div(size, target_size), target_size])

    res =
      NN.dot_nt(col, params[:weight])
      |> NN.add(params[:bias])
      |> NN.reshape(
        get_conv_forward_shape(
          x,
          fix_params[:filter_height],
          fix_params[:filter_width],
          fix_params[:padding],
          fix_params[:stride],
          length(params[:weight])
        )
      )
      |> NN.transpose(0, 3, 1, 2)

    {res, [col: col]}
  end

backward処理

backward処理は関数型ではなかなか大変です。Elixirの基本データ構造は単方向リストで不変性(一度メモリ上で定義されたら破棄されるまで変更されない)という特性があります。ElixirでCNNを書こうとしてbackwardで行き詰まった人もいると思います。Convolution2DもMaxPooling2Dもカーネルの特定の位置に該当する元のデータでの位置に逆伝播しなくてはならない上、Convolutionでは同じ位置が複数回使われることが多いので集計が必要になります。

この問題を解決するため、次の(モデル上入力に近い)レイヤーに渡すための処理はマップを使って、戻るべき位置のindexへの値を集計した上で、集計を元にテンソルを構築する形を取りました。

MaxPooling2D

まずはよりシンプルなMaxPooling2Dに関して説明します。MaxPooling2Dにはパラメータがないので、基本的に次のレイヤーに戻す処理のみです。カーネルの対象の数値の中で最も大きい値のみが採用されるため、最大値のあったデータの位置のみに逆伝播され、他は0になります。また、カーネルがオーバーラップしている場合は加算された値が伝播します。

f:id:ret_tamura:20200122173917p:plain

Elixirでの実装

これを下記のようなMapを使ったコードで実現しました。

  def backward(dout, [:max_pooling2d, fix_params, config]) do
    Enum.zip(config[:mask], dout)
    |> Enum.map(fn {mask_batch, dout_batch} ->
      Enum.zip(mask_batch, dout_batch)
      |> Enum.map(fn {mask_channel, dout_channel} ->
        {_, result} =
          Enum.zip(mask_channel, dout_channel)
          |> Enum.reduce(
            {0, %{}},
            fn {masks, list}, {idx1, map1} ->
              {_, result_map} =
                Enum.zip(masks, list)
                |> Enum.reduce(
                  {0, map1},
                  fn {mask, val}, {idx2, map} ->
                    idx =
                      (div(mask, fix_params[:pool_width]) + idx1) * config[:original_width] +
                        idx2 +
                        rem(mask, fix_params[:pool_width])

                    {idx2 + fix_params[:stride], Map.update(map, idx, val, fn v -> v + val end)}
                  end
                )

              {idx1 + fix_params[:stride], result_map}
            end
          )

        for y <- 0..(config[:original_height] - 1) do
          for x <- 0..(config[:original_width] - 1) do
            idx = y * config[:original_width] + x
            Map.get(result, idx, 0.0)
          end
        end
      end)
    end)
  end

ポイントとしては、下記のコードでbackwardの入力の位置から、入力のどこの位置に値を戻すかを決定しています。

                  fn {mask, val}, {idx2, map} ->
                    idx =
                      (div(mask, fix_params[:pool_width]) + idx1) * config[:original_width] +
                        idx2 +
                        rem(mask, fix_params[:pool_width])

                    {idx2 + fix_params[:stride], Map.update(map, idx, val, fn v -> v + val end)}
                  end

それを、元のshapeの多次元リストをforで構築しながら値を設定しています。

        for y <- 0..(config[:original_height] - 1) do
          for x <- 0..(config[:original_width] - 1) do
            idx = y * config[:original_width] + x
            Map.get(result, idx, 0.0)
          end
        end

Convolution2D

Convolution2Dはパラメータが2つあるので、それぞれの更新のための処理もあります。 パラメータ部分は一旦backwardへの入力をim2colの結果の形に戻してやる必要がある以外はAffine(全結合層)と大差ないので簡単です。

Elixirでの実装

コードが長いので全体は省略して、要点だけ抜き出して説明します。

パラメータの更新のための処理は結構スッキリまとまりました。

    dout_t =
      NN.transpose(dout, 0, 2, 3, 1)
      |> NN.reshape([div(NN.data_size(dout), length(params[:weight])), length(params[:weight])])

    db = NN.sum(dout_t, :col)

    dw =
      NN.dot_tn(config[:col], dout_t)
      |> NN.transpose()

問題の次の層へ戻す処理ですが、基本的にはMaxPooling2Dと同じで、Mapで集計しています。違いとしては、重みを掛けてあげる必要があるのと、カーネルの全てのカラムに値を戻しつつ集計する必要がある(MaxPooling2Dでも最大値が複数カーネルで同じ位置なら必要ですが)ことでしょうか。

MaxPooling2Dではpadding部分が最大値になることは稀なのであまり影響ありませんが、Convolution2Dでは下記の最後のforの処理が結構上手く動いてると思います。

      for c <- 0..max_c do
        for h <- 0..(max_h - fix_params[:padding]) do
          for w <- 0..(max_w - fix_params[:padding]) do
            Map.get(channel_map, {c, h, w}, 0.0)
          end
        end
      end

paddingで増えた位置、forwardのinputの(-1, -1)のようなindexで集計されて入るものの、結果を構築する際には無視されるため、集計時点では戻る位置のindexがはみ出ないか確認せず一律全てのカーネルが当たったとされる位置を集計すれば良い、と処理が容易になっています。

速度の問題

Mapを使った集計処理はElixirの単方向リストでの処理で実現するという意味では目的を達成しているものの、処理速度はかなり厳しいものがあり、せっかく実装したものの、対象のbackward処理はNIF等で最優先で改善する必要があると思います。

現状の処理速度としては、下記の構成でMNISTの60000の訓練データに対し1エポック4時間弱 f:id:ret_tamura:20200122180817p:plain

同様に下記の構成でCIFAR10の50000の訓練データに対し、1エポック100〜150時間かかります。 f:id:ret_tamura:20200122181019p:plain

なお、計測したマシンはMacBook Pro(2017)でCore i5 2.3GHz、16GB LPDDR3、MacOS Mojave 10.14.6でElixirの環境は1.9.1、Erlang/OTP 21です。

フレームワークの今後

NIFでC++/OpenCLでの高速化(CUDAでなくても速くしたい気持ちがあり)を考えてはいますが、自分の勉強のためとエコシステムが整った時に高速化できそう、という目論見で書き始めましたが、趣味のフレームワークで終わりそうな気がします。

とはいえ、実際に書いてみての学びは非常に大きかったので、ゆっくりメンテナンスしつつ、他の層や手法も実装できればと思います。Define by Runに関数型で挑戦するというのも面白いかも知れません。やはり、違う言語で書く場合、足りないものが多く、大変ではありますが、その分本質的な理解が求められるので非常に勉強になりました。