def test_predict_base_model(self):
    model = build_linear_functional_model(
        input_shape=(2,), weights=np.array([1.0, -1.0]))
    inputs = {FEATURE_NAME: tf.constant([[5.0, 3.0]])}

    graph_reg_model = graph_regularization.GraphRegularization(model)
    graph_reg_model.compile(optimizer=tf.keras.optimizers.SGD(0.01), loss='MSE')

    prediction = model.predict(x=inputs, steps=1, batch_size=1)

    self.assertAllEqual([[1 * 5.0 + (-1.0) * 3.0]], prediction)
    def _create_and_compile_graph_reg_model(model_fn, weight, max_neighbors):
      """Creates and compiles a graph regularized model.

      Args:
        model_fn: A function that builds a linear regression model.
        weight: Initial value for the weights variable in the linear regressor.
        max_neighbors: The maximum number of neighbors for graph regularization.

      Returns:
        A pair containing the unregularized model and the graph regularized
        model as `tf.keras.Model` instances.
      """
      model = model_fn((2,), weight)
      graph_reg_config = configs.make_graph_reg_config(
          max_neighbors=max_neighbors, multiplier=1)
      graph_reg_model = graph_regularization.GraphRegularization(
          model, graph_reg_config)
      graph_reg_model.compile(
          optimizer=tf.keras.optimizers.SGD(LEARNING_RATE), loss='MSE')
      return model, graph_reg_model