TensorFlowでmodelを保存して復元する

modelの保存と復元は、それぞれ以下のようにシンプルな設計で行える。

  • Save a model
saver = tf.train.Saver()
saver.save(sess, '../model/test_model')
  • Restore a model
saver = tf.train.Saver()
saver.restore(sess, '../model/test_model')

本記事では、実際にmodelを訓練して保存し、そのmodelを復元して、各々のテスト精度が一致していることを確認する。

処理の流れ

# coding: utf-8

import tensorflow as tf
import random
import os

from data_fizzbuzz import DataFizzBuzz

class Test:
    def __init__(self):
        pass

    def main(self):
        data = DataFizzBuzz().main()
        model = self.design_model(data)
        self.save_model(data, model)
        self.restore_model(data, model)

データを取得して、モデルをデザインして、モデルを訓練して保存して、モデルを復元する。

データ

例によってFizzBuzz問題のためのデータを利用する。 ここ にあるので、以下で解説するソースと同じ場所に置く。

モデル

def design_model(self, data):
    X  = tf.placeholder(tf.float32, [None, data[0].shape[1]])
    W1 = tf.Variable(tf.truncated_normal([data[0].shape[1], 100], stddev=0.01), name='W1')
    B1 = tf.Variable(tf.zeros([100]), name='B1')
    H1 = tf.nn.tanh(tf.matmul(X, W1) + B1)
    W2 = tf.Variable(tf.random_normal([100, data[1].shape[1]], stddev=0.01), name='W2')
    B2 = tf.Variable(tf.zeros([data[1].shape[1]]), name='B2')
    Y = tf.matmul(H1, W2) + B2
    Y_ = tf.placeholder(tf.float32, [None, data[1].shape[1]])

    tf.add_to_collection('vars', W1)
    tf.add_to_collection('vars', B1)
    tf.add_to_collection('vars', W2)
    tf.add_to_collection('vars', B2)

    model = {'X': X, 'Y': Y, 'Y_': Y_}

    return model

1層のFFNNを訓練する。 重み(W)とバイアス(B)についてはnameを付ける。 add_to_collection()で変数として登録する。

モデルの保存

def save_model(self, data, model):
    """
    # Data
    dataは、訓練データ、訓練ラベル、テストデータ、
    テストラベルの順に格納されているので、順に取り出す。
    """
    train_data  = data[0]
    train_label = data[1]
    test_data  = data[2]
    test_label = data[3]

    """
    # Model
    設計したモデルのうち、訓練で使うのは、X, Y, Y_のみなので、それらを取り出す。
    """
    X, Y, Y_ = model['X'], model['Y'], model['Y_']

    """
    # Functions
    訓練で利用する関数をそれぞれ定義する。
    """
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Y, Y_))
    step = tf.train.AdamOptimizer(0.05).minimize(loss)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

    """
    # Setting
    初期化。saverは、モデルを保存するためのインスタンス。
    """
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    """
    # Randamize data
    101〜1023のデータをランダマイズする。
    """
    p = np.random.permutation(range(len(train_data)))
    train_data, train_label = train_data[p], train_label[p]

    """
    # Training
    100バッチずつ訓練するエポックを1回繰り返す。
    訓練エラー、訓練精度、テスト精度を算出する。
    """
    for start in range(0, train_label.shape[0], 1):
        end = start + 100
        sess.run(step, feed_dict={X: train_data[start:end], Y_: train_label[start:end]})

        # Testing
        train_loss = sess.run(loss, feed_dict={X: train_data, Y_: train_label})
        train_accuracy = sess.run(accuracy, feed_dict={X: train_data, Y_: train_label})
        test_accuracy = sess.run(accuracy, feed_dict={X: test_data, Y_: test_label})

    """
    # Accuracy
    学習後、訓練エラー、訓練精度、テスト精度を標準出力する。
    """
    std_output = 'Train Loss: %s, \t Train Accuracy: %s, \t Test Accuracy: %s'
    print(std_output % (train_loss, train_accuracy, test_accuracy))

    """
    # Save a model
    既存の訓練モデルを削除して、今回訓練したモデルを保存する。
    """
    for f in os.listdir('../model/'):
        os.remove('../model/'+f)
    saver.save(sess, '../model/test_model')
    print('Saved a model.')

    sess.close()

modelディレクトリーには、以下が保存される。

  • checkpoint
  • test_model.data-00000-of-00001
  • test_model.index
  • test_model.meta

モデルの復元

def restore_model(self, data, model):
    # Data
    train_data  = data[0]
    train_label = data[1]
    test_data  = data[2]
    test_label = data[3]

    # Model
    X, Y, Y_ = model['X'], model['Y'], model['Y_']

    # Function
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)), tf.float32))

    """
    # Setting
    SaverとSessionのインスタンスを生成し、モデルを復元する。
    """
    saver = tf.train.Saver()
    sess = tf.Session()
    saver.restore(sess, '../model/test_model')
    print('Restored a model')

    test_accuracy = sess.run(accuracy, feed_dict={X: test_data, Y_: test_label})
    print('Test Accuracy: %s' % test_accuracy)

復元は非常に簡単。 モデルの保存で使用したデータとモデルを渡している。 データは何でもいいが、モデルは同じものが必要。 初期化は不要。

実行結果

$ python test.py
Train Loss: 1.88192,     Train Accuracy: 0.781148,   Test Accuracy: 0.762376
Saved a model.
Restored a model
Test Accuracy: 0.762376

Test Accuracyが、モデルの保存とモデルの復元で一致している。 同じモデルとテストデータだと、同じ精度が出ることが確認できた。

その他

もしモデルの復元で失敗するときは、モデルの保存を一回した後で、モデルの復元だけ実行したり、 以下のように1行足してみると上手く行くかも。

# Setting
saver = tf.train.Saver()
sess = tf.Session()
saver = tf.train.import_meta_graph('../model/test_model.meta')  ## <- add
saver.restore(sess, '../model/test_model')