Ejemplo n.º 1
0
 def _gap_test_template(self, times, values):
     random_model = RandomStateSpaceModel(
         state_dimension=1,
         state_noise_dimension=1,
         configuration=state_space_model.StateSpaceModelConfiguration(
             num_features=1))
     random_model.initialize_graph()
     input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader({
             feature_keys.TrainEvalFeatures.TIMES:
             times,
             feature_keys.TrainEvalFeatures.VALUES:
             values
         }))
     features, _ = input_fn()
     times = features[feature_keys.TrainEvalFeatures.TIMES]
     values = features[feature_keys.TrainEvalFeatures.VALUES]
     model_outputs = random_model.get_batch_loss(
         features={
             feature_keys.TrainEvalFeatures.TIMES: times,
             feature_keys.TrainEvalFeatures.VALUES: values
         },
         mode=None,
         state=math_utils.replicate_state(
             start_state=random_model.get_start_state(),
             batch_size=array_ops.shape(times)[0]))
     with self.cached_session() as session:
         variables.global_variables_initializer().run()
         coordinator = coordinator_lib.Coordinator()
         queue_runner_impl.start_queue_runners(session, coord=coordinator)
         model_outputs.loss.eval()
         coordinator.request_stop()
         coordinator.join()
Ejemplo n.º 2
0
 def test_exogenous_input(self):
     """Test that no errors are raised when using exogenous features."""
     dtype = dtypes.float64
     times = [1, 2, 3, 4, 5, 6]
     values = [[0.01], [5.10], [5.21], [0.30], [5.41], [0.50]]
     feature_a = [["off"], ["on"], ["on"], ["off"], ["on"], ["off"]]
     sparse_column_a = feature_column.sparse_column_with_keys(
         column_name="feature_a", keys=["on", "off"])
     one_hot_a = layers.one_hot_column(sparse_id_column=sparse_column_a)
     regressor = estimators.StructuralEnsembleRegressor(
         periodicities=[],
         num_features=1,
         moving_average_order=0,
         exogenous_feature_columns=[one_hot_a],
         dtype=dtype)
     features = {
         TrainEvalFeatures.TIMES: times,
         TrainEvalFeatures.VALUES: values,
         "feature_a": feature_a
     }
     train_input_fn = input_pipeline.RandomWindowInputFn(
         input_pipeline.NumpyReader(features), window_size=6, batch_size=1)
     regressor.train(input_fn=train_input_fn, steps=1)
     eval_input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader(features))
     evaluation = regressor.evaluate(input_fn=eval_input_fn, steps=1)
     predict_input_fn = input_pipeline.predict_continuation_input_fn(
         evaluation,
         times=[[7, 8, 9]],
         exogenous_features={"feature_a": [[["on"], ["off"], ["on"]]]})
     regressor.predict(input_fn=predict_input_fn)
Ejemplo n.º 3
0
 def _equivalent_to_single_model_test_template(self, model_generator):
     with self.cached_session() as session:
         random_model = RandomStateSpaceModel(
             state_dimension=5,
             state_noise_dimension=4,
             configuration=state_space_model.StateSpaceModelConfiguration(
                 dtype=dtypes.float64, num_features=1))
         random_model.initialize_graph()
         series_length = 10
         model_data = random_model.generate(
             number_of_series=1,
             series_length=series_length,
             model_parameters=random_model.random_model_parameters())
         input_fn = input_pipeline.WholeDatasetInputFn(
             input_pipeline.NumpyReader(model_data))
         features, _ = input_fn()
         model_outputs = random_model.get_batch_loss(
             features=features,
             mode=None,
             state=math_utils.replicate_state(
                 start_state=random_model.get_start_state(),
                 batch_size=array_ops.shape(
                     features[feature_keys.TrainEvalFeatures.TIMES])[0]))
         variables.global_variables_initializer().run()
         compare_outputs_evaled_fn = model_generator(
             random_model, model_data)
         coordinator = coordinator_lib.Coordinator()
         queue_runner_impl.start_queue_runners(session, coord=coordinator)
         compare_outputs_evaled = compare_outputs_evaled_fn(session)
         model_outputs_evaled = session.run(
             (model_outputs.end_state, model_outputs.predictions))
         coordinator.request_stop()
         coordinator.join()
         model_posteriors, model_predictions = model_outputs_evaled
         (_, compare_posteriors,
          compare_predictions) = compare_outputs_evaled
         (model_posterior_mean, model_posterior_var,
          model_from_time) = model_posteriors
         (compare_posterior_mean, compare_posterior_var,
          compare_from_time) = compare_posteriors
         self.assertAllClose(model_posterior_mean,
                             compare_posterior_mean[0])
         self.assertAllClose(model_posterior_var, compare_posterior_var[0])
         self.assertAllClose(model_from_time, compare_from_time)
         self.assertEqual(sorted(model_predictions.keys()),
                          sorted(compare_predictions.keys()))
         for prediction_name in model_predictions:
             if prediction_name == "loss":
                 # Chunking means that losses will be different; skip testing them.
                 continue
             # Compare the last chunk to their corresponding un-chunked model
             # predictions
             last_prediction_chunk = compare_predictions[prediction_name][
                 -1]
             comparison_values = last_prediction_chunk.shape[0]
             model_prediction = (
                 model_predictions[prediction_name][0, -comparison_values:])
             self.assertAllClose(model_prediction, last_prediction_chunk)
Ejemplo n.º 4
0
 def test_structural_ensemble_numpy_input(self):
     numpy_data = {
         "times": numpy.arange(50),
         "values": numpy.random.normal(size=[50])
     }
     estimators.StructuralEnsembleRegressor(
         num_features=1,
         periodicities=[],
         model_dir=self.get_temp_dir(),
         config=_SeedRunConfig()).train(input_pipeline.WholeDatasetInputFn(
             input_pipeline.NumpyReader(numpy_data)),
                                        steps=1)
Ejemplo n.º 5
0
    def _test_pass_to_next(self, read_offset, step, correct_offset):
        stub_model = StubTimeSeriesModel(correct_offset=correct_offset)
        data = self._make_test_data(length=100 + read_offset,
                                    cut_start=None,
                                    cut_end=None,
                                    offset=100.,
                                    step=step)
        init_input_fn = input_pipeline.WholeDatasetInputFn(
            input_pipeline.NumpyReader(
                {k: v[:-read_offset]
                 for k, v in data.items()}))
        result_input_fn = input_pipeline.WholeDatasetInputFn(
            input_pipeline.NumpyReader(
                {k: v[read_offset:]
                 for k, v in data.items()}))

        chainer = state_management.ChainingStateManager(
            state_saving_interval=1)
        stub_model.initialize_graph()
        chainer.initialize_graph(model=stub_model)
        init_model_outputs = chainer.define_loss(
            model=stub_model,
            features=init_input_fn()[0],
            mode=estimator_lib.ModeKeys.TRAIN)
        result_model_outputs = chainer.define_loss(
            model=stub_model,
            features=result_input_fn()[0],
            mode=estimator_lib.ModeKeys.TRAIN)
        with self.cached_session() as session:
            variables.global_variables_initializer().run()
            coordinator = coordinator_lib.Coordinator()
            queue_runner_impl.start_queue_runners(session, coord=coordinator)
            init_model_outputs.loss.eval()
            returned_loss = result_model_outputs.loss.eval()
            coordinator.request_stop()
            coordinator.join()
            return returned_loss
Ejemplo n.º 6
0
 def _whole_dataset_input_fn_test_template(self, time_series_reader,
                                           num_features, num_samples):
     result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
     with self.cached_session() as session:
         session.run(variables.local_variables_initializer())
         coordinator = coordinator_lib.Coordinator()
         queue_runner_impl.start_queue_runners(session, coord=coordinator)
         features = session.run(result)
         coordinator.request_stop()
         coordinator.join()
     self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype)
     self.assertAllEqual(
         numpy.arange(num_samples, dtype=numpy.int64)[None, :],
         features[TrainEvalFeatures.TIMES])
     for feature_number in range(num_features):
         self.assertAllEqual(
             features[TrainEvalFeatures.TIMES] * 2. + feature_number,
             features[TrainEvalFeatures.VALUES][:, :, feature_number])
Ejemplo n.º 7
0
 def test_loop_unrolling(self):
     """Tests running/restoring from a checkpoint with static unrolling."""
     model = TimeDependentStateSpaceModel(
         # Unroll during training, but not evaluation
         static_unrolling_window_size_threshold=2)
     estimator = estimators.StateSpaceRegressor(model=model)
     times = numpy.arange(100)
     values = numpy.arange(100)
     dataset = {
         feature_keys.TrainEvalFeatures.TIMES: times,
         feature_keys.TrainEvalFeatures.VALUES: values
     }
     train_input_fn = input_pipeline.RandomWindowInputFn(
         input_pipeline.NumpyReader(dataset), batch_size=16, window_size=2)
     eval_input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader(dataset))
     estimator.train(input_fn=train_input_fn, max_steps=1)
     estimator.evaluate(input_fn=eval_input_fn, steps=1)
Ejemplo n.º 8
0
 def _time_dependency_test_template(self, model_type):
     """Test that a time-dependent observation model influences predictions."""
     model = model_type()
     estimator = estimators.StateSpaceRegressor(
         model=model,
         optimizer=gradient_descent.GradientDescentOptimizer(0.1))
     values = numpy.reshape([1., 2., 3., 4.], newshape=[1, 4, 1])
     input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader({
             feature_keys.TrainEvalFeatures.TIMES: [[0, 1, 2, 3]],
             feature_keys.TrainEvalFeatures.VALUES:
             values
         }))
     estimator.train(input_fn=input_fn, max_steps=1)
     predicted_values = estimator.evaluate(input_fn=input_fn,
                                           steps=1)["mean"]
     # Throw out the first value so we don't test the prior
     self.assertAllEqual(values[1:], predicted_values[1:])
Ejemplo n.º 9
0
    def dry_run_train_helper(self,
                             sample_every,
                             period,
                             num_samples,
                             model_type,
                             model_args,
                             num_features=1):
        numpy.random.seed(1)
        dtype = dtypes.float32
        features = self.simple_data(sample_every,
                                    dtype=dtype,
                                    period=period,
                                    num_samples=num_samples,
                                    num_features=num_features)
        model = model_type(
            configuration=(state_space_model.StateSpaceModelConfiguration(
                num_features=num_features,
                dtype=dtype,
                covariance_prior_fn=lambda _: 0.)),
            **model_args)

        class _RunConfig(estimator_lib.RunConfig):
            @property
            def tf_random_seed(self):
                return 4

        estimator = estimators.StateSpaceRegressor(model, config=_RunConfig())
        train_input_fn = input_pipeline.RandomWindowInputFn(
            input_pipeline.NumpyReader(features),
            num_threads=1,
            shuffle_seed=1,
            batch_size=16,
            window_size=16)
        eval_input_fn = input_pipeline.WholeDatasetInputFn(
            input_pipeline.NumpyReader(features))
        estimator.train(input_fn=train_input_fn, max_steps=1)
        first_evaluation = estimator.evaluate(input_fn=eval_input_fn, steps=1)
        estimator.train(input_fn=train_input_fn, max_steps=3)
        second_evaluation = estimator.evaluate(input_fn=eval_input_fn, steps=1)
        self.assertLess(second_evaluation["loss"], first_evaluation["loss"])
Ejemplo n.º 10
0
    def test_exact_posterior_recovery_no_transition_noise(self):
        with self.cached_session() as session:
            stub_model, data, true_params = self._get_single_model()
            input_fn = input_pipeline.WholeDatasetInputFn(
                input_pipeline.NumpyReader(data))
            features, _ = input_fn()
            model_outputs = stub_model.get_batch_loss(
                features=features,
                mode=None,
                state=math_utils.replicate_state(
                    start_state=stub_model.get_start_state(),
                    batch_size=array_ops.shape(
                        features[feature_keys.TrainEvalFeatures.TIMES])[0]))
            variables.global_variables_initializer().run()
            coordinator = coordinator_lib.Coordinator()
            queue_runner_impl.start_queue_runners(session, coord=coordinator)
            posterior_mean, posterior_var, posterior_times = session.run(
                # Feed the true model parameters so that this test doesn't depend on
                # the generated parameters being close to the variable initializations
                # (an alternative would be training steps to fit the noise values,
                # which would be slow).
                model_outputs.end_state,
                feed_dict=true_params)
            coordinator.request_stop()
            coordinator.join()

            self.assertAllClose(numpy.zeros([1, 4, 4]),
                                posterior_var,
                                atol=1e-2)
            self.assertAllClose(numpy.dot(
                numpy.linalg.matrix_power(
                    stub_model.transition,
                    data[feature_keys.TrainEvalFeatures.TIMES].shape[1]),
                true_params[stub_model.prior_state_mean]),
                                posterior_mean[0],
                                rtol=1e-1)
            self.assertAllClose(
                math_utils.batch_end_time(
                    features[feature_keys.TrainEvalFeatures.TIMES]).eval(),
                posterior_times)
Ejemplo n.º 11
0
 def test_no_periodicity(self):
     """Test that no errors are raised when periodicites is None."""
     dtype = dtypes.float64
     times = [1, 2, 3, 4, 5, 6]
     values = [[0.01], [5.10], [5.21], [0.30], [5.41], [0.50]]
     regressor = estimators.StructuralEnsembleRegressor(
         periodicities=None,
         num_features=1,
         moving_average_order=0,
         dtype=dtype)
     features = {
         TrainEvalFeatures.TIMES: times,
         TrainEvalFeatures.VALUES: values
     }
     train_input_fn = input_pipeline.RandomWindowInputFn(
         input_pipeline.NumpyReader(features), window_size=6, batch_size=1)
     regressor.train(input_fn=train_input_fn, steps=1)
     eval_input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader(features))
     evaluation = regressor.evaluate(input_fn=eval_input_fn, steps=1)
     predict_input_fn = input_pipeline.predict_continuation_input_fn(
         evaluation, times=[[7, 8, 9]])
     regressor.predict(input_fn=predict_input_fn)
Ejemplo n.º 12
0
 def test_state_override(self):
     test_start_state = (numpy.array([[2, 3, 4]]), (numpy.array([2]),
                                                    numpy.array([[3.,
                                                                  5.]])))
     data = {
         feature_keys.FilteringFeatures.TIMES: numpy.arange(5),
         feature_keys.FilteringFeatures.VALUES: numpy.zeros(shape=[5, 3])
     }
     features, _ = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader(data))()
     features[feature_keys.FilteringFeatures.STATE_TUPLE] = test_start_state
     stub_model = _StateOverrideModel()
     chainer = state_management.ChainingStateManager()
     stub_model.initialize_graph()
     chainer.initialize_graph(model=stub_model)
     model_outputs = chainer.define_loss(model=stub_model,
                                         features=features,
                                         mode=estimator_lib.ModeKeys.EVAL)
     with train.MonitoredSession() as session:
         end_state = session.run(model_outputs.end_state)
     nest.assert_same_structure(test_start_state, end_state)
     for expected, received in zip(nest.flatten(test_start_state),
                                   nest.flatten(end_state)):
         self.assertAllEqual(expected, received)
Ejemplo n.º 13
0
def _train_on_generated_data(
    generate_fn, generative_model, train_iterations, seed,
    learning_rate=0.1, ignore_params_fn=lambda _: (),
    derived_param_test_fn=lambda _: (),
    train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    train_state_manager=state_management.PassthroughStateManager()):
  """The training portion of parameter recovery tests."""
  random_seed.set_random_seed(seed)
  generate_graph = ops.Graph()
  with generate_graph.as_default():
    with session.Session(graph=generate_graph):
      generative_model.initialize_graph()
      time_series_reader, true_parameters = generate_fn(generative_model)
      true_parameters = {
          tensor.name: value for tensor, value in true_parameters.items()}
  eval_input_fn = input_pipeline.WholeDatasetInputFn(time_series_reader)
  eval_state_manager = state_management.PassthroughStateManager()
  true_parameter_eval_graph = ops.Graph()
  with true_parameter_eval_graph.as_default():
    generative_model.initialize_graph()
    ignore_params = ignore_params_fn(generative_model)
    feature_dict, _ = eval_input_fn()
    eval_state_manager.initialize_graph(generative_model)
    feature_dict[TrainEvalFeatures.VALUES] = math_ops.cast(
        feature_dict[TrainEvalFeatures.VALUES], generative_model.dtype)
    model_outputs = eval_state_manager.define_loss(
        model=generative_model,
        features=feature_dict,
        mode=estimator_lib.ModeKeys.EVAL)
    with session.Session(graph=true_parameter_eval_graph) as sess:
      variables.global_variables_initializer().run()
      coordinator = coordinator_lib.Coordinator()
      queue_runner_impl.start_queue_runners(sess, coord=coordinator)
      true_param_loss = model_outputs.loss.eval(feed_dict=true_parameters)
      true_transformed_params = {
          param: param.eval(feed_dict=true_parameters)
          for param in derived_param_test_fn(generative_model)}
      coordinator.request_stop()
      coordinator.join()

  saving_hook = _SavingTensorHook(
      tensors=true_parameters.keys(),
      every_n_iter=train_iterations - 1)

  class _RunConfig(estimator_lib.RunConfig):

    @property
    def tf_random_seed(self):
      return seed

  estimator = estimators.TimeSeriesRegressor(
      model=generative_model,
      config=_RunConfig(),
      state_manager=train_state_manager,
      optimizer=adam.AdamOptimizer(learning_rate))
  train_input_fn = train_input_fn_type(time_series_reader=time_series_reader)
  trained_loss = (estimator.train(
      input_fn=train_input_fn,
      max_steps=train_iterations,
      hooks=[saving_hook]).evaluate(
          input_fn=eval_input_fn, steps=1))["loss"]
  logging.info("Final trained loss: %f", trained_loss)
  logging.info("True parameter loss: %f", true_param_loss)
  return (ignore_params, true_parameters, true_transformed_params,
          trained_loss, true_param_loss, saving_hook,
          true_parameter_eval_graph)
Ejemplo n.º 14
0
 def test_savedmodel_state_override(self):
     random_model = RandomStateSpaceModel(
         state_dimension=5,
         state_noise_dimension=4,
         configuration=state_space_model.StateSpaceModelConfiguration(
             exogenous_feature_columns=[
                 layers.real_valued_column("exogenous")
             ],
             dtype=dtypes.float64,
             num_features=1))
     estimator = estimators.StateSpaceRegressor(
         model=random_model,
         optimizer=gradient_descent.GradientDescentOptimizer(0.1))
     combined_input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader({
             feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
             feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
             "exogenous": [-1., -2., -3., -4.]
         }))
     estimator.train(combined_input_fn, steps=1)
     export_location = estimator.export_saved_model(
         self.get_temp_dir(),
         estimator.build_raw_serving_input_receiver_fn())
     with ops.Graph().as_default() as graph:
         random_model.initialize_graph()
         with self.session(graph=graph) as session:
             variables.global_variables_initializer().run()
             evaled_start_state = session.run(
                 random_model.get_start_state())
     evaled_start_state = [
         state_element[None, ...] for state_element in evaled_start_state
     ]
     with ops.Graph().as_default() as graph:
         with self.session(graph=graph) as session:
             signatures = loader.load(session, [tag_constants.SERVING],
                                      export_location)
             first_split_filtering = saved_model_utils.filter_continuation(
                 continue_from={
                     feature_keys.FilteringResults.STATE_TUPLE:
                     evaled_start_state
                 },
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [1, 2],
                     feature_keys.FilteringFeatures.VALUES: [1., 2.],
                     "exogenous": [[-1.], [-2.]]
                 })
             second_split_filtering = saved_model_utils.filter_continuation(
                 continue_from=first_split_filtering,
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [3, 4],
                     feature_keys.FilteringFeatures.VALUES: [3., 4.],
                     "exogenous": [[-3.], [-4.]]
                 })
             combined_filtering = saved_model_utils.filter_continuation(
                 continue_from={
                     feature_keys.FilteringResults.STATE_TUPLE:
                     evaled_start_state
                 },
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
                     feature_keys.FilteringFeatures.VALUES:
                     [1., 2., 3., 4.],
                     "exogenous": [[-1.], [-2.], [-3.], [-4.]]
                 })
             split_predict = saved_model_utils.predict_continuation(
                 continue_from=second_split_filtering,
                 signatures=signatures,
                 session=session,
                 steps=1,
                 exogenous_features={"exogenous": [[[-5.]]]})
             combined_predict = saved_model_utils.predict_continuation(
                 continue_from=combined_filtering,
                 signatures=signatures,
                 session=session,
                 steps=1,
                 exogenous_features={"exogenous": [[[-5.]]]})
     for state_key, combined_state_value in combined_filtering.items():
         if state_key == feature_keys.FilteringResults.TIMES:
             continue
         self.assertAllClose(combined_state_value,
                             second_split_filtering[state_key])
     for prediction_key, combined_value in combined_predict.items():
         self.assertAllClose(combined_value, split_predict[prediction_key])
Ejemplo n.º 15
0
    def _fit_restore_fit_test_template(self, estimator_fn, dtype):
        """Tests restoring previously fit models."""
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        exogenous_feature_columns = (
            feature_column.numeric_column("exogenous"), )
        first_estimator = estimator_fn(model_dir, exogenous_feature_columns)
        times = numpy.arange(20, dtype=numpy.int64)
        values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
        exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
        features = {
            feature_keys.TrainEvalFeatures.TIMES: times,
            feature_keys.TrainEvalFeatures.VALUES: values,
            "exogenous": exogenous
        }
        train_input_fn = input_pipeline.RandomWindowInputFn(
            input_pipeline.NumpyReader(features),
            shuffle_seed=2,
            num_threads=1,
            batch_size=16,
            window_size=16)
        eval_input_fn = input_pipeline.RandomWindowInputFn(
            input_pipeline.NumpyReader(features),
            shuffle_seed=3,
            num_threads=1,
            batch_size=16,
            window_size=16)
        first_estimator.train(input_fn=train_input_fn, steps=1)
        first_evaluation = first_estimator.evaluate(input_fn=eval_input_fn,
                                                    steps=1)
        first_loss_before_fit = first_evaluation["loss"]
        self.assertAllEqual(first_loss_before_fit,
                            first_evaluation["average_loss"])
        self.assertAllEqual([], first_loss_before_fit.shape)
        first_estimator.train(input_fn=train_input_fn, steps=1)
        first_loss_after_fit = first_estimator.evaluate(input_fn=eval_input_fn,
                                                        steps=1)["loss"]
        self.assertAllEqual([], first_loss_after_fit.shape)
        second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
        second_estimator.train(input_fn=train_input_fn, steps=1)
        whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
            input_pipeline.NumpyReader(features))
        whole_dataset_evaluation = second_estimator.evaluate(
            input_fn=whole_dataset_input_fn, steps=1)
        exogenous_values_ten_steps = {
            "exogenous":
            numpy.arange(10, dtype=dtype.as_numpy_dtype)[None, :, None]
        }
        predict_input_fn = input_pipeline.predict_continuation_input_fn(
            evaluation=whole_dataset_evaluation,
            exogenous_features=exogenous_values_ten_steps,
            steps=10)
        # Also tests that limit_epochs in predict_continuation_input_fn prevents
        # infinite iteration
        (estimator_predictions, ) = list(
            second_estimator.predict(input_fn=predict_input_fn))
        self.assertAllEqual([10, 1], estimator_predictions["mean"].shape)
        input_receiver_fn = first_estimator.build_raw_serving_input_receiver_fn(
        )
        export_location = first_estimator.export_saved_model(
            self.get_temp_dir(), input_receiver_fn)
        with ops.Graph().as_default():
            with session.Session() as sess:
                signatures = loader.load(sess, [tag_constants.SERVING],
                                         export_location)
                # Test that prediction and filtering can continue from evaluation output
                saved_prediction = saved_model_utils.predict_continuation(
                    continue_from=whole_dataset_evaluation,
                    steps=10,
                    exogenous_features=exogenous_values_ten_steps,
                    signatures=signatures,
                    session=sess)
                # Saved model predictions should be the same as Estimator predictions
                # starting from the same evaluation.
                for prediction_key, prediction_value in estimator_predictions.items(
                ):
                    self.assertAllClose(
                        prediction_value,
                        numpy.squeeze(saved_prediction[prediction_key],
                                      axis=0))
                first_filtering = saved_model_utils.filter_continuation(
                    continue_from=whole_dataset_evaluation,
                    features={
                        feature_keys.FilteringFeatures.TIMES:
                        times[None, -1] + 2,
                        feature_keys.FilteringFeatures.VALUES:
                        values[None, -1] + 2.,
                        "exogenous": values[None, -1, None] + 12.
                    },
                    signatures=signatures,
                    session=sess)
                # Test that prediction and filtering can continue from filtering output
                second_saved_prediction = saved_model_utils.predict_continuation(
                    continue_from=first_filtering,
                    steps=1,
                    exogenous_features={
                        "exogenous":
                        numpy.arange(1, dtype=dtype.as_numpy_dtype)[None, :,
                                                                    None]
                    },
                    signatures=signatures,
                    session=sess)
                self.assertEqual(
                    times[-1] + 3,
                    numpy.squeeze(second_saved_prediction[
                        feature_keys.PredictionResults.TIMES]))
                saved_model_utils.filter_continuation(
                    continue_from=first_filtering,
                    features={
                        feature_keys.FilteringFeatures.TIMES: times[-1] + 3,
                        feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,
                        "exogenous": values[-1, None] + 13.
                    },
                    signatures=signatures,
                    session=sess)

                # Test cold starting
                six.assertCountEqual(
                    self, [
                        feature_keys.FilteringFeatures.TIMES,
                        feature_keys.FilteringFeatures.VALUES, "exogenous"
                    ],
                    signatures.signature_def[feature_keys.SavedModelLabels.
                                             COLD_START_FILTER].inputs.keys())
                batch_numpy_times = numpy.tile(
                    numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
                batch_numpy_values = numpy.ones([10, 30, 1])
                state = saved_model_utils.cold_start_filter(
                    signatures=signatures,
                    session=sess,
                    features={
                        feature_keys.FilteringFeatures.TIMES:
                        batch_numpy_times,
                        feature_keys.FilteringFeatures.VALUES:
                        batch_numpy_values,
                        "exogenous": 10. + batch_numpy_values
                    })
                predict_times = numpy.tile(
                    numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
                predictions = saved_model_utils.predict_continuation(
                    continue_from=state,
                    times=predict_times,
                    exogenous_features={
                        "exogenous":
                        numpy.tile(
                            numpy.arange(15, dtype=dtype.as_numpy_dtype),
                            (10, ))[None, :, None]
                    },
                    signatures=signatures,
                    session=sess)
                self.assertAllEqual([10, 15, 1], predictions["mean"].shape)