Beispiel #1
0
 def test_organize_valid_indices(self):
   tf.compat.v1.set_random_seed(1)
   labels = [[1.0, 0.0, -1.0], [-1.0, 1.0, 2.0]]
   is_valid = utils.is_label_valid(labels)
   shuffled_indices = utils.shuffle_valid_indices(is_valid)
   organized_indices = utils.organize_valid_indices(is_valid, shuffle=False)
   with tf.compat.v1.Session() as sess:
     shuffled_indices = sess.run(shuffled_indices)
     self.assertAllEqual(shuffled_indices,
                         [[[0, 1], [0, 0], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
     organized_indices = sess.run(organized_indices)
     self.assertAllEqual(organized_indices,
                         [[[0, 0], [0, 1], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
Beispiel #2
0
def _form_group_indices_nd(is_valid, group_size, shuffle=False, seed=None):
    """Forms the indices for groups for gather_nd or scatter_nd.

  Args:
    is_valid: A boolen `Tensor` for entry validity with shape [batch_size,
      list_size].
    group_size: An scalar int `Tensor` for the number of examples in a group.
    shuffle: A boolean that indicates whether valid indices should be shuffled
      when forming group indices.
    seed: Random seed for shuffle.

  Returns:
    A tuple of Tensors (indices, mask). The first has shape [batch_size,
    num_groups, group_size, 2] and it can be used in gather_nd or scatter_nd for
    group features. The second has the shape of [batch_size, num_groups] with
    value True for valid groups.
  """
    with tf.name_scope(name='form_group_indices'):
        is_valid = tf.convert_to_tensor(value=is_valid)
        batch_size, list_size = tf.unstack(tf.shape(input=is_valid))
        num_valid_entries = tf.reduce_sum(input_tensor=tf.cast(is_valid,
                                                               dtype=tf.int32),
                                          axis=1)
        rw_indices, mask = _rolling_window_indices(list_size, group_size,
                                                   num_valid_entries)
        # Valid indices of the tensor are shuffled and put on the top.
        # [batch_size, list_size, 2]. A determinstic op-level seed is set mainly for
        # unittest purpose. We can find a better way to avoid setting this seed
        # explicitly.
        shuffled_indices = utils.organize_valid_indices(is_valid,
                                                        shuffle=shuffle,
                                                        seed=seed)
        # Construct indices for gather_nd.
        # [batch_size, num_groups, group_size, 2]
        group_indices_nd = tf.expand_dims(rw_indices, axis=3)
        group_indices_nd = tf.concat([
            tf.reshape(tf.range(batch_size), [-1, 1, 1, 1]) *
            tf.ones_like(group_indices_nd), group_indices_nd
        ], 3)

        indices = tf.gather_nd(shuffled_indices, group_indices_nd)
        return indices, mask