TensorFlowでmodelを保存して復元する
modelの保存と復元は、それぞれ以下のようにシンプルな設計で行える。
- Save a model
saver = tf.train.Saver()
saver.save(sess, '../model/test_model')
- Restore a model
saver = tf.train.Saver()
saver.restore(sess, '../model/test_model')
本記事では、実際にmodelを訓練して保存し、そのmodelを復元して、各々のテスト精度が一致していることを確認する。
処理の流れ
# coding: utf-8 import tensorflow as tf import random import os from data_fizzbuzz import DataFizzBuzz class Test: def __init__(self): pass def main(self): data = DataFizzBuzz().main() model = self.design_model(data) self.save_model(data, model) self.restore_model(data, model)
データを取得して、モデルをデザインして、モデルを訓練して保存して、モデルを復元する。
データ
例によって、FizzBuzz問題のためのデータを利用する。 ここ にあるので、以下で解説するソースと同じ場所に置く。
モデル
def design_model(self, data): X = tf.placeholder(tf.float32, [None, data[0].shape[1]]) W1 = tf.Variable(tf.truncated_normal([data[0].shape[1], 100], stddev=0.01), name='W1') B1 = tf.Variable(tf.zeros([100]), name='B1') H1 = tf.nn.tanh(tf.matmul(X, W1) + B1) W2 = tf.Variable(tf.random_normal([100, data[1].shape[1]], stddev=0.01), name='W2') B2 = tf.Variable(tf.zeros([data[1].shape[1]]), name='B2') Y = tf.matmul(H1, W2) + B2 Y_ = tf.placeholder(tf.float32, [None, data[1].shape[1]]) tf.add_to_collection('vars', W1) tf.add_to_collection('vars', B1) tf.add_to_collection('vars', W2) tf.add_to_collection('vars', B2) model = {'X': X, 'Y': Y, 'Y_': Y_} return model
1層のFFNNを訓練する。 重み(W)とバイアス(B)についてはnameを付ける。 add_to_collection()で変数として登録する。
モデルの保存
def save_model(self, data, model): """ # Data dataは、訓練データ、訓練ラベル、テストデータ、 テストラベルの順に格納されているので、順に取り出す。 """ train_data = data[0] train_label = data[1] test_data = data[2] test_label = data[3] """ # Model 設計したモデルのうち、訓練で使うのは、X, Y, Y_のみなので、それらを取り出す。 """ X, Y, Y_ = model['X'], model['Y'], model['Y_'] """ # Functions 訓練で利用する関数をそれぞれ定義する。 """ loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Y, Y_)) step = tf.train.AdamOptimizer(0.05).minimize(loss) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32)) """ # Setting 初期化。saverは、モデルを保存するためのインスタンス。 """ saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) """ # Randamize data 101〜1023のデータをランダマイズする。 """ p = np.random.permutation(range(len(train_data))) train_data, train_label = train_data[p], train_label[p] """ # Training 100バッチずつ訓練するエポックを1回繰り返す。 訓練エラー、訓練精度、テスト精度を算出する。 """ for start in range(0, train_label.shape[0], 1): end = start + 100 sess.run(step, feed_dict={X: train_data[start:end], Y_: train_label[start:end]}) # Testing train_loss = sess.run(loss, feed_dict={X: train_data, Y_: train_label}) train_accuracy = sess.run(accuracy, feed_dict={X: train_data, Y_: train_label}) test_accuracy = sess.run(accuracy, feed_dict={X: test_data, Y_: test_label}) """ # Accuracy 学習後、訓練エラー、訓練精度、テスト精度を標準出力する。 """ std_output = 'Train Loss: %s, \t Train Accuracy: %s, \t Test Accuracy: %s' print(std_output % (train_loss, train_accuracy, test_accuracy)) """ # Save a model 既存の訓練モデルを削除して、今回訓練したモデルを保存する。 """ for f in os.listdir('../model/'): os.remove('../model/'+f) saver.save(sess, '../model/test_model') print('Saved a model.') sess.close()
modelディレクトリーには、以下が保存される。
- checkpoint
- test_model.data-00000-of-00001
- test_model.index
- test_model.meta
モデルの復元
def restore_model(self, data, model): # Data train_data = data[0] train_label = data[1] test_data = data[2] test_label = data[3] # Model X, Y, Y_ = model['X'], model['Y'], model['Y_'] # Function accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32)) """ # Setting SaverとSessionのインスタンスを生成し、モデルを復元する。 """ saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, '../model/test_model') print('Restored a model') test_accuracy = sess.run(accuracy, feed_dict={X: test_data, Y_: test_label}) print('Test Accuracy: %s' % test_accuracy)
復元は非常に簡単。 モデルの保存で使用したデータとモデルを渡している。 データは何でもいいが、モデルは同じものが必要。 初期化は不要。
実行結果
$ python test.py Train Loss: 1.88192, Train Accuracy: 0.781148, Test Accuracy: 0.762376 Saved a model. Restored a model Test Accuracy: 0.762376
Test Accuracyが、モデルの保存とモデルの復元で一致している。 同じモデルとテストデータだと、同じ精度が出ることが確認できた。
その他
もしモデルの復元で失敗するときは、モデルの保存を一回した後で、モデルの復元だけ実行したり、 以下のように1行足してみると上手く行くかも。
# Setting saver = tf.train.Saver() sess = tf.Session() saver = tf.train.import_meta_graph('../model/test_model.meta') ## <- add saver.restore(sess, '../model/test_model')