Ejemplo n.º 1
0
    def testLinearlySeparableBinaryDataNoKernels(self):
        """Tests classifier w/o kernels (log. regression) for lin-separable data."""

        feature1 = layers.real_valued_column('feature1')
        feature2 = layers.real_valued_column('feature2')

        logreg_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[feature1, feature2])
        logreg_classifier.fit(input_fn=_linearly_separable_binary_input_fn,
                              steps=100)

        metrics = logreg_classifier.evaluate(
            input_fn=_linearly_separable_binary_input_fn, steps=1)
        # Since the data is linearly separable, the classifier should have small
        # loss and perfect accuracy.
        self.assertLess(metrics['loss'], 0.1)
        self.assertEqual(metrics['accuracy'], 1.0)

        # As a result, it should assign higher probability to class 1 for the 1st
        # and 3rd example and higher probability to class 0 for the second example.
        logreg_prob_predictions = list(
            logreg_classifier.predict_proba(
                input_fn=_linearly_separable_binary_input_fn))
        self.assertGreater(logreg_prob_predictions[0][1], 0.5)
        self.assertGreater(logreg_prob_predictions[1][0], 0.5)
        self.assertGreater(logreg_prob_predictions[2][1], 0.5)
Ejemplo n.º 2
0
    def testInvalidNumberOfClasses(self):
        """ValueError raised when the kernel mappers provided have invalid type."""

        feature = layers.real_valued_column('feature')
        with self.assertRaises(ValueError):
            _ = kernel_estimators.KernelLinearClassifier(
                feature_columns=[feature], n_classes=1)
Ejemplo n.º 3
0
    def testMulticlassDataWithAndWithoutKernels(self):
        """Tests classifier w/ and w/o kernels on multiclass data."""
        feature_column = layers.real_valued_column('feature', dimension=4)

        # Metrics for linear classifier (no kernels).
        linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[feature_column], n_classes=3)
        linear_classifier.fit(input_fn=test_data.iris_input_multiclass_fn,
                              steps=50)
        linear_metrics = linear_classifier.evaluate(
            input_fn=test_data.iris_input_multiclass_fn, steps=1)
        linear_loss = linear_metrics['loss']
        linear_accuracy = linear_metrics['accuracy']

        # Using kernel mappers allows to discover non-linearities in data (via RBF
        # kernel approximation), reduces loss and increases accuracy.
        kernel_mappers = {
            feature_column: [
                RandomFourierFeatureMapper(input_dim=4,
                                           output_dim=50,
                                           stddev=1.0,
                                           name='rffm')
            ]
        }
        kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[], n_classes=3, kernel_mappers=kernel_mappers)
        kernel_linear_classifier.fit(
            input_fn=test_data.iris_input_multiclass_fn, steps=50)
        kernel_linear_metrics = kernel_linear_classifier.evaluate(
            input_fn=test_data.iris_input_multiclass_fn, steps=1)
        kernel_linear_loss = kernel_linear_metrics['loss']
        kernel_linear_accuracy = kernel_linear_metrics['accuracy']
        self.assertLess(kernel_linear_loss, linear_loss)
        self.assertGreater(kernel_linear_accuracy, linear_accuracy)
Ejemplo n.º 4
0
def _add_bias_column(feature_columns, columns_to_tensors, bias_variable,
                     columns_to_variables):
    """Adds a fake bias feature column filled with all 1s."""
    # TODO(b/31008490): Move definition to a common constants place.
    bias_column_name = "tf_virtual_bias_column"
    if any(col.name is bias_column_name for col in feature_columns):
        raise ValueError("%s is a reserved column name." % bias_column_name)
    if not feature_columns:
        raise ValueError("feature_columns can't be empty.")

    # Loop through input tensors until we can figure out batch_size.
    batch_size = None
    for column in columns_to_tensors.values():
        if isinstance(column, tuple):
            column = column[0]
        if isinstance(column, sparse_tensor.SparseTensor):
            shape = tensor_util.constant_value(column.dense_shape)
            if shape is not None:
                batch_size = shape[0]
                break
        else:
            batch_size = array_ops.shape(column)[0]
            break
    if batch_size is None:
        raise ValueError("Could not infer batch size from input features.")

    bias_column = layers.real_valued_column(bias_column_name)
    columns_to_tensors[bias_column] = array_ops.ones([batch_size, 1],
                                                     dtype=dtypes.float32)
    columns_to_variables[bias_column] = [bias_variable]
Ejemplo n.º 5
0
    def testInvalidKernelMapper(self):
        """ValueError raised when the kernel mappers provided have invalid type."""
        class DummyKernelMapper(object):
            def __init__(self):
                pass

        feature = layers.real_valued_column('feature')
        kernel_mappers = {feature: [DummyKernelMapper()]}
        with self.assertRaises(ValueError):
            _ = kernel_estimators.KernelLinearClassifier(
                feature_columns=[feature], kernel_mappers=kernel_mappers)
Ejemplo n.º 6
0
    def testClassifierWithAndWithoutKernelsNoRealValuedColumns(self):
        """Tests kernels have no effect for non-real valued columns ."""
        def input_fn():
            return {
                'price':
                constant_op.constant([[0.4], [0.6], [0.3]]),
                'country':
                sparse_tensor.SparseTensor(values=['IT', 'US', 'GB'],
                                           indices=[[0, 0], [1, 3], [2, 1]],
                                           dense_shape=[3, 5]),
            }, constant_op.constant([[1], [0], [1]])

        price = layers.real_valued_column('price')
        country = layers.sparse_column_with_hash_bucket('country',
                                                        hash_bucket_size=5)

        linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[price, country])
        linear_classifier.fit(input_fn=input_fn, steps=100)
        linear_metrics = linear_classifier.evaluate(input_fn=input_fn, steps=1)
        linear_loss = linear_metrics['loss']
        linear_accuracy = linear_metrics['accuracy']

        kernel_mappers = {
            country: [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
        }

        kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[price, country], kernel_mappers=kernel_mappers)
        kernel_linear_classifier.fit(input_fn=input_fn, steps=100)
        kernel_linear_metrics = kernel_linear_classifier.evaluate(
            input_fn=input_fn, steps=1)
        kernel_linear_loss = kernel_linear_metrics['loss']
        kernel_linear_accuracy = kernel_linear_metrics['accuracy']

        # The kernel mapping is applied to a non-real-valued feature column and so
        # it should have no effect on the model. The loss and accuracy of the
        # "kernelized" model should match the loss and accuracy of the initial model
        # (without kernels).
        self.assertAlmostEqual(linear_loss, kernel_linear_loss, delta=0.01)
        self.assertAlmostEqual(linear_accuracy,
                               kernel_linear_accuracy,
                               delta=0.01)
Ejemplo n.º 7
0
    def testVariablesWithAndWithoutKernels(self):
        """Tests variables w/ and w/o kernel."""
        multi_dim_feature = layers.real_valued_column('multi_dim_feature',
                                                      dimension=2)

        linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[multi_dim_feature])
        linear_classifier.fit(input_fn=_linearly_inseparable_binary_input_fn,
                              steps=50)
        linear_variables = linear_classifier.get_variable_names()
        self.assertIn('linear/multi_dim_feature/weight', linear_variables)
        self.assertIn('linear/bias_weight', linear_variables)
        linear_weights = linear_classifier.get_variable_value(
            'linear/multi_dim_feature/weight')
        linear_bias = linear_classifier.get_variable_value(
            'linear/bias_weight')

        kernel_mappers = {
            multi_dim_feature:
            [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
        }
        kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[], kernel_mappers=kernel_mappers)
        kernel_linear_classifier.fit(
            input_fn=_linearly_inseparable_binary_input_fn, steps=50)
        kernel_linear_variables = kernel_linear_classifier.get_variable_names()
        self.assertIn('linear/multi_dim_feature_MAPPED/weight',
                      kernel_linear_variables)
        self.assertIn('linear/bias_weight', kernel_linear_variables)
        kernel_linear_weights = kernel_linear_classifier.get_variable_value(
            'linear/multi_dim_feature_MAPPED/weight')
        kernel_linear_bias = kernel_linear_classifier.get_variable_value(
            'linear/bias_weight')

        # The feature column used for linear classification (no kernels) has
        # dimension 2 so the model will learn a 2-dimension weights vector (and a
        # scalar for the bias). In the kernelized model, the features are mapped to
        # a 30-dimensional feature space and so the weights variable will also have
        # dimension 30.
        self.assertEqual(2, len(linear_weights))
        self.assertEqual(1, len(linear_bias))
        self.assertEqual(30, len(kernel_linear_weights))
        self.assertEqual(1, len(kernel_linear_bias))
Ejemplo n.º 8
0
    def testLinearlyInseparableBinaryDataWithAndWithoutKernels(self):
        """Tests classifier w/ and w/o kernels on non-linearly-separable data."""
        multi_dim_feature = layers.real_valued_column('multi_dim_feature',
                                                      dimension=2)

        # Data points are non-linearly separable so there will be at least one
        # mis-classified sample (accuracy < 0.8). In fact, the loss is minimized for
        # w1=w2=0.0, in which case each example incurs a loss of ln(2). The overall
        # (average) loss should then be ln(2) and the logits should be approximately
        # 0.0 for each sample.
        logreg_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[multi_dim_feature])
        logreg_classifier.fit(input_fn=_linearly_inseparable_binary_input_fn,
                              steps=50)
        logreg_metrics = logreg_classifier.evaluate(
            input_fn=_linearly_inseparable_binary_input_fn, steps=1)
        logreg_loss = logreg_metrics['loss']
        logreg_accuracy = logreg_metrics['accuracy']
        logreg_predictions = logreg_classifier.predict(
            input_fn=_linearly_inseparable_binary_input_fn, as_iterable=False)
        self.assertAlmostEqual(logreg_loss, np.log(2), places=3)
        self.assertLess(logreg_accuracy, 0.8)
        self.assertAllClose(logreg_predictions['logits'],
                            [[0.0], [0.0], [0.0], [0.0]])

        # Using kernel mappers allows to discover non-linearities in data. Mapping
        # the data to a higher dimensional feature space using approx RBF kernels,
        # substantially reduces the loss and leads to perfect classification
        # accuracy.
        kernel_mappers = {
            multi_dim_feature:
            [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
        }
        kernelized_logreg_classifier = kernel_estimators.KernelLinearClassifier(
            feature_columns=[], kernel_mappers=kernel_mappers)
        kernelized_logreg_classifier.fit(
            input_fn=_linearly_inseparable_binary_input_fn, steps=50)
        kernelized_logreg_metrics = kernelized_logreg_classifier.evaluate(
            input_fn=_linearly_inseparable_binary_input_fn, steps=1)
        kernelized_logreg_loss = kernelized_logreg_metrics['loss']
        kernelized_logreg_accuracy = kernelized_logreg_metrics['accuracy']
        self.assertLess(kernelized_logreg_loss, 0.2)
        self.assertEqual(kernelized_logreg_accuracy, 1.0)
def _dnn_tree_combined_model_fn(
        features,
        labels,
        mode,
        head,
        dnn_hidden_units,
        dnn_feature_columns,
        tree_learner_config,
        num_trees,
        tree_examples_per_layer,
        config=None,
        dnn_optimizer="Adagrad",
        dnn_activation_fn=nn.relu,
        dnn_dropout=None,
        dnn_input_layer_partitioner=None,
        dnn_input_layer_to_tree=True,
        dnn_steps_to_train=10000,
        predict_with_tree_only=False,
        tree_feature_columns=None,
        tree_center_bias=False,
        dnn_to_tree_distillation_param=None,
        use_core_versions=False,
        output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
        override_global_step_value=None):
    """DNN and GBDT combined model_fn.

  Args:
    features: `dict` of `Tensor` objects.
    labels: Labels used to train on.
    mode: Mode we are in. (TRAIN/EVAL/INFER)
    head: A `Head` instance.
    dnn_hidden_units: List of hidden units per layer.
    dnn_feature_columns: An iterable containing all the feature columns
      used by the model's DNN.
    tree_learner_config: A config for the tree learner.
    num_trees: Number of trees to grow model to after training DNN.
    tree_examples_per_layer: Number of examples to accumulate before
      growing the tree a layer. This value has a big impact on model
      quality and should be set equal to the number of examples in
      training dataset if possible. It can also be a function that computes
      the number of examples based on the depth of the layer that's
      being built.
    config: `RunConfig` of the estimator.
    dnn_optimizer: string, `Optimizer` object, or callable that defines the
      optimizer to use for training the DNN. If `None`, will use the Adagrad
      optimizer with default learning rate of 0.001.
    dnn_activation_fn: Activation function applied to each layer of the DNN.
      If `None`, will use `tf.nn.relu`.
    dnn_dropout: When not `None`, the probability to drop out a given
      unit in the DNN.
    dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
      Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
    dnn_input_layer_to_tree: Whether to provide the DNN's input layer
    as a feature to the tree.
    dnn_steps_to_train: Number of steps to train dnn for before switching
      to gbdt.
    predict_with_tree_only: Whether to use only the tree model output as the
      final prediction.
    tree_feature_columns: An iterable containing all the feature columns
      used by the model's boosted trees. If dnn_input_layer_to_tree is
      set to True, these features are in addition to dnn_feature_columns.
    tree_center_bias: Whether a separate tree should be created for
      first fitting the bias.
    dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
      float defines the weight of the distillation loss, and the loss_fn, for
      computing distillation loss, takes dnn_logits, tree_logits and weight
      tensor. If the entire tuple is None, no distillation will be applied. If
      only the loss_fn is None, we will take the sigmoid/softmax cross entropy
      loss be default. When distillation is applied, `predict_with_tree_only`
      will be set to True.
    use_core_versions: Whether feature columns and loss are from the core (as
      opposed to contrib) version of tensorflow.
    output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
      (new interface).
    override_global_step_value: If after the training is done, global step
      value must be reset to this value. This is particularly useful for hyper
      parameter tuning, which can't recognize early stopping due to the number
      of trees. If None, no override of global step will happen.

  Returns:
    A `ModelFnOps` object.
  Raises:
    ValueError: if inputs are not valid.
  """
    if not isinstance(features, dict):
        raise ValueError("features should be a dictionary of `Tensor`s. "
                         "Given type: {}".format(type(features)))

    if not dnn_feature_columns:
        raise ValueError("dnn_feature_columns must be specified")

    if dnn_to_tree_distillation_param:
        if not predict_with_tree_only:
            logging.warning(
                "update predict_with_tree_only to True since distillation"
                "is specified.")
            predict_with_tree_only = True

    # Build DNN Logits.
    dnn_parent_scope = "dnn"
    dnn_partitioner = dnn_input_layer_partitioner or (
        partitioned_variables.min_max_variable_partitioner(
            max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))

    if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC
            and not use_core_versions):
        raise ValueError("You must use core versions with Estimator Spec")
    global_step = training_util.get_global_step()

    with variable_scope.variable_scope(dnn_parent_scope,
                                       values=tuple(six.itervalues(features)),
                                       partitioner=dnn_partitioner):

        with variable_scope.variable_scope(
                "input_from_feature_columns",
                values=tuple(six.itervalues(features)),
                partitioner=dnn_partitioner) as input_layer_scope:
            if use_core_versions:
                input_layer = feature_column_lib.input_layer(
                    features=features,
                    feature_columns=dnn_feature_columns,
                    weight_collections=[dnn_parent_scope])
            else:
                input_layer = layers.input_from_feature_columns(
                    columns_to_tensors=features,
                    feature_columns=dnn_feature_columns,
                    weight_collections=[dnn_parent_scope],
                    scope=input_layer_scope)

        def dnn_logits_fn():
            """Builds the logits from the input layer."""
            previous_layer = input_layer
            for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
                with variable_scope.variable_scope(
                        "hiddenlayer_%d" % layer_id,
                        values=(previous_layer, )) as hidden_layer_scope:
                    net = layers.fully_connected(
                        previous_layer,
                        num_hidden_units,
                        activation_fn=dnn_activation_fn,
                        variables_collections=[dnn_parent_scope],
                        scope=hidden_layer_scope)
                    if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN:
                        net = layers.dropout(net,
                                             keep_prob=(1.0 - dnn_dropout))
                _add_hidden_layer_summary(net, hidden_layer_scope.name)
                previous_layer = net
            with variable_scope.variable_scope(
                    "logits", values=(previous_layer, )) as logits_scope:
                dnn_logits = layers.fully_connected(
                    previous_layer,
                    head.logits_dimension,
                    activation_fn=None,
                    variables_collections=[dnn_parent_scope],
                    scope=logits_scope)
            _add_hidden_layer_summary(dnn_logits, logits_scope.name)
            return dnn_logits

        if predict_with_tree_only and mode == model_fn.ModeKeys.INFER:
            dnn_logits = array_ops.constant(0.0)
            dnn_train_op_fn = control_flow_ops.no_op
        elif predict_with_tree_only and mode == model_fn.ModeKeys.EVAL:
            dnn_logits = control_flow_ops.cond(
                global_step > dnn_steps_to_train,
                lambda: array_ops.constant(0.0), dnn_logits_fn)
            dnn_train_op_fn = control_flow_ops.no_op
        else:
            dnn_logits = dnn_logits_fn()

            def dnn_train_op_fn(loss):
                """Returns the op to optimize the loss."""
                return optimizers.optimize_loss(
                    loss=loss,
                    global_step=training_util.get_global_step(),
                    learning_rate=_DNN_LEARNING_RATE,
                    optimizer=_get_optimizer(dnn_optimizer),
                    name=dnn_parent_scope,
                    variables=ops.get_collection(
                        ops.GraphKeys.TRAINABLE_VARIABLES,
                        scope=dnn_parent_scope),
                    # Empty summaries to prevent optimizers from logging training_loss.
                    summaries=[])

    # Build Tree Logits.
    with ops.device(global_step.device):
        ensemble_handle = model_ops.tree_ensemble_variable(
            stamp_token=0,
            tree_ensemble_config="",  # Initialize an empty ensemble.
            name="ensemble_model")

    tree_features = features.copy()
    if dnn_input_layer_to_tree:
        tree_features["dnn_input_layer"] = input_layer
        tree_feature_columns.append(
            layers.real_valued_column("dnn_input_layer"))
    gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
        is_chief=config.is_chief,
        num_ps_replicas=config.num_ps_replicas,
        ensemble_handle=ensemble_handle,
        center_bias=tree_center_bias,
        examples_per_layer=tree_examples_per_layer,
        learner_config=tree_learner_config,
        feature_columns=tree_feature_columns,
        logits_dimension=head.logits_dimension,
        features=tree_features,
        use_core_columns=use_core_versions)

    with ops.name_scope("gbdt"):
        predictions_dict = gbdt_model.predict(mode)
        tree_logits = predictions_dict["predictions"]

        def _tree_train_op_fn(loss):
            """Returns the op to optimize the loss."""
            if dnn_to_tree_distillation_param:
                loss_weight, loss_fn = dnn_to_tree_distillation_param
                # pylint: disable=protected-access
                if use_core_versions:
                    weight_tensor = head_lib._weight_tensor(
                        features, head._weight_column)
                else:
                    weight_tensor = head_lib._weight_tensor(
                        features, head.weight_column_name)
                # pylint: enable=protected-access
                dnn_logits_fixed = array_ops.stop_gradient(dnn_logits)

                if loss_fn is None:
                    # we create the loss_fn similar to the head loss_fn for
                    # multi_class_head used previously as the default one.
                    n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension
                    loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn(
                        n_classes)

                dnn_to_tree_distillation_loss = loss_weight * loss_fn(
                    dnn_logits_fixed, tree_logits, weight_tensor)
                summary.scalar("dnn_to_tree_distillation_loss",
                               dnn_to_tree_distillation_loss)
                loss += dnn_to_tree_distillation_loss

            update_op = gbdt_model.train(loss, predictions_dict, labels)
            with ops.control_dependencies(
                [update_op]), (ops.colocate_with(global_step)):
                update_op = state_ops.assign_add(global_step, 1).op
                return update_op

    if predict_with_tree_only:
        if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER:
            tree_train_logits = tree_logits
        else:
            tree_train_logits = control_flow_ops.cond(
                global_step > dnn_steps_to_train, lambda: tree_logits,
                lambda: dnn_logits)
    else:
        tree_train_logits = dnn_logits + tree_logits

    def _no_train_op_fn(loss):
        """Returns a no-op."""
        del loss
        return control_flow_ops.no_op()

    if tree_center_bias:
        num_trees += 1
    finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()

    if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS:
        model_fn_ops = head.create_model_fn_ops(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=_no_train_op_fn,
                                                logits=tree_train_logits)
        if mode != model_fn.ModeKeys.TRAIN:
            return model_fn_ops
        dnn_train_op = head.create_model_fn_ops(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=dnn_train_op_fn,
                                                logits=dnn_logits).train_op
        tree_train_op = head.create_model_fn_ops(
            features=tree_features,
            mode=mode,
            labels=labels,
            train_op_fn=_tree_train_op_fn,
            logits=tree_train_logits).train_op

        # Add the hooks
        model_fn_ops.training_hooks.extend([
            trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
                                        tree_train_op),
            trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
                                          finalized_trees,
                                          override_global_step_value)
        ])
        return model_fn_ops

    elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC:
        fusion_spec = head.create_estimator_spec(features=features,
                                                 mode=mode,
                                                 labels=labels,
                                                 train_op_fn=_no_train_op_fn,
                                                 logits=tree_train_logits)
        if mode != model_fn.ModeKeys.TRAIN:
            return fusion_spec
        dnn_spec = head.create_estimator_spec(features=features,
                                              mode=mode,
                                              labels=labels,
                                              train_op_fn=dnn_train_op_fn,
                                              logits=dnn_logits)
        tree_spec = head.create_estimator_spec(features=tree_features,
                                               mode=mode,
                                               labels=labels,
                                               train_op_fn=_tree_train_op_fn,
                                               logits=tree_train_logits)

        training_hooks = [
            trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
                                        tree_spec.train_op),
            trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
                                          finalized_trees,
                                          override_global_step_value)
        ]
        fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
                                           list(fusion_spec.training_hooks))
        return fusion_spec
Ejemplo n.º 10
0
 def test_savedmodel_state_override(self):
     random_model = RandomStateSpaceModel(
         state_dimension=5,
         state_noise_dimension=4,
         configuration=state_space_model.StateSpaceModelConfiguration(
             exogenous_feature_columns=[
                 layers.real_valued_column("exogenous")
             ],
             dtype=dtypes.float64,
             num_features=1))
     estimator = estimators.StateSpaceRegressor(
         model=random_model,
         optimizer=gradient_descent.GradientDescentOptimizer(0.1))
     combined_input_fn = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader({
             feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
             feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.],
             "exogenous": [-1., -2., -3., -4.]
         }))
     estimator.train(combined_input_fn, steps=1)
     export_location = estimator.export_saved_model(
         self.get_temp_dir(),
         estimator.build_raw_serving_input_receiver_fn())
     with ops.Graph().as_default() as graph:
         random_model.initialize_graph()
         with self.session(graph=graph) as session:
             variables.global_variables_initializer().run()
             evaled_start_state = session.run(
                 random_model.get_start_state())
     evaled_start_state = [
         state_element[None, ...] for state_element in evaled_start_state
     ]
     with ops.Graph().as_default() as graph:
         with self.session(graph=graph) as session:
             signatures = loader.load(session, [tag_constants.SERVING],
                                      export_location)
             first_split_filtering = saved_model_utils.filter_continuation(
                 continue_from={
                     feature_keys.FilteringResults.STATE_TUPLE:
                     evaled_start_state
                 },
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [1, 2],
                     feature_keys.FilteringFeatures.VALUES: [1., 2.],
                     "exogenous": [[-1.], [-2.]]
                 })
             second_split_filtering = saved_model_utils.filter_continuation(
                 continue_from=first_split_filtering,
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [3, 4],
                     feature_keys.FilteringFeatures.VALUES: [3., 4.],
                     "exogenous": [[-3.], [-4.]]
                 })
             combined_filtering = saved_model_utils.filter_continuation(
                 continue_from={
                     feature_keys.FilteringResults.STATE_TUPLE:
                     evaled_start_state
                 },
                 signatures=signatures,
                 session=session,
                 features={
                     feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4],
                     feature_keys.FilteringFeatures.VALUES:
                     [1., 2., 3., 4.],
                     "exogenous": [[-1.], [-2.], [-3.], [-4.]]
                 })
             split_predict = saved_model_utils.predict_continuation(
                 continue_from=second_split_filtering,
                 signatures=signatures,
                 session=session,
                 steps=1,
                 exogenous_features={"exogenous": [[[-5.]]]})
             combined_predict = saved_model_utils.predict_continuation(
                 continue_from=combined_filtering,
                 signatures=signatures,
                 session=session,
                 steps=1,
                 exogenous_features={"exogenous": [[[-5.]]]})
     for state_key, combined_state_value in combined_filtering.items():
         if state_key == feature_keys.FilteringResults.TIMES:
             continue
         self.assertAllClose(combined_state_value,
                             second_split_filtering[state_key])
     for prediction_key, combined_value in combined_predict.items():
         self.assertAllClose(combined_value, split_predict[prediction_key])