2019年4月20日土曜日

CIFAR-100を畳み込みニューラルネットワークで認識

本当に久しぶりの投稿になってしまいました.投稿していない間,色々なことがありました.一番大きな出来事は,これまで勤めていた会社を退職し,フリーランスかつアルバイト採用のエンジニアとして働き始めました.現在は,主に測定機や組み込み機器を開発するメーカーでソフトウエア開発をしています.(個人でのご依頼もどんどん受け付けております.もう一回言いますね.個人でのご依頼もどんどん受け付けております.

ありがたいことに,出勤時間や出勤日の自由度が非常に高い会社で,平日でもこれまで出来なかったような音楽活動や,個人で受けた案件を進めたり,単純に遊んだり,昨日頑張った(個人の感想)から昼過ぎまで寝てから出勤したり,昨日頑張った(個人の感想)から昼過ぎまで寝て出勤しようと思ったけどそのまま夕方まで寝たり,ということが出来ています(とはいえ,出勤しないとその分自分の給料に跳ね返るのだが).社風やオフィス内の雰囲気も私好みでとても気に入っています.何より,アルバイト採用という立場ながら,憧れの組み込み業界で開発をすることが出来て,一応仕事の業界的にはキャリアアップも達成できたように思えるのがとても嬉しいです.(組み込みに詳しい風を装って合格したので,分からないことが多いのに今更質問することもできなくて少し焦っています

そんな訳で数ヶ月は,仕事以外の,完全に自分の趣味や興味で何かをすることが,音楽活動以外無くなってしまい,本ブログの更新も滞っていました.しかし,上記のような生活スタイルになって,そうした活動を再開するゆとりが出てきたのでちょっと再開してみますかと,色々始めています.

今回は,機械学習の話題.CIFAR-100のデータをCNNで識別してみます.

機械学習や画像処理の分野で多く用いられる畳み込みニューラルネットワーク(Convolutional neural network 以下CNN).その分野に詳しい方は今更説明するまでもないくらい有名かつ効果的な手法として普及しています.

入力した画像に写っているものが何か識別するタスクに多く用いられます.

https://postd.cc/how-do-convolutional-neural-networks-work/

CNNの仕組みはこちらのwebサイトが非常にわかりやすいので,そちらに譲ります.

特徴としては,教師データの画像の局所的な特徴(点がある,線があるといった抽象的な構造)をとらえて,それを認識に使うので,入力画像で認識したい物体の位置が違っていたり,角度が違っていても認識ができるというところにあります.従来のニューラルネットワークでは,いきなり画像全体の特徴を認識しようとするので,複雑な構造を発見したり,些細な違いを見分けることが困難でした.しかし,画像の局所的な特徴を認識して,それを組み合わせて徐々に複雑な特徴を認識していけば,複雑な画像でも認識できるわけです.

機械学習の練習やテストに用いられる画像データセットは,色々なものが知られています.一番有名なのが,NMISTです.


0〜9までの手書きの数字が学習用に60000枚,テスト(学習がうまくいったか,実際に入力してみる)用に10000枚がセットになっています.

他にも衣類(シャツや靴等)画像を学習するためのFASHON-MNISTもあります.

これらは白黒画像での認識ですが,カラー画像の練習に多く使われているのが,CIFAR-10です.


図の10種類に分類できる画像が,学習用に50000枚,テスト用に10000枚用意されています.

上記データセットの学習は,初歩的なニューラルネットワークやCNNでも非常に精度よく学習できるため,多くの実験例をweb上でみることができます.

ところで,CIFAR-10によく似たデータセットにCIFAR-100があります.


その名の通り,100種類に分類できる画像が学習用に50000枚,テスト用に10000枚用意されています.

このCIFAR-100の学習,やはり100種類の識別となると難しいのか,あまり学習を行っている例を見かけません.そこで,ちょっと行って見ました.色々試行錯誤はしてみましたが,とりあえず現状でのコードです.学習は非常に大変なので,Google Colabratoryを使います.作成したCNNのネットワークと,学習したパラメータをファイルに保存して,学習終了後にダウンロードして,ローカルでの認識処理に使います.学習するついでに,50000枚の学習画像の中から25枚,表示してみています.

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
#CIFAR100のラベル名
CIFAR100_LABELS_LIST = [
                        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
                        'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
                        'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
                        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
                        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
                        'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
                        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
                        'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
                        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
                        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
                        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
                        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
                        'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
                        'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
                        'worm'
                        ]
#CIFAR-100 datasetの読み込み
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data(label_mode='fine')

train_labels_onehot=keras.utils.to_categorical(y_train,100)
test_labels_onehot=keras.utils.to_categorical(y_test,100)

#画像をfloat32(0.~1.)に変換
x_train=x_train.astype("float32")/255.0
x_test=x_test.astype("float32")/255.0

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

#学習画像を少し見てみる
plt.figure(figsize=(5,5))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(x_train[i+25])
    plt.xlabel(CIFAR100_LABELS_LIST[int(y_train[i+25])])

#ネットワーク作成
model=keras.Sequential()

model.add(keras.layers.Conv2D(filters=32,kernel_size=(3,3),padding='same',activation='relu',input_shape=(32,32,3)))
model.add(keras.layers.Conv2D(filters=32,kernel_size=(3,3),padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=(2,2)))
model.add(keras.layers.Dropout(0.25))

model.add(keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding='same',activation='relu'))
model.add(keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=(2,2)))
model.add(keras.layers.Dropout(0.25))

model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512,activation='relu'))
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(100,activation='softmax'))

model.compile(optimizer=tf.train.AdamOptimizer(),loss='categorical_crossentropy',metrics=["accuracy"])

#学習用データで学習してみる
model.fit(x_train[:,:,:,:],train_labels_onehot,epochs=200,batch_size=64)

#ネットワークをファイルに保存
model_json_str=model.to_json()
open('model.json','w').write(model_json_str)
#学習したパラメータをファイルに保存
model.save_weights('weights.h5')

plt.show()

データセットの一部を表示して見ると,こんな感じ.



学習回数は200回です.GPUを使っても結構時間がかかりますので,以下のようなセッション切れ対策が必要かもしれません.
Google Colaboratoryの90分セッション切れ対策【自動接続】

結果は画像のようになりました.


accuracyは80%までいきましたが,lossがそんなに下がっていないですね.
とりあえず,保存したモデルとパラメータをダウンロードして,ローカルマシンでテストしてみます.テスト用の10000枚をネットワークに入力して分類して見ます.Pythonのバージョンは3.7.1です.

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
#CIFAR100のラベル名
CIFAR100_LABELS_LIST = [
                        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
                        'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
                        'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
                        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
                        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
                        'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
                        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
                        'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
                        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
                        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
                        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
                        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
                        'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
                        'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
                        'worm'
                        ]
#CIFAR-100 datasetの読み込み
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data(label_mode='fine')

train_labels_onehot=keras.utils.to_categorical(y_train,100)
test_labels_onehot=keras.utils.to_categorical(y_test,100)

#画像をfloat32(0.~1.)に変換
x_train=x_train.astype("float32")/255.0
x_test=x_test.astype("float32")/255.0

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

#学習画像を少し見てみる
plt.figure(figsize=(5,5))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    #ランダムな番号の画像を見る
    index=np.random.randint(0,49999)
    plt.imshow(x_train[index])
    plt.xlabel(CIFAR100_LABELS_LIST[int(y_train[index])])

#ネットワーク読み込み
model=keras.models.model_from_json(open('model.json').read())
#学習したパラメータ読み込み
model.load_weights('weights.h5')
model.summary()
model.compile(optimizer=tf.train.AdamOptimizer(),loss='categorical_crossentropy',metrics=["accuracy"])

#テスト画像を入力して識別
labels=model.predict(x_test[:,:,:,:])

#結果を少し見る
plt.figure(figsize=(5,5))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    #ランダムな番号の画像を見る
    index=np.random.randint(0,9999)
    true_index=np.argmax(test_labels_onehot[index])#正解
    predict_index=np.argmax(labels[index])#予測したインデックス
    plt.imshow(x_test[index])#画像表示
    plt.xlabel("{}({})".format(CIFAR100_LABELS_LIST[predict_index],CIFAR100_LABELS_LIST[true_index]),color=("green" if predict_index==true_index else "red"))#"予測したラベル(正解のラベル)"で表示.正解なら緑,間違っていれば赤で表示

#実際、テスト画像でどれほど正解しているのか?
correct=0
for i in range(10000):
    true_index=np.argmax(test_labels_onehot[i])#正解
    predict_index=np.argmax(labels[i])#予測したインデックス
    if true_index==predict_index:
        correct+=1

print("correct: "+str(correct)+" / 10000")

plt.show()


結果は以下のようになりました.画像名を緑色で表示しているのが正解した画像.


結構間違えてますね.でも,実際間違えているのを見ると,惜しい(自分が分類しても間違えるかもしれない)のもあったりします.

実際,10000枚のうち何枚正解したのか数えてみました.


いや何が80%なんですか.

やはり100種類に分類するというのは結構難しいようですね.

0 件のコメント:

コメントを投稿