def _train_and_check_params(self, example, max_neighbors, weight, bias,
                                expected_grad_from_weight,
                                expected_grad_from_bias):
        """Runs training for one step and verifies gradient-based updates."""
        def embedding_fn(features, unused_mode):
            # Computes y = w*x
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE,
                                   auxiliary_name_scope=False):
                weight_tensor = tf.reshape(tf.get_variable(
                    WEIGHT_VARIABLE,
                    shape=[2, 1],
                    partitioner=tf.fixed_size_partitioner(1)),
                                           shape=[-1, 2])

            x_tensor = tf.reshape(features[FEATURE_NAME], shape=[-1, 2])
            return tf.reduce_sum(tf.multiply(weight_tensor, x_tensor),
                                 1,
                                 keep_dims=True)

        def optimizer_fn():
            return tf.train.GradientDescentOptimizer(LEARNING_RATE)

        base_est = self.build_linear_regressor(weight=weight,
                                               weight_shape=[2, 1],
                                               bias=bias,
                                               bias_shape=[1])

        graph_reg_config = nsl_configs.make_graph_reg_config(
            max_neighbors=max_neighbors, multiplier=1)
        graph_reg_est = nsl_estimator.add_graph_regularization(
            base_est,
            embedding_fn,
            optimizer_fn,
            graph_reg_config=graph_reg_config)

        input_fn = single_example_input_fn(example,
                                           input_shape=[2],
                                           max_neighbors=max_neighbors)
        graph_reg_est.train(input_fn=input_fn, steps=1)

        # Compute the new bias and weight values based on the gradients.
        expected_bias = bias - LEARNING_RATE * (expected_grad_from_bias)
        expected_weight = weight - LEARNING_RATE * (expected_grad_from_weight)

        # Check that the parameters of the linear regressor have the correct values.
        self.assertAllClose(expected_bias,
                            graph_reg_est.get_variable_value(BIAS_VARIABLE))
        self.assertAllClose(expected_weight,
                            graph_reg_est.get_variable_value(WEIGHT_VARIABLE))
    def _train_and_check_eval_results(self, train_example, test_example,
                                      max_neighbors, weight, bias):
        """Verifies evaluation results for the graph-regularized model."""
        def embedding_fn(features, unused_mode):
            # Computes y = w*x
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE,
                                   auxiliary_name_scope=False):
                weight_tensor = tf.reshape(tf.get_variable(
                    WEIGHT_VARIABLE,
                    shape=[2, 1],
                    partitioner=tf.fixed_size_partitioner(1)),
                                           shape=[-1, 2])

            x_tensor = tf.reshape(features[FEATURE_NAME], shape=[-1, 2])
            return tf.reduce_sum(tf.multiply(weight_tensor, x_tensor),
                                 1,
                                 keep_dims=True)

        def optimizer_fn():
            return tf.train.GradientDescentOptimizer(LEARNING_RATE)

        base_est = self.build_linear_regressor(weight=weight,
                                               weight_shape=[2, 1],
                                               bias=bias,
                                               bias_shape=[1])

        graph_reg_config = nsl_configs.make_graph_reg_config(
            max_neighbors=max_neighbors, multiplier=1)
        graph_reg_est = nsl_estimator.add_graph_regularization(
            base_est,
            embedding_fn,
            optimizer_fn,
            graph_reg_config=graph_reg_config)

        train_input_fn = single_example_input_fn(train_example,
                                                 input_shape=[2],
                                                 max_neighbors=max_neighbors)
        graph_reg_est.train(input_fn=train_input_fn, steps=1)

        # Evaluating the graph-regularized model should yield the same results
        # as evaluating the base model because model paramters are shared.
        eval_input_fn = single_example_input_fn(test_example,
                                                input_shape=[2],
                                                max_neighbors=0)
        graph_eval_results = graph_reg_est.evaluate(input_fn=eval_input_fn)
        base_eval_results = base_est.evaluate(input_fn=eval_input_fn)
        self.assertAllClose(base_eval_results, graph_eval_results)
    def _create_and_compile_graph_reg_model(model_fn, weight, max_neighbors):
      """Creates and compiles a graph regularized model.

      Args:
        model_fn: A function that builds a linear regression model.
        weight: Initial value for the weights variable in the linear regressor.
        max_neighbors: The maximum number of neighbors for graph regularization.

      Returns:
        A pair containing the unregularized model and the graph regularized
        model as `tf.keras.Model` instances.
      """
      model = model_fn((2,), weight)
      graph_reg_config = configs.make_graph_reg_config(
          max_neighbors=max_neighbors, multiplier=1)
      graph_reg_model = graph_regularization.GraphRegularization(
          model, graph_reg_config)
      graph_reg_model.compile(
          optimizer=tf.keras.optimizers.SGD(LEARNING_RATE), loss='MSE')
      return model, graph_reg_model
    def test_graph_reg_wrapper_no_training(self):
        """Test that predictions are unaffected when there is no training."""
        # Base model: y = x + 2
        base_est = self.build_linear_regressor(weight=[[1.0]],
                                               weight_shape=[1, 1],
                                               bias=[2.0],
                                               bias_shape=[1])

        def embedding_fn(features, unused_mode):
            # Apply the same model, i.e, y = x + 2.
            # Use broadcasting to do element-wise addition.
            return tf.math.add(features[FEATURE_NAME], [2.0])

        graph_reg_config = nsl_configs.make_graph_reg_config(max_neighbors=1)
        graph_reg_est = nsl_estimator.add_graph_regularization(
            base_est, embedding_fn, graph_reg_config=graph_reg_config)

        # Consider only one neighbor for the input sample.
        example = """
                features {
                  feature {
                    key: "x"
                    value: { float_list { value: [ 1.0 ] } }
                  }
                  feature {
                    key: "NL_nbr_0_x"
                    value: { float_list { value: [ 2.0 ] } }
                  }
                  feature {
                    key: "NL_nbr_0_weight"
                    value: { float_list { value: 1.0 } }
                  }
               }
              """

        input_fn = single_example_input_fn(example,
                                           input_shape=[1],
                                           max_neighbors=0)
        predictions = graph_reg_est.predict(input_fn=input_fn)
        predicted_scores = [x['predictions'] for x in predictions]
        self.assertAllClose([[3.0]], predicted_scores)
    def test_graph_reg_wrapper_saving_batch_statistics(self):
        """Verifies that batch statistics in batch-norm layers are saved."""
        def optimizer_fn():
            return tf.train.GradientDescentOptimizer(0.005)

        def embedding_fn(features, mode):
            input_layer = features[FEATURE_NAME]
            with tf.compat.v1.variable_scope('hidden_layer',
                                             reuse=tf.AUTO_REUSE):
                hidden_layer = tf.compat.v1.layers.dense(input_layer,
                                                         units=4,
                                                         activation=tf.nn.relu)
                batch_norm_layer = tf.compat.v1.layers.batch_normalization(
                    hidden_layer,
                    training=(mode == tf.estimator.ModeKeys.TRAIN))
            return batch_norm_layer

        def model_fn(features, labels, mode, params=None, config=None):
            del params, config
            embeddings = embedding_fn(features, mode)
            with tf.compat.v1.variable_scope('logit', reuse=tf.AUTO_REUSE):
                logits = tf.compat.v1.layers.dense(embeddings, units=1)
            predictions = tf.argmax(logits, 1)
            if mode == tf.estimator.ModeKeys.PREDICT:
                return tf.estimator.EstimatorSpec(mode=mode,
                                                  predictions={
                                                      'logits': logits,
                                                      'predictions':
                                                      predictions
                                                  })

            loss = tf.losses.sigmoid_cross_entropy(labels, logits)
            if mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

            optimizer = optimizer_fn()
            train_op = optimizer.minimize(
                loss, global_step=tf.compat.v1.train.get_global_step())
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            train_op = tf.group(train_op, *update_ops)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op)

        def input_fn():
            nbr_feature = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, FEATURE_NAME)
            nbr_weight = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0,
                                         NBR_WEIGHT_SUFFIX)
            features = {
                FEATURE_NAME: tf.constant([[0.1, 0.9], [0.8, 0.2]]),
                nbr_feature: tf.constant([[0.11, 0.89], [0.81, 0.21]]),
                nbr_weight: tf.constant([[0.9], [0.8]]),
            }
            labels = tf.constant([[1], [0]])
            return tf.data.Dataset.from_tensor_slices(
                (features, labels)).batch(2)

        base_est = tf.estimator.Estimator(model_fn, model_dir=self.model_dir)
        graph_reg_config = nsl_configs.make_graph_reg_config(max_neighbors=1,
                                                             multiplier=1)
        graph_reg_est = nsl_estimator.add_graph_regularization(
            base_est,
            embedding_fn,
            optimizer_fn,
            graph_reg_config=graph_reg_config)
        graph_reg_est.train(input_fn, steps=1)

        moving_mean = graph_reg_est.get_variable_value(
            'hidden_layer/batch_normalization/moving_mean')
        moving_variance = graph_reg_est.get_variable_value(
            'hidden_layer/batch_normalization/moving_variance')
        self.assertNotAllClose(moving_mean, np.zeros(moving_mean.shape))
        self.assertNotAllClose(moving_variance, np.ones(moving_variance.shape))