読者です 読者をやめる 読者になる 読者になる

損失関数がNaNになる問題

TensorFlowでDeep Learningを実行している途中で、損失関数がNaNになる問題が発生した。

Epoch:  10,  Train Loss: 85.6908,   Train Accuracy: 0.996,      Test Error: 90.7068,  Test Accuracy: 0.985238
Epoch:  20,  Train Loss: 42.9642,   Train Accuracy: 0.998286,   Test Error: 121.561,  Test Accuracy: 0.98619
Epoch:  30,  Train Loss: 0.945895,  Train Accuracy: 1.0,        Test Error: 102.041,  Test Accuracy: 0.990476
Epoch:  40,  Train Loss: nan,       Train Accuracy: 0.101429,   Test Error: nan,      Test Accuracy: 0.1
Epoch:  50,  Train Loss: nan,       Train Accuracy: 0.0941429,  Test Error: nan,      Test Accuracy: 0.1
Epoch:  60,  Train Loss: nan,       Train Accuracy: 0.0968571,  Test Error: nan,      Test Accuracy: 0.1
Epoch:  70,  Train Loss: nan,       Train Accuracy: 0.0881429,  Test Error: nan,      Test Accuracy: 0.1
Epoch:  80,  Train Loss: nan,       Train Accuracy: 0.0931429,  Test Error: nan,      Test Accuracy: 0.1
Epoch:  90,  Train Loss: nan,       Train Accuracy: 0.0997143,  Test Error: nan,      Test Accuracy: 0.1
Epoch: 100,  Train Loss: nan,       Train Accuracy: 0.0997143,  Test Error: nan,      Test Accuracy: 0.1

原因は、損失関数に指定している交差エントロピーtf.log(Y)にあった。

cross_entropy = -tf.reduce_sum(Y_*tf.log(Y))

tf.log(Y)はln(x)(自然対数のlog)であり、xが0になるとき-∞になるため、NaNとなっていた。

f:id:Shoto:20170507182523p:plain

ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装では、4章の「ニューラルネットワークの学習」で、 極少の数値を加算することで、この問題に対処する方法を提示している。

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))

しかし、TensorFlowを利用したいので、上記の問題を解決しているsoftmax_cross_entropy_with_logits()を用いて、 cross_entropyを書き換える。

#cross_entropy = -tf.reduce_sum(Y_*tf.log(Y))
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Y, Y_))

これにより、NaNになる問題が解決できた。 しかし、学習が進み、cross_entropyが低くなるに連れて、Yは1に近づき、tf.log(Y)は0に近づくため、 cross_entropyがNaNになる確率は低くなるはずなのだが、NaNになる点は謎のままである。

Epoch:  10,  Train Error: 1.46717,  Train Accuracy: 0.995,     Test Error: 1.47448,   Test Accuracy: 0.987619
Epoch:  20,  Train Error: 1.46428,  Train Accuracy: 0.997286,  Test Error: 1.47456,   Test Accuracy: 0.985714
Epoch:  30,  Train Error: 1.46262,  Train Accuracy: 0.998715,  Test Error: 1.47142,   Test Accuracy: 0.990476
Epoch:  40,  Train Error: 1.46272,  Train Accuracy: 0.998429,  Test Error: 1.47249,   Test Accuracy: 0.989047
Epoch:  50,  Train Error: 1.46235,  Train Accuracy: 0.998857,  Test Error: 1.47399,   Test Accuracy: 0.987619
Epoch:  60,  Train Error: 1.46462,  Train Accuracy: 0.996715,  Test Error: 1.47435,   Test Accuracy: 0.987619
Epoch:  70,  Train Error: 1.46261,  Train Accuracy: 0.998572,  Test Error: 1.47196,   Test Accuracy: 0.989524
Epoch:  80,  Train Error: 1.46215,  Train Accuracy: 0.999,     Test Error: 1.47119,   Test Accuracy: 0.99
Epoch:  90,  Train Error: 1.46547,  Train Accuracy: 0.995715,  Test Error: 1.4784,    Test Accuracy: 0.982857
Epoch: 100,  Train Error: 1.46201,  Train Accuracy: 0.999143,  Test Error: 1.47057,   Test Accuracy: 0.990476

参考文献