PyTorchで線画着色
はじめに
高校の冬休み中に何かやりたいと思い、準備自体は12月から始めていたのですが線画着色に挑戦しました。
PaintsChainerの登場から線画着色を行っている方は多く、n番煎じではありますがどうぞお付き合いください。PyTorchのアドベントカレンダーに参加しようと思いましたが、まだまだ未熟なため遠慮しました....。
初めに結果から(左が本物で右がGeneratorによる偽物)
未学習データでここまで綺麗になったのでうれしいです。
画像は [かがちさく様] : https://www.pixiv.net/member_illust.php?mode=medium&illust_id=72418691 からお借りしました。
実装 :
使用データ
nico-opendataで選別したカラーイラスト1万と数千枚。漫画系とモノクロイラストは使いませんでした。それをOpenCVで線画に変えてtrainデータとしました。
学習方法
PaintsChainerがどのように学習させたか詳しくわからなかったのでPix2Pixの学習方法で行いました。以前実装したDCGANのTrain部分をほぼそのまま使いました。
lossに関してはDiscriminatorにおいては普通のGANLoss、GeneratorにおいてはGANLossと、本物と偽物とのL1LossとL2Lossを足して10倍したものを使いました。L1Lossだけで十分です。無駄でした。
ネットワーク
GeneratorにはUNet、Discriminatorには普通のCNNを用いました。
Generator
UnetのUpにConvTransposeかPixelShufflerかで迷ったのですが使い慣れてるConvTransposeにしました。ほかのGANでPixelShufflerの方が良い結果が出たりしていたので比較も行いたかったですがConvTransposeで。
最初は出力層にTanh入れてなかったのですが、色が一辺倒になってしまったのでTanhを入れることで表現力上げました。Sigmoidでもよかったかも。
コード(Gistとかで)を張るとさらに冗長になってしまいそうなので以下を参照してください。
PainTorch/unet.py at master · reppy4620/PainTorch · GitHub
Discriminator
普通のCNNです。なんの変哲もありません。特に言うことはないです。
PainTorch/discriminator.py at master · reppy4620/PainTorch · GitHub
感想
過学習を起こそうと思って120epochくらい回したら、OpenCVで生成した線画にしか反応しなくなってしまい、実質カラーイラストの塗りなおしネットワークになってしまいましたが、塗りなおしネットワークとしてはとても満足できる出力をしてくれました。
まあ、思い付きで始めた割に結構うまくいってよかったと思います。 薄っぺらい内容となってしまいましたが冬休み報告を兼ねた投稿を〆させていただきます。
ありがとうございました。