def test_organize_valid_indices(self): tf.compat.v1.set_random_seed(1) labels = [[1.0, 0.0, -1.0], [-1.0, 1.0, 2.0]] is_valid = utils.is_label_valid(labels) shuffled_indices = utils.shuffle_valid_indices(is_valid) organized_indices = utils.organize_valid_indices(is_valid, shuffle=False) with tf.compat.v1.Session() as sess: shuffled_indices = sess.run(shuffled_indices) self.assertAllEqual(shuffled_indices, [[[0, 1], [0, 0], [0, 2]], [[1, 1], [1, 2], [1, 0]]]) organized_indices = sess.run(organized_indices) self.assertAllEqual(organized_indices, [[[0, 0], [0, 1], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
def _form_group_indices_nd(is_valid, group_size, shuffle=False, seed=None): """Forms the indices for groups for gather_nd or scatter_nd. Args: is_valid: A boolen `Tensor` for entry validity with shape [batch_size, list_size]. group_size: An scalar int `Tensor` for the number of examples in a group. shuffle: A boolean that indicates whether valid indices should be shuffled when forming group indices. seed: Random seed for shuffle. Returns: A tuple of Tensors (indices, mask). The first has shape [batch_size, num_groups, group_size, 2] and it can be used in gather_nd or scatter_nd for group features. The second has the shape of [batch_size, num_groups] with value True for valid groups. """ with tf.name_scope(name='form_group_indices'): is_valid = tf.convert_to_tensor(value=is_valid) batch_size, list_size = tf.unstack(tf.shape(input=is_valid)) num_valid_entries = tf.reduce_sum(input_tensor=tf.cast(is_valid, dtype=tf.int32), axis=1) rw_indices, mask = _rolling_window_indices(list_size, group_size, num_valid_entries) # Valid indices of the tensor are shuffled and put on the top. # [batch_size, list_size, 2]. A determinstic op-level seed is set mainly for # unittest purpose. We can find a better way to avoid setting this seed # explicitly. shuffled_indices = utils.organize_valid_indices(is_valid, shuffle=shuffle, seed=seed) # Construct indices for gather_nd. # [batch_size, num_groups, group_size, 2] group_indices_nd = tf.expand_dims(rw_indices, axis=3) group_indices_nd = tf.concat([ tf.reshape(tf.range(batch_size), [-1, 1, 1, 1]) * tf.ones_like(group_indices_nd), group_indices_nd ], 3) indices = tf.gather_nd(shuffled_indices, group_indices_nd) return indices, mask