def test_update_scatter_gather_indices(self, mode): """Test for group size > 1.""" params = { 'num_shuffles_train': 2, 'num_shuffles_eval': 2, 'num_shuffles_predict': 2, } with tf.Graph().as_default(): tf.compat.v1.set_random_seed(2) with tf.compat.v1.Session() as sess: ranking_model = model._GroupwiseRankingModel(None, group_size=2) ranking_model._update_scatter_gather_indices( tf.convert_to_tensor([[True, True, False]]), mode, params) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list(), [1, 6, 2, 2]) self.assertEqual(ranking_model._indices_mask.get_shape().as_list(), [1, 6]) feature_gather_indices, indices_mask = sess.run([ ranking_model._feature_gather_indices, ranking_model._indices_mask ]) self.assertAllEqual( feature_gather_indices, [[ [[0, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [0, 1]], # shuffle 1. [[0, 1], [0, 0]], [[0, 0], [0, 1]], [[0, 1], [0, 0]], # shuffle 2. ]]) self.assertAllEqual(indices_mask, [[True, True, False, True, True, False]])
def test_compute_logits(self, mode): group_size = 2 params = { 'num_shuffles_train': 2, 'num_shuffles_eval': 2, 'num_shuffles_predict': 2, } def _dummy_score_fn(context_features, group_features, mode, params, config): del [mode, params, config] # 'context': [batch_size * num_groups, 1] # 'example_f1': [batch_size * num_groups, group_size, 1] logits = tf.expand_dims( context_features['context'], axis=1) + group_features['example_f1'] logits = tf.reshape(logits, [-1, group_size]) # Add the shape of the logits to differentiate number of shuffles. return logits + tf.cast(tf.shape(logits)[0], tf.float32) with tf.Graph().as_default(): tf.compat.v1.set_random_seed(1) with tf.compat.v1.Session() as sess: ranking_model = model._GroupwiseRankingModel( _dummy_score_fn, group_size=group_size, transform_fn=feature.make_identity_transform_fn(['context']), ) # batch_size = 1, list_size = 3, is_valid = [True, True, False] features = { 'context': [[1.]], 'example_f1': [[[1.], [2.], [3.]]], } labels = [[1., 0, -1]] # No params. logits = sess.run( ranking_model.compute_logits(features, labels, mode, None, None)) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list(), [1, 3, 2, 2]) self.assertAllEqual(logits, [[5., 6., 0.]]) # Trigger params. logits = sess.run( ranking_model.compute_logits(features, labels, mode, params, None)) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list(), [1, 6, 2, 2]) self.assertAllEqual(logits, [[8., 9., 0.]]) # batch_size = 1, list_size = 3, is_valid = [True, True, True] features = { 'context': [[1.]], 'example_f1': [[[1.], [2.], [0.]]], } labels = [[1., 0, 1]] logits = sess.run( ranking_model.compute_logits(features, labels, mode, params, None)) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list(), [1, 6, 2, 2]) self.assertAllEqual(logits, [[8., 9., 7.]])
def test_update_scatter_gather_indices_groupsize_1(self): """Test for group size = 1.""" with tf.Graph().as_default(): tf.compat.v1.set_random_seed(1) with tf.compat.v1.Session() as sess: ranking_model = model._GroupwiseRankingModel(None, group_size=1) ranking_model._update_scatter_gather_indices( tf.convert_to_tensor([[True, True, False]]), tf.estimator.ModeKeys.TRAIN, None) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list(), [1, 3, 1, 2]) self.assertEqual(ranking_model._indices_mask.get_shape().as_list(), [1, 3]) feature_gather_indices, indices_mask = sess.run([ ranking_model._feature_gather_indices, ranking_model._indices_mask ]) self.assertAllEqual(feature_gather_indices, [[[[0, 0]], [[0, 1]], [[0, 0]]]]) self.assertAllEqual(indices_mask, [[True, True, False]])
def test_update_scatter_gather_indices_predict_no_shuffle(self): """Test for group size > 1 and mode = PREDICT.""" with tf.Graph().as_default(): tf.compat.v1.set_random_seed(1) with tf.compat.v1.Session() as sess: ranking_model = model._GroupwiseRankingModel(None, group_size=2) ranking_model._update_scatter_gather_indices( tf.convert_to_tensor(value=[[True, True, True]]), tf.estimator.ModeKeys.PREDICT, None) self.assertEqual( ranking_model._feature_gather_indices.get_shape().as_list( ), [1, 3, 2, 2]) self.assertEqual( ranking_model._indices_mask.get_shape().as_list(), [1, 3]) feature_gather_indices, indices_mask = sess.run([ ranking_model._feature_gather_indices, ranking_model._indices_mask ]) self.assertAllEqual( feature_gather_indices, [[[[0, 0], [0, 1]], [[0, 1], [0, 2]], [[0, 2], [0, 0]]]]) self.assertAllEqual(indices_mask, [[True, True, True]])