tetsuのlog

テツがひとりでに読んだ論文とか行き詰まったところをshareする場

CNN for Sentence Classificationの実装

はおーてっつんでーす

youtubeの登録チャンネルはほぼV tuberのテツです。友達いないからフォローしてね!!!

前回の論文読みを終えて、せっかくなので実装してみようと言うことでchainerでやってみました。

tetsu316.hatenablog.com

ソースコードは以下のgithubのリンクからご確認ください!!

github.com

データセット

今回は大学の授業の課題もあり、sentence単位での青空文庫の著者推定をやりました!! 夏目漱石江戸川乱歩太宰治芥川龍之介宮沢賢治の5名の作品を僕がテキトーに12作ずつ選んで10作を訓練データに、2作をテストデータにしました。

データセットの詳細は以下の通りになってます。

train数(sentence単位) test数(sentence単位)
夏目漱石 25224 1962
江戸川乱歩 25607 2545
太宰治 11805 2683
芥川龍之介 3473 235
宮沢賢治 4300 643

芥川隆ノ介と宮沢賢治が作品が短く、データが少し偏っています。

word2vecは、乾研究室で公開しているwikipediaで事前学習したモデルを使用しました。

前処理

  • ファイル読み込み
  • いらないところ削る + 「。」で文区切り
  • mecabによる分かち書き
  • word2vecのインデックスに変更

過学習抑制のために、先輩に教えていただいた、wildcard trainingを使用しました。(preprocessで行ってます) 簡単に説明しますと、単語ベクトルの一部をランダムで0にするって感じですね。気になった方は是非調べてみてください!!!

お先に結果

f:id:tetsu316:20190122014229p:plain
accuracy

f:id:tetsu316:20190122014235p:plain
loss

loss めっちゃ高いし過学習しとるやんけええええ

最後のpredictのconfusion matrix 表の見方は横が正解ラベルで縦が推測になってます。対角部分が高いと強いモデルになります。

f:id:tetsu316:20190122191029p:plain
confusion matrix

データの少ない芥川隆ノ介、宮沢賢治がうまく推定できてないですね。芥川隆ノ介を夏目漱石に間違えがちです。 でも5クラスで75%近くだからそこそこ良い精度出てるんでない?? 実装間違ってたりしたら是非、コメントとかで教えてください!!!

モデルの工夫

論文通り、word2vecにない単語を全部1つのランダムなベクトルに置き換えました。 あと、全てのベクトルを最大単語数の大きさでパディングしても良いのですが、それだと計算の無駄が多いため、バッチごとの最大でバッチごとにパディングを行うように変更しました。これをやるだけで時間が1/10くらい変わります。 Two Channelのモデルだけ載せておきます。

class TwoChannel(chainer.Chain):
    def __init__(self, w2v_w, batch_size):
        self.w2v_w = w2v_w
        self.batch = batch_size
        super(TwoChannel, self).__init__()
        with self.init_scope():
            self.embed1 = L.EmbedID(
                self.w2v_w.shape[0], self.w2v_w.shape[1],
                ignore_label=-1, initialW=self.w2v_w)
            self.embed1.disable_update()
            self.embed2 = L.EmbedID(
                self.w2v_w.shape[0], self.w2v_w.shape[1],
                ignore_label=-1, initialW=self.w2v_w)
            self.cnn_w3 = L.Convolution2D(None, 100, (3, 200))
            self.cnn_w4 = L.Convolution2D(None, 100, (4, 200))
            self.cnn_w5 = L.Convolution2D(None, 100, (5, 200))
            self.fc = L.Linear(None, 5)

    def __call__(self, xs):
        x = concat_examples(xs, padding=-1)
        len_x = len(x[0])
        h1 = F.reshape(self.embed1(x), (-1, 1, len_x, 200))
        h2 = F.reshape(self.embed2(x), (-1, 1, len_x, 200))
        sentence_vec = F.concat([h1, h2], axis=1)
        h_3 = F.max(
            F.tanh(self.cnn_w3(sentence_vec)), axis=2)
        h_4 = F.max(
            F.tanh(self.cnn_w4(sentence_vec)), axis=2)
        h_5 = F.max(
            F.tanh(self.cnn_w5(sentence_vec)), axis=2)
        concat = F.concat([h_3, h_4, h_5], axis=2)
        h2 = F.dropout(F.relu(concat), ratio=0.5)
        y = self.fc(h2)
        return y

NonStaticのモデルとStaticのモデルもGitHubの方に乗っているので、是非ご覧ください。

つまずいたとこ

可変長

一番つまずいたとこは、入力を可変長にしたことですね。今まで画像しかやってこなかったため脳死L.Classifierが使えなくて辛かったです。(自作text_classifier) chainer exampleのtext_classfierを色々といじくりまわし、頑張って理解しました、、(コピペ感あり) convert_seqとかconcat exsampleとか色々やることが多かったです。

エラー探し

なんだかんだこれに一番時間がかかった気がします。chainerなどのライブラリは簡単にdeeplearningのモデルが組める反面、中で全部やていただけるので エラーの場所が全然わからんくて辛かった、、、 今までprintデバッグしてたぼくですが、今回からimport pdb; pdb.set_trace()!!!!!!!!!!!!!!! を使い始めました。実行途中で処理を止めて色々買う人できるやつです。(下にリンク貼りま)。ゆうてまだp とかしかつかってないけどな!!!

docs.python.jp

ちなみに一番アホだったエラーはインデントずれててforwardがないって言われたことですかね!!! 先輩(いつもの人ではない)にクッソ煽られました。

メモリエラー

これは大したことではないのですが、word2vecの重みがでかい!!!マジで 僕は研究室の余り物マシン(4G程度)を使ってたのですが、乗らなくて研究用のマシンを使いました。 絶対いけるもんだと思ってて、実装ミスだとひたすら自分に言い聞かせてました。

大きくつまずいたのはこんなもんですが、他にも色々エラー出ては直しの繰り返しでした。

今後できそうなこと

データの偏りでlossとか大きくなっちゃったのかなーって思うので、データの偏りによってlossの大きさとかかえれたらなーとか思ってます。 データの偏りはいろんなとこでありそうな問題なので解決策をチョロチョロと探していきたいなあと思いま!!!(オーバーサンプリングやfocal lossなど調べて理解したらまとめま!)

まとめ

今回は初めて1から自分で前処理からネットワークまで組みました!!最初に回った時は本当にドッキドキで1時間に1回は確認してました笑 今回は研究室の色々な人にエラーなどみてもらったりしてもらいました。ありがとうございました。 また論文を読んで実装とかしてみたいなあと思います。これからも頑張るマーン

ビルマーン