コード例 #1
0
    def test_tcn_forecaster_xshard_input(self):
        train_data, val_data, test_data = create_data()
        print("original", train_data[0].dtype)
        init_orca_context(cores=4, memory="2g")
        from zoo.orca.data import XShards

        def transform_to_dict(data):
            return {'x': data[0], 'y': data[1]}

        def transform_to_dict_x(data):
            return {'x': data[0]}
        train_data = XShards.partition(train_data).transform_shard(transform_to_dict)
        val_data = XShards.partition(val_data).transform_shard(transform_to_dict)
        test_data = XShards.partition(test_data).transform_shard(transform_to_dict_x)
        for distributed in [True, False]:
            forecaster = Seq2SeqForecaster(past_seq_len=24,
                                           future_seq_len=5,
                                           input_feature_num=1,
                                           output_feature_num=1,
                                           loss="mae",
                                           lr=0.01,
                                           distributed=distributed)
            forecaster.fit(train_data, epochs=2)
            distributed_pred = forecaster.predict(test_data)
            distributed_eval = forecaster.evaluate(val_data)
        stop_orca_context()
コード例 #2
0
 def test_tcn_forecaster_shape_error(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(future_seq_len=5,
                                    input_feature_num=2,
                                    output_feature_num=4)
     with pytest.raises(AssertionError):
         forecaster.fit(train_data[0], train_data[1], epochs=2)
コード例 #3
0
 def test_s2s_forecaster_fit_eva_pred(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(future_seq_len=5,
                                    input_feature_num=3,
                                    output_feature_num=2,
                                    lstm_layer_num=2)
     train_mse = forecaster.fit(train_data[0], train_data[1], epochs=10)
     test_pred = forecaster.predict(test_data[0])
     assert test_pred.shape == test_data[1].shape
     test_mse = forecaster.evaluate(test_data[0], test_data[1])
コード例 #4
0
 def test_tcn_forecaster_shape_error(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(past_seq_len=24,
                                    future_seq_len=5,
                                    input_feature_num=1,
                                    output_feature_num=2,
                                    loss="mae",
                                    lr=0.01)
     with pytest.raises(AssertionError):
         forecaster.fit(train_data, epochs=2)
コード例 #5
0
 def test_tcn_forecaster_fit_eva_pred(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(past_seq_len=24,
                                    future_seq_len=5,
                                    input_feature_num=1,
                                    output_feature_num=1,
                                    loss="mae",
                                    lr=0.01)
     train_loss = forecaster.fit(train_data, epochs=2)
     test_pred = forecaster.predict(test_data[0])
     assert test_pred.shape == test_data[1].shape
     test_mse = forecaster.evaluate(test_data)
コード例 #6
0
 def test_tcn_forecaster_runtime_error(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(future_seq_len=5,
                                    input_feature_num=3,
                                    output_feature_num=2)
     with pytest.raises(RuntimeError):
         with tempfile.TemporaryDirectory() as tmp_dir_name:
             ckpt_name = os.path.join(tmp_dir_name, "ckpt")
             forecaster.save(ckpt_name)
     with pytest.raises(RuntimeError):
         forecaster.predict(test_data[0])
     with pytest.raises(RuntimeError):
         forecaster.evaluate(test_data[0], test_data[1])
コード例 #7
0
 def test_s2s_forecaster_save_restore(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(future_seq_len=5,
                                    input_feature_num=3,
                                    output_feature_num=2)
     train_mse = forecaster.fit(train_data[0], train_data[1], epochs=10)
     with tempfile.TemporaryDirectory() as tmp_dir_name:
         ckpt_name = os.path.join(tmp_dir_name, "ckpt")
         test_pred_save = forecaster.predict(test_data[0])
         forecaster.save(ckpt_name)
         forecaster.restore(ckpt_name)
         test_pred_restore = forecaster.predict(test_data[0])
     np.testing.assert_almost_equal(test_pred_save, test_pred_restore)
コード例 #8
0
 def test_tcn_forecaster_save_load(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(past_seq_len=24,
                                    future_seq_len=5,
                                    input_feature_num=1,
                                    output_feature_num=1,
                                    loss="mae",
                                    lr=0.01)
     train_mse = forecaster.fit(train_data, epochs=2)
     with tempfile.TemporaryDirectory() as tmp_dir_name:
         ckpt_name = os.path.join(tmp_dir_name, "ckpt")
         test_pred_save = forecaster.predict(test_data[0])
         forecaster.save(ckpt_name)
         forecaster.load(ckpt_name)
         test_pred_load = forecaster.predict(test_data[0])
     np.testing.assert_almost_equal(test_pred_save, test_pred_load)
コード例 #9
0
 def test_tcn_forecaster_onnx_methods(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(future_seq_len=5,
                                    input_feature_num=3,
                                    output_feature_num=2,
                                    teacher_forcing=True)
     forecaster.fit(train_data[0], train_data[1], epochs=2)
     try:
         import onnx
         import onnxruntime
         pred = forecaster.predict(test_data[0])
         pred_onnx = forecaster.predict_with_onnx(test_data[0])
         np.testing.assert_almost_equal(pred, pred_onnx, decimal=5)
         mse = forecaster.evaluate(test_data[0], test_data[1])
         mse_onnx = forecaster.evaluate_with_onnx(test_data[0],
                                                  test_data[1])
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
     except ImportError:
         pass
コード例 #10
0
    def test_tcn_forecaster_distributed(self):
        train_data, val_data, test_data = create_data()

        init_orca_context(cores=4, memory="2g")

        forecaster = Seq2SeqForecaster(past_seq_len=24,
                                       future_seq_len=5,
                                       input_feature_num=1,
                                       output_feature_num=1,
                                       loss="mae",
                                       lr=0.01,
                                       distributed=True)

        forecaster.fit(train_data, epochs=2)
        distributed_pred = forecaster.predict(test_data[0])
        distributed_eval = forecaster.evaluate(val_data)

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        forecaster.to_local()
        local_pred = forecaster.predict(test_data[0])
        local_eval = forecaster.evaluate(val_data)

        np.testing.assert_almost_equal(distributed_pred, local_pred, decimal=5)

        try:
            import onnx
            import onnxruntime
            local_pred_onnx = forecaster.predict_with_onnx(test_data[0])
            local_eval_onnx = forecaster.evaluate_with_onnx(val_data)
            np.testing.assert_almost_equal(distributed_pred, local_pred_onnx, decimal=5)
        except ImportError:
            pass

        model = forecaster.get_model()
        assert isinstance(model, torch.nn.Module)

        stop_orca_context()
コード例 #11
0
 def test_tcn_forecaster_onnx_methods(self):
     train_data, val_data, test_data = create_data()
     forecaster = Seq2SeqForecaster(past_seq_len=24,
                                    future_seq_len=5,
                                    input_feature_num=1,
                                    output_feature_num=1,
                                    loss="mae",
                                    lr=0.01)
     forecaster.fit(train_data, epochs=2)
     try:
         import onnx
         import onnxruntime
         pred = forecaster.predict(test_data[0])
         pred_onnx = forecaster.predict_with_onnx(test_data[0])
         np.testing.assert_almost_equal(pred, pred_onnx, decimal=5)
         mse = forecaster.evaluate(test_data, multioutput="raw_values")
         mse_onnx = forecaster.evaluate_with_onnx(test_data,
                                                  multioutput="raw_values")
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
         mse = forecaster.evaluate(test_data)
         mse_onnx = forecaster.evaluate_with_onnx(test_data)
         np.testing.assert_almost_equal(mse, mse_onnx, decimal=5)
     except ImportError:
         pass