def test_tcn_forecaster_onnx_methods(self):
     train_data, val_data, test_data = create_data()
     forecaster = TCNForecaster(past_seq_len=24,
                                future_seq_len=5,
                                input_feature_num=1,
                                output_feature_num=1,
                                kernel_size=4,
                                num_channels=[16, 16],
                                lr=0.01)
     forecaster.fit(train_data[0], train_data[1], epochs=2)
     try:
         import onnx
         import onnxruntime
         pred = forecaster.predict(test_data[0])
         pred_onnx = forecaster.predict_with_onnx(test_data[0])
         np.testing.assert_almost_equal(pred, pred_onnx, decimal=5)
         mse = forecaster.evaluate(test_data[0],
                                   test_data[1],
                                   multioutput="raw_values")
         mse_onnx = forecaster.evaluate_with_onnx(test_data[0],
                                                  test_data[1],
                                                  multioutput="raw_values")
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
         mse = forecaster.evaluate(test_data[0], test_data[1])
         mse_onnx = forecaster.evaluate_with_onnx(test_data[0],
                                                  test_data[1])
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
     except ImportError:
         pass
Beispiel #2
0
 def test_tcn_forecaster_save_restore(self):
     train_data, val_data, test_data = create_data()
     forecaster = TCNForecaster(past_seq_len=24,
                                future_seq_len=5,
                                input_feature_num=1,
                                output_feature_num=1,
                                kernel_size=3,
                                lr=0.01)
     train_mse = forecaster.fit(train_data[0], train_data[1], epochs=2)
     with tempfile.TemporaryDirectory() as tmp_dir_name:
         ckpt_name = os.path.join(tmp_dir_name, "ckpt")
         test_pred_save = forecaster.predict(test_data[0])
         forecaster.save(ckpt_name)
         forecaster.restore(ckpt_name)
         test_pred_restore = forecaster.predict(test_data[0])
     np.testing.assert_almost_equal(test_pred_save, test_pred_restore)
 def test_tcn_forecaster_runtime_error(self):
     train_data, val_data, test_data = create_data()
     forecaster = TCNForecaster(past_seq_len=24,
                                future_seq_len=5,
                                input_feature_num=1,
                                output_feature_num=1,
                                kernel_size=3,
                                lr=0.01)
     with pytest.raises(RuntimeError):
         with tempfile.TemporaryDirectory() as tmp_dir_name:
             ckpt_name = os.path.join(tmp_dir_name, "ckpt")
             forecaster.save(ckpt_name)
     with pytest.raises(RuntimeError):
         forecaster.predict(test_data[0])
     with pytest.raises(RuntimeError):
         forecaster.evaluate(test_data[0], test_data[1])
Beispiel #4
0
 def test_tcn_forecaster_fit_eva_pred(self):
     train_data, val_data, test_data = create_data()
     forecaster = TCNForecaster(past_seq_len=24,
                                future_seq_len=5,
                                input_feature_num=1,
                                output_feature_num=1,
                                kernel_size=3,
                                lr=0.01)
     train_mse = forecaster.fit(train_data[0], train_data[1], epochs=2)
     test_pred = forecaster.predict(test_data[0])
     assert test_pred.shape == test_data[1].shape
     test_mse = forecaster.evaluate(test_data[0], test_data[1])