y = torch.from_numpy(train_label).long() for epoch in range(EPOCH): batches = gen_batch(x, y, batch_size) loss_sum = 0 for var_x, var_y in batches: pred = model(var_x) loss = loss_func(pred, var_y) optimizer.zero_grad() loss.backward() optimizer.step() loss_sum += loss.item() * pred.shape[0] print('epoch %d loss: %f' % (epoch, loss_sum / x.shape[0])) if (epoch + 1) % 50 == 0: torch.save(model.state_dict(), tmp_save_path % (epoch + 1)) torch.save(model.state_dict(), model_save_path) else: print('loading model...') model.load_state_dict(torch.load(model_save_path)) var_test_x = torch.from_numpy(test_data).float().to(DEVICE) var_test_y = torch.from_numpy(test_label).long().to(DEVICE) # 测试序列似乎还是太大,只能分批测试 test_pred = [] batches = gen_batch(var_test_x, var_test_y, batch_size=batch_size) for var_x, var_y in batches: pred = model(var_x) test_pred.extend(list(pred.cpu().data.numpy())) test_pred = np.array(test_pred) test_pred = np.argmax(test_pred, axis=1) classify_analysis(test_label, test_pred)
if __name__ == '__main__': path = './data/train-test-0.4.npz' train_data, test_data, train_label, test_label = load_train_test(path) vectors = np.load('lstm_ed.npz') error_vector = vectors['error_vector'] pred_vector = vectors['pred_vector'] model = KNeighborsClassifier() model.fit(error_vector, train_label) pred = model.predict(error_vector) model2 = KNeighborsClassifier() model2.fit(pred_vector, train_label) pred2 = model2.predict(pred_vector) classify_analysis(train_label, pred) classify_analysis(train_label, pred2) # plt.figure(figsize=(8, 3)) # plt.subplot(2, 1, 1) # plt.plot(train_data[train_label == 0][0], color='black', label='ground truth') # plt.plot(pred_vector[train_label == 0][0], color='red', label='prediction') # plt.plot(error_vector[train_label == 0][0], color='blue', label='error') # plt.legend() # plt.title('Normal TS') # plt.tight_layout(True) # plt.subplot(2, 1, 2) # plt.plot(train_data[train_label == 1][0], color='black', label='ground truth') # plt.plot(pred_vector[train_label == 1][0], color='red', label='prediction') # plt.plot(error_vector[train_label == 1][0], color='blue', label='error') # plt.legend()