tensorflowで画像分類モデル(VGG19)を用いて絵画の画像データから芸術様式(写実主義、シュルレアリスム、印象派など)を予測するモデルをつくった
芸術様式の横文字はアートスタイルとかアートジャンルなどがヒットしたが、ジャンルは素材(写真、彫刻など)のことを指すようにも見受けられたので芸術様式と表記する
環境
- python: 3.7
- tensorflow: ^2.0
- tf_notification_callback
データセット
kaggleにあるデータセットを利用
all_data_infoにtrain/test
の情報と、画像名とgenre、styleなどが入っている
こんな感じの画像データが入っている

手法
転移学習を行う
データ読み込み
import pandas as pd
import tensorflow as tf
df = pd.read_csv('./all_data_info.csv.zip')
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
rotation_range=40,
zoom_range=[0.75, 1.25], # ランダムにzoomをしてみる
horizontal_flip=True,
vertical_flip=True,
)
IMG_HEIGHT, IMG_WIDTH = 224, 224
BATCH_SIZE = 48
train_generator = datagen.flow_from_dataframe(
dataframe=df[df['in_train']==True],
directory='./train',
x_col='new_filename',
y_col='style',
batch_size=BATCH_SIZE,
target_size=(IMG_HEIGHT, IMG_WIDTH)
)
test_generator = datagen.flow_from_dataframe(
dataframe=df[df['in_train']==False],
directory='./test',
x_col='new_filename',
y_col='style',
batch_size=BATCH_SIZE,
target_size=(IMG_HEIGHT, IMG_WIDTH)
)
学習
モデル定義
学習済みのモデルに対していくつか層を足してみる
# model init
classes = list(df['style'])
input_tensor = tf.keras.layers.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
with tf.device('/device:XLA_GPU:0'):
vgg19 = tf.keras.applications.VGG19(
weights='imagenet',
include_top=False,
input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
input_tensor=input_tensor
)
# add some layer for output
x = vgg19.output
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(724, activation='relu')(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
predictions = tf.keras.layers.Dense(len(classes),
activation='softmax')(x)
model = tf.keras.Model(inputs=vgg19.input, outputs=predictions)
# 20層までfreeze
for layer in model.layers[:22]:
layer.trainable = False
# 20層以降、学習させる
for layer in model.layers[22:]:
layer.trainable = True
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
学習&モデル保存
コールバック関数を定義しておくcheckpoint保存用と、slack通知用
モデルの保存はコールバックで行われる
dpath = './models'
fpath = os.path.join(dpath, 'weights.{epoch:02d}-{loss:.2f}-{accuracy:.2f}-{val_loss:.2f}-{val_accuracy:.2f}.hdf5')
cp_cb = tf.keras.callbacks.ModelCheckpoint(filepath=fpath, monitor='val_loss',
verbose=1, save_best_only=True, mode='auto',
save_freq='epoch'
)
# notify to slack
from tf_notification_callback import SlackCallback
slack_cb = SlackCallback(webhookURL='https://hooks.slack.com/services/Txxxxx/Bxxxx/bxxxxxxxxxxx',
channel='tf-nortify',
modelName='vgg19 Model',
loss_metrics=['loss', 'val_loss'],
acc_metrics=['accuracy', 'val_accuracy'],
getSummary=True)
N_TRAIN = len(df[df['in_train']==True])
N_TEST = len(df[df['in_train']==False])
N_EPOCHS = 20
history = model.fit(
train_generator,
steps_per_epoch=N_TRAIN // BATCH_SIZE,
epochs=N_EPOCHS,
validation_data=test_generator,
validation_steps=N_TEST // BATCH_SIZE,
callbacks=[cp_cb, slack_cb],
)
データ量の多い順にクラス数10まで絞って学習したら、精度40%くらいだった
...
1066/1066 [==============================] - 10213s 10s/step - loss: 1.5340 - accuracy: 0.4436 - val_loss: 1.6417 - val_accuracy: 0.4040
今後はチューニングと、データ探索をしてみたい