CNN for Sentence Classificationの実装
はおーてっつんでーす
youtubeの登録チャンネルはほぼV tuberのテツです。友達いないからフォローしてね!!!
前回の論文読みを終えて、せっかくなので実装してみようと言うことでchainerでやってみました。
ソースコードは以下のgithubのリンクからご確認ください!!
データセット
今回は大学の授業の課題もあり、sentence単位での青空文庫の著者推定をやりました!! 夏目漱石、江戸川乱歩、太宰治、芥川龍之介、宮沢賢治の5名の作品を僕がテキトーに12作ずつ選んで10作を訓練データに、2作をテストデータにしました。
データセットの詳細は以下の通りになってます。
train数(sentence単位) | test数(sentence単位) | |
---|---|---|
夏目漱石 | 25224 | 1962 |
江戸川乱歩 | 25607 | 2545 |
太宰治 | 11805 | 2683 |
芥川龍之介 | 3473 | 235 |
宮沢賢治 | 4300 | 643 |
芥川隆ノ介と宮沢賢治が作品が短く、データが少し偏っています。
word2vecは、乾研究室で公開しているwikipediaで事前学習したモデルを使用しました。
前処理
過学習抑制のために、先輩に教えていただいた、wildcard trainingを使用しました。(preprocessで行ってます) 簡単に説明しますと、単語ベクトルの一部をランダムで0にするって感じですね。気になった方は是非調べてみてください!!!
お先に結果
loss めっちゃ高いし過学習しとるやんけええええ
最後のpredictのconfusion matrix 表の見方は横が正解ラベルで縦が推測になってます。対角部分が高いと強いモデルになります。
データの少ない芥川隆ノ介、宮沢賢治がうまく推定できてないですね。芥川隆ノ介を夏目漱石に間違えがちです。 でも5クラスで75%近くだからそこそこ良い精度出てるんでない?? 実装間違ってたりしたら是非、コメントとかで教えてください!!!
モデルの工夫
論文通り、word2vecにない単語を全部1つのランダムなベクトルに置き換えました。 あと、全てのベクトルを最大単語数の大きさでパディングしても良いのですが、それだと計算の無駄が多いため、バッチごとの最大でバッチごとにパディングを行うように変更しました。これをやるだけで時間が1/10くらい変わります。 Two Channelのモデルだけ載せておきます。
class TwoChannel(chainer.Chain): def __init__(self, w2v_w, batch_size): self.w2v_w = w2v_w self.batch = batch_size super(TwoChannel, self).__init__() with self.init_scope(): self.embed1 = L.EmbedID( self.w2v_w.shape[0], self.w2v_w.shape[1], ignore_label=-1, initialW=self.w2v_w) self.embed1.disable_update() self.embed2 = L.EmbedID( self.w2v_w.shape[0], self.w2v_w.shape[1], ignore_label=-1, initialW=self.w2v_w) self.cnn_w3 = L.Convolution2D(None, 100, (3, 200)) self.cnn_w4 = L.Convolution2D(None, 100, (4, 200)) self.cnn_w5 = L.Convolution2D(None, 100, (5, 200)) self.fc = L.Linear(None, 5) def __call__(self, xs): x = concat_examples(xs, padding=-1) len_x = len(x[0]) h1 = F.reshape(self.embed1(x), (-1, 1, len_x, 200)) h2 = F.reshape(self.embed2(x), (-1, 1, len_x, 200)) sentence_vec = F.concat([h1, h2], axis=1) h_3 = F.max( F.tanh(self.cnn_w3(sentence_vec)), axis=2) h_4 = F.max( F.tanh(self.cnn_w4(sentence_vec)), axis=2) h_5 = F.max( F.tanh(self.cnn_w5(sentence_vec)), axis=2) concat = F.concat([h_3, h_4, h_5], axis=2) h2 = F.dropout(F.relu(concat), ratio=0.5) y = self.fc(h2) return y
NonStaticのモデルとStaticのモデルもGitHubの方に乗っているので、是非ご覧ください。
つまずいたとこ
可変長
一番つまずいたとこは、入力を可変長にしたことですね。今まで画像しかやってこなかったため脳死L.Classifierが使えなくて辛かったです。(自作text_classifier) chainer exampleのtext_classfierを色々といじくりまわし、頑張って理解しました、、(コピペ感あり) convert_seqとかconcat exsampleとか色々やることが多かったです。
エラー探し
なんだかんだこれに一番時間がかかった気がします。chainerなどのライブラリは簡単にdeeplearningのモデルが組める反面、中で全部やていただけるので
エラーの場所が全然わからんくて辛かった、、、
今までprintデバッグしてたぼくですが、今回からimport pdb; pdb.set_trace()
!!!!!!!!!!!!!!!
を使い始めました。実行途中で処理を止めて色々買う人できるやつです。(下にリンク貼りま)。ゆうてまだp とかしかつかってないけどな!!!
ちなみに一番アホだったエラーはインデントずれててforwardがないって言われたことですかね!!! 先輩(いつもの人ではない)にクッソ煽られました。
メモリエラー
これは大したことではないのですが、word2vecの重みがでかい!!!マジで 僕は研究室の余り物マシン(4G程度)を使ってたのですが、乗らなくて研究用のマシンを使いました。 絶対いけるもんだと思ってて、実装ミスだとひたすら自分に言い聞かせてました。
大きくつまずいたのはこんなもんですが、他にも色々エラー出ては直しの繰り返しでした。
今後できそうなこと
データの偏りでlossとか大きくなっちゃったのかなーって思うので、データの偏りによってlossの大きさとかかえれたらなーとか思ってます。 データの偏りはいろんなとこでありそうな問題なので解決策をチョロチョロと探していきたいなあと思いま!!!(オーバーサンプリングやfocal lossなど調べて理解したらまとめま!)
まとめ
今回は初めて1から自分で前処理からネットワークまで組みました!!最初に回った時は本当にドッキドキで1時間に1回は確認してました笑 今回は研究室の色々な人にエラーなどみてもらったりしてもらいました。ありがとうございました。 また論文を読んで実装とかしてみたいなあと思います。これからも頑張るマーン
デビルマーン