コード例 #1
0
 def test_tcn_forecaster_runtime_error(self):
     train_data, val_data, test_data = create_data()
     forecaster = LSTMForecaster(past_seq_len=24,
                                 input_feature_num=2,
                                 output_feature_num=2,
                                 loss="mae",
                                 lr=0.01)
     with pytest.raises(RuntimeError):
         with tempfile.TemporaryDirectory() as tmp_dir_name:
             ckpt_name = os.path.join(tmp_dir_name, "ckpt")
             forecaster.save(ckpt_name)
     with pytest.raises(RuntimeError):
         forecaster.predict(test_data[0])
     with pytest.raises(RuntimeError):
         forecaster.evaluate(test_data)
コード例 #2
0
 def test_tcn_forecaster_save_load(self):
     train_data, val_data, test_data = create_data()
     forecaster = LSTMForecaster(past_seq_len=24,
                                 input_feature_num=2,
                                 output_feature_num=2,
                                 loss="mae",
                                 lr=0.01)
     train_mse = forecaster.fit(train_data, epochs=2)
     with tempfile.TemporaryDirectory() as tmp_dir_name:
         ckpt_name = os.path.join(tmp_dir_name, "ckpt")
         test_pred_save = forecaster.predict(test_data[0])
         forecaster.save(ckpt_name)
         forecaster.load(ckpt_name)
         test_pred_load = forecaster.predict(test_data[0])
     np.testing.assert_almost_equal(test_pred_save, test_pred_load)