def model_pred(test_data): # test generator test_gene = data_generator(test_data, att_lstm_num=att_lstm_num, long_term_lstm_seq_len=long_term_lstm_seq_len, short_term_lstm_seq_len=short_term_lstm_seq_len, hist_feature_daynum=hist_feature_daynum, last_feature_num=last_feature_num, nbhd_size=nbhd_size, batchsize=BATCHSIZE) test_sep = (sum([len(x) for x in test_data]) - empty_time * len(test_data)) * test_data[0].shape[1] // BATCHSIZE # get predict modeler = STDN_models_noflow.models() model = modeler.stdn(att_lstm_num=att_lstm_num, att_lstm_seq_len=long_term_lstm_seq_len, lstm_seq_len=short_term_lstm_seq_len, feature_vec_len=feature_vec_len, cnn_flat_size=cnn_flat_size, nbhd_size=window_size, nbhd_type=DATACHANNEL, output_shape=DATACHANNEL, optimizer=OPTIMIZER, loss=LOSS) model.load_weights(PATH + '/' + MODELNAME + '.h5') # write record scale_MSE = model.evaluate_generator(test_gene, steps=test_sep) rescale_MSE = scale_MSE * MAX_VALUE * MAX_VALUE print("Model scaled MSE", scale_MSE) print("Model rescaled MSE", rescale_MSE) with open(PATH + '/' + MODELNAME + '_prediction_scores.txt', 'a') as wf: wf.write("Keras MSE on testData, %f\n" % scale_MSE) wf.write("Rescaled MSE on testData, %f\n" % rescale_MSE) pred_gene = data_generator(test_data, att_lstm_num=att_lstm_num, long_term_lstm_seq_len=long_term_lstm_seq_len, short_term_lstm_seq_len=short_term_lstm_seq_len, hist_feature_daynum=hist_feature_daynum, last_feature_num=last_feature_num, nbhd_size=nbhd_size, batchsize=BATCHSIZE, type='test') pred = model.predict_generator(pred_gene, steps=test_sep, verbose=1) pred = np.reshape(pred, (sum([len(x) for x in test_data]) - empty_time * len(test_data), HEIGHT, WIDTH, DATACHANNEL)) np.save(PATH + '/' + MODELNAME + '_prediction.npy', pred * MAX_VALUE) testY = get_test_true(test_data) testY = np.reshape( testY, (sum([len(x) for x in test_data]) - empty_time * len(test_data), HEIGHT, WIDTH, DATACHANNEL)) np.save(PATH + '/' + MODELNAME + '_groundtruth.npy', testY * MAX_VALUE) print(np.mean((pred * MAX_VALUE - testY * MAX_VALUE)**2))
def model_train(train_data, valid_data): # set callbacks csv_logger = CSVLogger(PATH + '/' + MODELNAME + '.log') checkpointer_path = PATH + '/' + MODELNAME + '.h5' checkpointer = ModelCheckpoint(filepath=checkpointer_path, verbose=1, save_best_only=True) early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='auto') LearnRate = LearningRateScheduler(lambda epoch: LR) # data generator train_generator = data_generator(train_data, att_lstm_num=att_lstm_num, long_term_lstm_seq_len=long_term_lstm_seq_len, short_term_lstm_seq_len=short_term_lstm_seq_len, hist_feature_daynum=hist_feature_daynum, last_feature_num=last_feature_num, nbhd_size=nbhd_size, batchsize=BATCHSIZE) val_generator = data_generator(valid_data, att_lstm_num=att_lstm_num, long_term_lstm_seq_len=long_term_lstm_seq_len, short_term_lstm_seq_len=short_term_lstm_seq_len, hist_feature_daynum=hist_feature_daynum, last_feature_num=last_feature_num, nbhd_size=nbhd_size, batchsize=BATCHSIZE) sep = (train_data.shape[0] - empty_time) * train_data.shape[1] // BATCHSIZE val_sep = (valid_data.shape[0] - empty_time) * valid_data.shape[1] // BATCHSIZE # train model modeler = STDN_models_noflow.models() model = modeler.stdn(att_lstm_num=att_lstm_num, att_lstm_seq_len=long_term_lstm_seq_len, lstm_seq_len=short_term_lstm_seq_len, feature_vec_len=feature_vec_len, cnn_flat_size=cnn_flat_size, nbhd_size=window_size, nbhd_type=DATACHANNEL, output_shape=DATACHANNEL, optimizer=OPTIMIZER, loss=LOSS) model.summary() model.fit_generator(train_generator, steps_per_epoch=sep, epochs=EPOCH, validation_data=val_generator, validation_steps=val_sep, callbacks=[csv_logger, checkpointer, LearnRate, early_stopping]) # model.fit_generator(train_generator, steps_per_epoch=sep, epochs=EPOCH) # write record val_scale_MSE = model.evaluate_generator(val_generator, steps=val_sep) val_rescale_MSE = val_scale_MSE * MAX_VALUE * MAX_VALUE with open(PATH + '/' + MODELNAME + '_prediction_scores.txt', 'a') as wf: wf.write('train start time: {}\n'.format(StartTime)) wf.write('train end time: {}\n'.format(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))) wf.write("Keras MSE on trainData, %f\n" % val_scale_MSE) wf.write("Rescaled MSE on trainData, %f\n" % val_rescale_MSE)