Exemple #1
0
 def test_tcn_forecaster_onnx_methods(self):
     train_data, val_data, test_data = create_data()
     forecaster = LSTMForecaster(past_seq_len=24,
                                 input_feature_num=2,
                                 output_feature_num=2,
                                 loss="mae",
                                 lr=0.01)
     forecaster.fit(train_data, epochs=2)
     try:
         import onnx
         import onnxruntime
         pred = forecaster.predict(test_data[0])
         pred_onnx = forecaster.predict_with_onnx(test_data[0])
         np.testing.assert_almost_equal(pred, pred_onnx, decimal=5)
         mse = forecaster.evaluate(test_data, multioutput="raw_values")
         mse_onnx = forecaster.evaluate_with_onnx(test_data,
                                                  multioutput="raw_values")
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
         mse = forecaster.evaluate(test_data)
         mse_onnx = forecaster.evaluate_with_onnx(test_data)
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
     except ImportError:
         pass
Exemple #2
0
    def test_tcn_forecaster_distributed(self):
        train_data, val_data, test_data = create_data()
        init_orca_context(cores=4, memory="2g")

        forecaster = LSTMForecaster(past_seq_len=24,
                                    input_feature_num=2,
                                    output_feature_num=2,
                                    loss="mae",
                                    lr=0.01,
                                    distributed=True)

        forecaster.fit(train_data, epochs=2)
        distributed_pred = forecaster.predict(test_data[0])
        distributed_eval = forecaster.evaluate(val_data)

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        forecaster.to_local()
        local_pred = forecaster.predict(test_data[0])
        local_eval = forecaster.evaluate(val_data)

        np.testing.assert_almost_equal(distributed_pred, local_pred, decimal=5)

        try:
            import onnx
            import onnxruntime
            local_pred_onnx = forecaster.predict_with_onnx(test_data[0])
            local_eval_onnx = forecaster.evaluate_with_onnx(val_data)
            np.testing.assert_almost_equal(distributed_pred, local_pred_onnx, decimal=5)
        except ImportError:
            pass

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        stop_orca_context()