Example #1
0
  def test_update_scatter_gather_indices(self, mode):
    """Test for group size > 1."""
    params = {
        'num_shuffles_train': 2,
        'num_shuffles_eval': 2,
        'num_shuffles_predict': 2,
    }
    with tf.Graph().as_default():
      tf.compat.v1.set_random_seed(2)
      with tf.compat.v1.Session() as sess:
        ranking_model = model._GroupwiseRankingModel(None, group_size=2)
        ranking_model._update_scatter_gather_indices(
            tf.convert_to_tensor([[True, True, False]]), mode, params)
        self.assertEqual(
            ranking_model._feature_gather_indices.get_shape().as_list(),
            [1, 6, 2, 2])
        self.assertEqual(ranking_model._indices_mask.get_shape().as_list(),
                         [1, 6])
        feature_gather_indices, indices_mask = sess.run([
            ranking_model._feature_gather_indices, ranking_model._indices_mask
        ])

        self.assertAllEqual(
            feature_gather_indices,
            [[
                [[0, 0], [0, 1]],
                [[0, 1], [0, 0]],
                [[0, 0], [0, 1]],  # shuffle 1.
                [[0, 1], [0, 0]],
                [[0, 0], [0, 1]],
                [[0, 1], [0, 0]],  # shuffle 2.
            ]])
        self.assertAllEqual(indices_mask,
                            [[True, True, False, True, True, False]])
Example #2
0
  def test_compute_logits(self, mode):
    group_size = 2
    params = {
        'num_shuffles_train': 2,
        'num_shuffles_eval': 2,
        'num_shuffles_predict': 2,
    }

    def _dummy_score_fn(context_features, group_features, mode, params, config):
      del [mode, params, config]
      # 'context': [batch_size * num_groups, 1]
      # 'example_f1': [batch_size * num_groups, group_size, 1]
      logits = tf.expand_dims(
          context_features['context'], axis=1) + group_features['example_f1']
      logits = tf.reshape(logits, [-1, group_size])
      # Add the shape of the logits to differentiate number of shuffles.
      return logits + tf.cast(tf.shape(logits)[0], tf.float32)

    with tf.Graph().as_default():
      tf.compat.v1.set_random_seed(1)
      with tf.compat.v1.Session() as sess:
        ranking_model = model._GroupwiseRankingModel(
            _dummy_score_fn,
            group_size=group_size,
            transform_fn=feature.make_identity_transform_fn(['context']),
        )

        # batch_size = 1, list_size = 3, is_valid = [True, True, False]
        features = {
            'context': [[1.]],
            'example_f1': [[[1.], [2.], [3.]]],
        }
        labels = [[1., 0, -1]]
        # No params.
        logits = sess.run(
            ranking_model.compute_logits(features, labels, mode, None, None))
        self.assertEqual(
            ranking_model._feature_gather_indices.get_shape().as_list(),
            [1, 3, 2, 2])
        self.assertAllEqual(logits, [[5., 6., 0.]])
        # Trigger params.
        logits = sess.run(
            ranking_model.compute_logits(features, labels, mode, params, None))
        self.assertEqual(
            ranking_model._feature_gather_indices.get_shape().as_list(),
            [1, 6, 2, 2])
        self.assertAllEqual(logits, [[8., 9., 0.]])

        # batch_size = 1, list_size = 3, is_valid = [True, True, True]
        features = {
            'context': [[1.]],
            'example_f1': [[[1.], [2.], [0.]]],
        }
        labels = [[1., 0, 1]]
        logits = sess.run(
            ranking_model.compute_logits(features, labels, mode, params, None))
        self.assertEqual(
            ranking_model._feature_gather_indices.get_shape().as_list(),
            [1, 6, 2, 2])
        self.assertAllEqual(logits, [[8., 9., 7.]])
Example #3
0
 def test_update_scatter_gather_indices_groupsize_1(self):
   """Test for group size = 1."""
   with tf.Graph().as_default():
     tf.compat.v1.set_random_seed(1)
     with tf.compat.v1.Session() as sess:
       ranking_model = model._GroupwiseRankingModel(None, group_size=1)
       ranking_model._update_scatter_gather_indices(
           tf.convert_to_tensor([[True, True, False]]),
           tf.estimator.ModeKeys.TRAIN, None)
       self.assertEqual(
           ranking_model._feature_gather_indices.get_shape().as_list(),
           [1, 3, 1, 2])
       self.assertEqual(ranking_model._indices_mask.get_shape().as_list(),
                        [1, 3])
       feature_gather_indices, indices_mask = sess.run([
           ranking_model._feature_gather_indices, ranking_model._indices_mask
       ])
       self.assertAllEqual(feature_gather_indices,
                           [[[[0, 0]], [[0, 1]], [[0, 0]]]])
       self.assertAllEqual(indices_mask, [[True, True, False]])
Example #4
0
 def test_update_scatter_gather_indices_predict_no_shuffle(self):
     """Test for group size > 1 and mode = PREDICT."""
     with tf.Graph().as_default():
         tf.compat.v1.set_random_seed(1)
         with tf.compat.v1.Session() as sess:
             ranking_model = model._GroupwiseRankingModel(None,
                                                          group_size=2)
             ranking_model._update_scatter_gather_indices(
                 tf.convert_to_tensor(value=[[True, True, True]]),
                 tf.estimator.ModeKeys.PREDICT, None)
             self.assertEqual(
                 ranking_model._feature_gather_indices.get_shape().as_list(
                 ), [1, 3, 2, 2])
             self.assertEqual(
                 ranking_model._indices_mask.get_shape().as_list(), [1, 3])
             feature_gather_indices, indices_mask = sess.run([
                 ranking_model._feature_gather_indices,
                 ranking_model._indices_mask
             ])
             self.assertAllEqual(
                 feature_gather_indices,
                 [[[[0, 0], [0, 1]], [[0, 1], [0, 2]], [[0, 2], [0, 0]]]])
             self.assertAllEqual(indices_mask, [[True, True, True]])