def test_tcn_forecaster_xshard_input(self): train_data, val_data, test_data = create_data() print("original", train_data[0].dtype) init_orca_context(cores=4, memory="2g") from zoo.orca.data import XShards def transform_to_dict(data): return {'x': data[0], 'y': data[1]} def transform_to_dict_x(data): return {'x': data[0]} train_data = XShards.partition(train_data).transform_shard(transform_to_dict) val_data = XShards.partition(val_data).transform_shard(transform_to_dict) test_data = XShards.partition(test_data).transform_shard(transform_to_dict_x) for distributed in [True, False]: forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, loss="mae", lr=0.01, distributed=distributed) forecaster.fit(train_data, epochs=2) distributed_pred = forecaster.predict(test_data) distributed_eval = forecaster.evaluate(val_data) stop_orca_context()
def test_tcn_forecaster_shape_error(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(future_seq_len=5, input_feature_num=2, output_feature_num=4) with pytest.raises(AssertionError): forecaster.fit(train_data[0], train_data[1], epochs=2)
def test_s2s_forecaster_fit_eva_pred(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(future_seq_len=5, input_feature_num=3, output_feature_num=2, lstm_layer_num=2) train_mse = forecaster.fit(train_data[0], train_data[1], epochs=10) test_pred = forecaster.predict(test_data[0]) assert test_pred.shape == test_data[1].shape test_mse = forecaster.evaluate(test_data[0], test_data[1])
def test_tcn_forecaster_shape_error(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=2, loss="mae", lr=0.01) with pytest.raises(AssertionError): forecaster.fit(train_data, epochs=2)
def test_tcn_forecaster_fit_eva_pred(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, loss="mae", lr=0.01) train_loss = forecaster.fit(train_data, epochs=2) test_pred = forecaster.predict(test_data[0]) assert test_pred.shape == test_data[1].shape test_mse = forecaster.evaluate(test_data)
def test_tcn_forecaster_runtime_error(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(future_seq_len=5, input_feature_num=3, output_feature_num=2) with pytest.raises(RuntimeError): with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "ckpt") forecaster.save(ckpt_name) with pytest.raises(RuntimeError): forecaster.predict(test_data[0]) with pytest.raises(RuntimeError): forecaster.evaluate(test_data[0], test_data[1])
def test_s2s_forecaster_save_restore(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(future_seq_len=5, input_feature_num=3, output_feature_num=2) train_mse = forecaster.fit(train_data[0], train_data[1], epochs=10) with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "ckpt") test_pred_save = forecaster.predict(test_data[0]) forecaster.save(ckpt_name) forecaster.restore(ckpt_name) test_pred_restore = forecaster.predict(test_data[0]) np.testing.assert_almost_equal(test_pred_save, test_pred_restore)
def test_tcn_forecaster_save_load(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, loss="mae", lr=0.01) train_mse = forecaster.fit(train_data, epochs=2) with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "ckpt") test_pred_save = forecaster.predict(test_data[0]) forecaster.save(ckpt_name) forecaster.load(ckpt_name) test_pred_load = forecaster.predict(test_data[0]) np.testing.assert_almost_equal(test_pred_save, test_pred_load)
def test_tcn_forecaster_onnx_methods(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(future_seq_len=5, input_feature_num=3, output_feature_num=2, teacher_forcing=True) forecaster.fit(train_data[0], train_data[1], epochs=2) try: import onnx import onnxruntime pred = forecaster.predict(test_data[0]) pred_onnx = forecaster.predict_with_onnx(test_data[0]) np.testing.assert_almost_equal(pred, pred_onnx, decimal=5) mse = forecaster.evaluate(test_data[0], test_data[1]) mse_onnx = forecaster.evaluate_with_onnx(test_data[0], test_data[1]) np.testing.assert_almost_equal(mse, mse_onnx, decimal=5) except ImportError: pass
def test_tcn_forecaster_distributed(self): train_data, val_data, test_data = create_data() init_orca_context(cores=4, memory="2g") forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, loss="mae", lr=0.01, distributed=True) forecaster.fit(train_data, epochs=2) distributed_pred = forecaster.predict(test_data[0]) distributed_eval = forecaster.evaluate(val_data) model = forecaster.get_model() assert isinstance(model, torch.nn.Module) forecaster.to_local() local_pred = forecaster.predict(test_data[0]) local_eval = forecaster.evaluate(val_data) np.testing.assert_almost_equal(distributed_pred, local_pred, decimal=5) try: import onnx import onnxruntime local_pred_onnx = forecaster.predict_with_onnx(test_data[0]) local_eval_onnx = forecaster.evaluate_with_onnx(val_data) np.testing.assert_almost_equal(distributed_pred, local_pred_onnx, decimal=5) except ImportError: pass model = forecaster.get_model() assert isinstance(model, torch.nn.Module) stop_orca_context()
def test_tcn_forecaster_onnx_methods(self): train_data, val_data, test_data = create_data() forecaster = Seq2SeqForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, loss="mae", lr=0.01) forecaster.fit(train_data, epochs=2) try: import onnx import onnxruntime pred = forecaster.predict(test_data[0]) pred_onnx = forecaster.predict_with_onnx(test_data[0]) np.testing.assert_almost_equal(pred, pred_onnx, decimal=5) mse = forecaster.evaluate(test_data, multioutput="raw_values") mse_onnx = forecaster.evaluate_with_onnx(test_data, multioutput="raw_values") np.testing.assert_almost_equal(mse, mse_onnx, decimal=5) mse = forecaster.evaluate(test_data) mse_onnx = forecaster.evaluate_with_onnx(test_data) np.testing.assert_almost_equal(mse, mse_onnx, decimal=5) except ImportError: pass