深層学習の4、RNN LSTM stateful

概要

RNNやLSTMのstatefulとは、各Batchのサンプルの状態が次のバッチのサンプルのための初期状態として再利用されるのであれば、Trueに設定するという。計算の便宜を図ってbatch_size毎に分けて処理するが、株価、天気など時系列データに対して、状態を引き継ぐのであれば、stateful=Trueのケースに当たる。stateful=Trueの場合、以下で対応する。

・batch_size引数をモデルの最初のLayerに渡して、batch_sizeを明記する。例えば、サンプル数が32、time_stepsが10、input_dim(入力データの特徴数)が16の場合には,batch_size=32と明記する。
・RNN、LSTMでstateful=Trueに指定する.
・model.fit()を呼ぶときにshuffle=Falseに指定する。
・model.reset_states()を実行して、モデルの全てのLayerの状態を更新する。または、layer.reset_states()を実行して、特定のLayerの状態を更新する。
・model.predict()、model.train_on_batch()、model.predict_classes()関数はいずれもmodel.reset_states()を実行してstatefulに指定したLayerの状態を更新する。
・model.fit()を呼ぶときに、以下例のようにbatch毎にmodel.reset_states()を実行する。

注意点は、model.reset_states()は0にリセットではなく、前回のbatchの最終状態値を次回のbatchの初期状態値として引き継ぐ。

実装の例

print('Build STATEFUL model...')
model = Sequential()
model.add(LSTM(10, batch_input_shape=(1, 1, 1), return_sequences=False, stateful=True))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
...
class ResetStatesCallback(Callback):
    def __init__(self):
        self.counter = 0

    def on_batch_begin(self, batch, logs={}):
        if self.counter % max_len == 0:
            self.model.reset_states()
        self.counter += 1
...        
model.fit(x, y, callbacks=[ResetStatesCallback()], batch_size=1, shuffle=False)
...
print('Train...')
for epoch in range(15):
    for i in range(len(X_train)):
        tr_loss, tr_acc = model.train_on_batch(np.expand_dims(np.expand_dims(X_train[i][j], axis=1), axis=1),
                                               np.array([y_true]))
        model.reset_states()
...
    for i in range(len(X_test)):
        te_loss, te_acc = model.test_on_batch(np.expand_dims(np.expand_dims(X_test[i][j], axis=1), axis=1),
                                              y_test[i])
        model.reset_states()
...
        y_pred = model.predict_on_batch(np.expand_dims(np.expand_dims(X_test[i][j], axis=1), axis=1))
        model.reset_states()

【追記】

model.fit()のcallbackについて、以下アラームが出る。

RuntimeWarning: Method (on_train_batch_begin) is slow compared to the batch update. Check your callbacks.

これは、model.fit()のcallbackがbatch updateより遅いという、なおかつ、KERASの解釈では「RNNをstatefulにするとは,各バッチのサンプルの状態が,次のバッチのサンプルのための初期状態として再利用されるということを意味する」とあるので、model.fit()を以下のように修正する。

for i in range(NUM_EPOCHS):
    for j in range(BATCH_SIZE*NUM_BATCHES):
        model.fit(x, y, batch_size=1, epochs=1, shuffle=False)
        model.reset_states()

ただし、batchが順番に処理してGPU/TPUの並行処理ができなくなるので、処理時間が余儀なく伸びる。精度と処理時間の兼ね合いから課題として、stateless、batch_sizeを増やすのと、statefulにするのと、どちらがよいのかを案件ごとに実験することが大切だと考える。

参考文献

Keras FAQ
・「Deep Learning with Keras」、by Antonio Gulli氏

ロボット・ドローン部品お探しなら
ROBOT翔・電子部品ストア