U-NetでPASCAL VOC 2012のImage Segmentation

この記事の内容

この記事では、U-Netを使って、PASCAL VOC 2012データセットのセグメンテーションをしてみます。PyTorchで実装したソースコードは以下で公開しています。

https://github.com/shioili/unet_voc/tree/master

はじめに

前回の記事では、Deep Learning向けのGPUマシンを作成しました。 cinnamell.hatenablog.com

今回は、これを使って、画像のセグメンテーションタスクを学習させてみます。セグメンテーションとは、画像に写っているオブジェクト(例えば、犬、猫、車、など)を、ピクセル単位でクラス分けするタスクです。例えば、下の写真を入力とすると、飛行機や、車、人などを識別して、それぞれに色を塗った画像を出力します。

f:id:cinnamell:20191214114739j:plain

f:id:cinnamell:20191214114754p:plain

 セグメンテーションの学習に使えるデータセットは、上の画像のように、写真と、それに対応するセグメンテーションの正解データから構成されます。いくつか有名なオープンデータセットがあるのですが、今回はPASCAL VOC 2012を使います。

セグメンテーションに使うネットワークは、とりあえずU-Netを実装してみることにします。U-Netは比較的単純なネットワークなので、実装が容易だと考えたからです。U-Netの詳細については、ネットに詳しい記事がいっぱいあるので、そちらをご覧ください(手抜き)。

学習

学習にはPASCAL VOC 2012には、セグメンテーションデータがついている画像が2913枚用意されています。今回は、このうち、ランダムに選んだ2713枚を学習に使います。その他の200枚は、テスト用に使います。

学習時には、前処理として、画像から256x256サイズをランダムに切り出して(Random Crop)使います。画像のサイズを揃えておかないと、バッチを作ることができないのが1つの理由ですが、ランダムに切り出すことでデータ量を水増しする効果も期待できます(ただし、ランダムにクロップしない場合に比べて、本当に汎化性能が良くなるのかはわかりません。)

学習時のLossの遷移は下のグラフのようになりました。ちなみに、390 epoch学習するのに、およそ9時間ほどかかりました。

f:id:cinnamell:20191214120854p:plain

ちゃんと学習できてますね。ただ、ここまで来るまでにいくつかハマりポイントがありました。

1つ目はBatch Normalizationです。U-Netの元論文では、Normalizationに関する記述は何もありません(その代わり、初期値の与え方について記述がある)。しかし、実際には、畳み込み層の後にBatch Normを入れてやらないと、学習がほとんど進みませんでした。Batch Normを入れる知恵は、こちらの記事を参考にさせていただきました。

qiita.com2つ目は、Random Cropのサイズです。上で書いたとおり、学習時には、画像を256x256に切り出して使っています。最初、128x128で切り出していたのですが、これでは全く学習できませんでした。256x256なら学習できます。U-Netは完全に畳み込みだけで構成されたネットワークなので、クロップのサイズはあまり関係がないようにも思うのですが・・・結局、なぜ256x256にする必要があるのか、理由はよく分かっていません。小さいサイズではPaddingが悪さをするのでしょうか。

テスト

学習したモデルで200枚のテスト画像を評価してみました。平均Accuracyは80%ということで、まずまずセグメンテーションできていると言えるのではないでしょうか。もっとも、なんのオブジェクトも属しないバックグラウンドも含めたAccuracyなので、実際に絵を見てみると、まだ識別ミスが多いなーという印象です。

f:id:cinnamell:20191214122726p:plain

f:id:cinnamell:20191214122730p:plain

最後に

とりあえず、それなりにセグメンテーションができていることが確認できたのは満足です。ただ、まだ精度に課題がありますので、次はPSPNetあたりを試してみたいところです。