def test_tcn_forecaster_runtime_error(self): train_data, val_data, test_data = create_data() forecaster = TCNForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, kernel_size=3, lr=0.01) with pytest.raises(RuntimeError): with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "ckpt") forecaster.save(ckpt_name) with pytest.raises(RuntimeError): forecaster.predict(test_data[0]) with pytest.raises(RuntimeError): forecaster.evaluate(test_data[0], test_data[1])
def test_tcn_forecaster_save_restore(self): train_data, val_data, test_data = create_data() forecaster = TCNForecaster(past_seq_len=24, future_seq_len=5, input_feature_num=1, output_feature_num=1, kernel_size=3, lr=0.01) train_mse = forecaster.fit(train_data[0], train_data[1], epochs=2) with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "ckpt") test_pred_save = forecaster.predict(test_data[0]) forecaster.save(ckpt_name) forecaster.restore(ckpt_name) test_pred_restore = forecaster.predict(test_data[0]) np.testing.assert_almost_equal(test_pred_save, test_pred_restore)