Beispiel #1
0
    def build_graph(parameters):
        """Build the graph for the test case."""

        input_tensor = tf.compat.v1.placeholder(
            dtype=tf.int32, name="input", shape=parameters["input_shape"])
        if parameters["index_type"] is None:
            output = tf.unique(input_tensor)
        else:
            output = tf.unique(input_tensor, parameters["index_type"])

        return [input_tensor], output
Beispiel #2
0
 def _apply_sparse(self, cache):
   """"""
   
   x_tm1, g_t, idxs = cache['x_tm1'], cache['g_t'], cache['idxs']
   idxs, idxs_ = tf.unique(idxs)
   g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs))
   updates = cache['updates']
   
   if self.mu > 0:
     m_t, t_m = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', beta=self.mu)
     m_t_ = tf.gather(m_t, idxs)
     m_bar_t_ = (1-self.gamma) * m_t_ + self.gamma * g_t_
     updates.extend([m_t, t_m])
   else:
     m_bar_t_ = g_t_
   
   if self.nu > 0:
     v_t, t_v = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', beta=self.nu)
     v_t_ = tf.gather(v_t, idxs)
     v_bar_t_ = tf.sqrt(v_t_ + self.epsilon)
     updates.extend([v_t, t_v])
   else:
     v_bar_t_ = 1
   
   s_t_ = self.learning_rate * m_bar_t_ / v_bar_t_
   cache['s_t'] = s_t_
   cache['g_t'] = g_t_
   cache['idxs'] = idxs
   return cache
Beispiel #3
0
    def way(self):
        """Compute the way of the episode.

    Returns:
      way: An int constant tensor. The number of classes in the episode.
    """
        episode_classes, _ = tf.unique(self.labels)
        return tf.size(episode_classes)
Beispiel #4
0
 def preproc_segment_ids(self, segment_ids):
     assert len(segment_ids.shape.as_list()) == 5
     B,T,H,W,C = segment_ids.shape.as_list()
     assert B <= 128 and segment_ids.dtype == tf.uint8, "max hash value must be < 256**4 / 2 = 2^16"
     # add batch values to make the hashing unique
     b_inds = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), [B,1,1,1,1]), [1,T,H,W,C])
     segment_ids = tf.concat([b_inds, tf.cast(segment_ids, tf.int32)], axis=-1)
     segment_ids = object_id_hash(segment_ids, dtype_out=tf.int32, val=256)
     _, segment_ids = tf.unique(tf.reshape(segment_ids, [-1]))
     segment_ids = tf.reshape(segment_ids, [B,T,H,W])
     segment_ids = segment_ids - tf.reduce_min(segment_ids, axis=[1,2,3], keepdims=True)
     return segment_ids
Beispiel #5
0
 def _apply_sparse(self, cache):
   """"""
   
   g_t, idxs = cache['g_t'], cache['idxs']
   idxs, idxs_ = tf.unique(idxs)
   g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs))
   
   cache['g_t'] = g_t_
   cache['idxs'] = idxs
   cache['s_t'] = self.learning_rate * g_t_
   
   return cache
Beispiel #6
0
def compute_unique_class_ids(class_ids):
    """Computes the unique class IDs of the episode containing `class_ids`.

  Args:
    class_ids: A 1D tensor representing class IDs, one per example in an
      episode.

  Returns:
    A 1D tensor of the unique class IDs whose size is equal to the way of an
    episode.
  """
    return tf.unique(class_ids)[0]
Beispiel #7
0
def penalize_used(logits, output):

    # I want to change the indices of logits wherever the index is found in output
    change_tensor = tf.zeros_like(logits, dtype=logits.dtype)
    unique = tf.unique(output[0])[0]
    ones = tf.ones_like(unique, dtype=unique.dtype)
    indices = tf.expand_dims(unique, 1)

    updates = tf.scatter_nd(indices, ones, [logits.shape[1]])

    bool_tensor = tf.expand_dims(tf.cast(updates, tf.bool), 0)

    return tf.compat.v1.where(bool_tensor, logits * 0.85, logits)
Beispiel #8
0
def accumulate_sparse_gradients(grad):
    """Accumulates repeated indices of a sparse gradient update.

  Args:
    grad: a tf.IndexedSlices gradient

  Returns:
    grad_indices: unique indices
    grad_values: gradient values corresponding to the indices
  """

    grad_indices, grad_segments = tf.unique(grad.indices)
    grad_values = tf.unsorted_segment_sum(grad.values, grad_segments,
                                          tf.size(grad_indices))
    return grad_indices, grad_values
Beispiel #9
0
def triangles_to_edges(faces):
  """Computes mesh edges from triangles."""
  # collect edges from triangles
  edges = tf.concat([faces[:, 0:2],
                     faces[:, 1:3],
                     tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0)
  # those edges are sometimes duplicated (within the mesh) and sometimes
  # single (at the mesh boundary).
  # sort & pack edges as single tf.int64
  receivers = tf.reduce_min(edges, axis=1)
  senders = tf.reduce_max(edges, axis=1)
  packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64)
  # remove duplicates and unpack
  unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32)
  senders, receivers = tf.unstack(unique_edges, axis=1)
  # create two-way connectivity
  return (tf.concat([senders, receivers], axis=0),
          tf.concat([receivers, senders], axis=0))
Beispiel #10
0
def _dedup_tensor(sp_tensor: tf.SparseTensor) -> tf.SparseTensor:
  """Dedup values of a SparseTensor along each row.

  Args:
    sp_tensor: A 2D SparseTensor to be deduped.
  Returns:
    A deduped SparseTensor of shape [batch_size, max_len], where max_len is
    the maximum number of unique values for a row in the Tensor.
  """
  string_batch_index = tf.as_string(sp_tensor.indices[:, 0])

  # tf.unique only works on 1D tensors. To avoid deduping across examples,
  # prepend each feature value with the example index. This requires casting
  # to and from strings for non-string features.
  string_values = sp_tensor.values
  original_dtype = sp_tensor.values.dtype
  if original_dtype != tf.string:
    string_values = tf.as_string(sp_tensor.values)
  index_and_value = tf.strings.join([string_batch_index, string_values],
                                    separator='|')
  unique_index_and_value, _ = tf.unique(index_and_value)

  # split is a shape [tf.size(values), 2] tensor. The first column contains
  # indices and the second column contains the feature value (we assume no
  # feature contains | so we get exactly 2 values from the string split).
  split = tf.string_split(unique_index_and_value, delimiter='|')
  split = tf.reshape(split.values, [-1, 2])
  string_indices = split[:, 0]
  values = split[:, 1]

  indices = tf.reshape(
      tf.string_to_number(string_indices, out_type=tf.int32), [-1])
  if original_dtype != tf.string:
    values = tf.string_to_number(values, out_type=original_dtype)
  values = tf.reshape(values, [-1])
  # Convert example indices into SparseTensor indices, e.g.
  # [0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]
  batch_size = tf.to_int32(sp_tensor.dense_shape[0])
  new_indices, max_len = _example_index_to_sparse_index(indices, batch_size)
  return tf.SparseTensor(
      indices=tf.to_int64(new_indices),
      values=values,
      dense_shape=[tf.to_int64(batch_size), max_len])
def block_delete_msa(protein, config):
    """Sample MSA by deleting contiguous blocks.

  Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"

  Arguments:
    protein: batch dict containing the msa
    config: ConfigDict with parameters

  Returns:
    updated protein
  """
    num_seq = shape_helpers.shape_list(protein['msa'])[0]
    block_num_seq = tf.cast(
        tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
        tf.int32)

    if config.randomize_num_blocks:
        nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
    else:
        nb = config.num_blocks

    del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
    del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
    del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
    del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    sparse_diff = tf.sets.difference(
        tf.range(1, num_seq)[None], del_indices[None])
    keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
    keep_indices = tf.concat([[0], keep_indices], axis=0)

    for k in _MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = tf.gather(protein[k], keep_indices)

    return protein
Beispiel #12
0
def unique(inputs):
    """Return the unique elements of the input tensor."""
    return tf.unique(inputs)
Beispiel #13
0
 def unique_labels(self):
     return tf.unique(
         tf.concat((self.support_labels, self.query_labels), -1))[0]
def merge_boxes_with_multiple_labels(boxes,
                                     classes,
                                     confidences,
                                     num_classes,
                                     quantization_bins=10000):
  """Merges boxes with same coordinates and returns K-hot encoded classes.

  Args:
    boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. Only
      normalized coordinates are allowed.
    classes: A tf.int32 tensor with shape [N] holding class indices.
      The class index starts at 0.
    confidences: A tf.float32 tensor with shape [N] holding class confidences.
    num_classes: total number of classes to use for K-hot encoding.
    quantization_bins: the number of bins used to quantize the box coordinate.

  Returns:
    merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes,
      where N' <= N.
    class_encodings: A tf.int32 tensor with shape [N', num_classes] holding
      K-hot encodings for the merged boxes.
    confidence_encodings: A tf.float32 tensor with shape [N', num_classes]
      holding encodings of confidences for the merged boxes.
    merged_box_indices: A tf.int32 tensor with shape [N'] holding original
      indices of the boxes.
  """
  boxes_shape = tf.shape(boxes)
  classes_shape = tf.shape(classes)
  confidences_shape = tf.shape(confidences)
  box_class_shape_assert = shape_utils.assert_shape_equal_along_first_dimension(
      boxes_shape, classes_shape)
  box_confidence_shape_assert = (
      shape_utils.assert_shape_equal_along_first_dimension(
          boxes_shape, confidences_shape))
  box_dimension_assert = tf.assert_equal(boxes_shape[1], 4)
  box_normalized_assert = shape_utils.assert_box_normalized(boxes)

  with tf.control_dependencies(
      [box_class_shape_assert, box_confidence_shape_assert,
       box_dimension_assert, box_normalized_assert]):
    quantized_boxes = tf.to_int64(boxes * (quantization_bins - 1))
    ymin, xmin, ymax, xmax = tf.unstack(quantized_boxes, axis=1)
    hashcodes = (
        ymin +
        xmin * quantization_bins +
        ymax * quantization_bins * quantization_bins +
        xmax * quantization_bins * quantization_bins * quantization_bins)
    unique_hashcodes, unique_indices = tf.unique(hashcodes)
    num_boxes = tf.shape(boxes)[0]
    num_unique_boxes = tf.shape(unique_hashcodes)[0]
    merged_box_indices = tf.unsorted_segment_min(
        tf.range(num_boxes), unique_indices, num_unique_boxes)
    merged_boxes = tf.gather(boxes, merged_box_indices)
    unique_indices = tf.to_int64(unique_indices)
    classes = tf.to_int64(classes)

    def map_box_encodings(i):
      """Produces box K-hot and score encodings for each class index."""
      box_mask = tf.equal(
          unique_indices, i * tf.ones(num_boxes, dtype=tf.int64))
      box_mask = tf.reshape(box_mask, [-1])
      box_indices = tf.boolean_mask(classes, box_mask)
      box_confidences = tf.boolean_mask(confidences, box_mask)
      box_class_encodings = tf.sparse_to_dense(
          box_indices, [num_classes], tf.constant(1, dtype=tf.int64),
          validate_indices=False)
      box_confidence_encodings = tf.sparse_to_dense(
          box_indices, [num_classes], box_confidences, validate_indices=False)
      return box_class_encodings, box_confidence_encodings

    # Important to avoid int32 here since there is no GPU kernel for int32.
    # int64 and float32 are fine.
    class_encodings, confidence_encodings = tf.map_fn(
        map_box_encodings,
        tf.range(tf.to_int64(num_unique_boxes)),
        back_prop=False,
        dtype=(tf.int64, tf.float32))

    merged_boxes = tf.reshape(merged_boxes, [-1, 4])
    class_encodings = tf.cast(class_encodings, dtype=tf.int32)
    class_encodings = tf.reshape(class_encodings, [-1, num_classes])
    confidence_encodings = tf.reshape(confidence_encodings, [-1, num_classes])
    merged_box_indices = tf.reshape(merged_box_indices, [-1])
    return (merged_boxes, class_encodings, confidence_encodings,
            merged_box_indices)
Beispiel #15
0
    def _build_loss(self):
        """Builds the loss tensor, to be minimized by the optimizer."""
        self.reader = reader.DataReader(
            self.data_dir,
            self.batch_size,
            self.img_height,
            self.img_width,
            SEQ_LENGTH,
            1,  # num_scales
            self.file_extension,
            self.random_scale_crop,
            reader.FLIP_RANDOM,
            self.random_color,
            self.imagenet_norm,
            self.shuffle,
            self.input_file,
            queue_size=self.queue_size)

        (self.image_stack, self.image_stack_norm, self.seg_stack,
         self.intrinsic_mat, _) = self.reader.read_data()
        if self.learn_intrinsics:
            self.intrinsic_mat = None
        if self.intrinsic_mat is None and not self.learn_intrinsics:
            raise RuntimeError(
                'Could not read intrinsic matrix. Turn '
                'learn_intrinsics on to learn it instead of loading '
                'it.')
        self.export('self.image_stack', self.image_stack)

        object_masks = []
        for i in range(self.batch_size):
            object_ids = tf.unique(tf.reshape(self.seg_stack[i], [-1]))[0]
            object_masks_i = []
            for j in range(SEQ_LENGTH):
                current_seg = self.seg_stack[i, :, :, j * 3]  # (H, W)

                def process_obj_mask(obj_id):
                    """Create a mask for obj_id, skipping the background mask."""
                    mask = tf.logical_and(
                        tf.equal(current_seg, obj_id),  # pylint: disable=cell-var-from-loop
                        tf.not_equal(tf.cast(0, tf.uint8), obj_id))
                    # Leave out vert small masks, that are most often errors.
                    size = tf.reduce_sum(tf.to_int32(mask))
                    mask = tf.logical_and(mask,
                                          tf.greater(size, MIN_OBJECT_AREA))
                    if not self.boxify:
                        return mask
                    # Complete the mask to its bounding box.
                    binary_obj_masks_y = tf.reduce_any(mask,
                                                       axis=1,
                                                       keepdims=True)
                    binary_obj_masks_x = tf.reduce_any(mask,
                                                       axis=0,
                                                       keepdims=True)
                    return tf.logical_and(binary_obj_masks_y,
                                          binary_obj_masks_x)

                object_mask = tf.map_fn(  # (N, H, W)
                    process_obj_mask, object_ids, dtype=tf.bool)
                object_mask = tf.reduce_any(object_mask, axis=0)
                object_masks_i.append(object_mask)
            object_masks.append(tf.stack(object_masks_i, axis=-1))

        self.seg_stack = tf.cast(tf.stack(object_masks, axis=0), tf.float)
        tf.summary.image('Masks', self.seg_stack)

        with tf.variable_scope(DEPTH_SCOPE):
            # Organized by ...[i][scale].  Note that the order is flipped in
            # variables in build_loss() below.
            self.disp = {}
            self.depth = {}

            # Parabolic rampup of he noise over LAYER_NORM_NOISE_RAMPUP_STEPS steps.
            # We stop at 0.5 because this is the value above which the multiplicative
            # noise we use can become negative. Further experimentation is needed to
            # find if non-negativity is indeed needed.
            noise_stddev = 0.5 * tf.square(
                tf.minimum(
                    tf.cast(self.global_step, tf.float) /
                    float(LAYER_NORM_NOISE_RAMPUP_STEPS), 1.0))

            def _normalizer_fn(x, is_train, name='bn'):
                return randomized_layer_normalization.normalize(
                    x, is_train=is_train, name=name, stddev=noise_stddev)

            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH):
                    image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]
                    self.depth[
                        i] = depth_prediction_net.depth_prediction_resnet18unet(
                            image, True, self.weight_reg, _normalizer_fn)
                    self.disp[i] = 1.0 / self.depth[i]

        with tf.name_scope('compute_loss'):
            self.reconstr_loss = 0
            self.smooth_loss = 0
            self.ssim_loss = 0
            self.depth_consistency_loss = 0

            # Smoothness.
            if self.smooth_weight > 0:
                for i in range(SEQ_LENGTH):
                    disp_smoothing = self.disp[i]
                    # Perform depth normalization, dividing by the mean.
                    mean_disp = tf.reduce_mean(disp_smoothing,
                                               axis=[1, 2, 3],
                                               keep_dims=True)
                    disp_input = disp_smoothing / mean_disp
                    self.smooth_loss += _depth_smoothness(
                        disp_input, self.image_stack[:, :, :,
                                                     3 * i:3 * (i + 1)])

            self.rot_loss = 0.0
            self.trans_loss = 0.0

            def add_result_to_loss_and_summaries(endpoints, i, j):
                tf.summary.image(
                    'valid_mask%d%d' % (i, j),
                    tf.expand_dims(endpoints['depth_proximity_weight'], -1))

                self.depth_consistency_loss += endpoints['depth_error']
                self.reconstr_loss += endpoints['rgb_error']
                self.ssim_loss += 0.5 * endpoints['ssim_error']
                self.rot_loss += endpoints['rotation_error']
                self.trans_loss += endpoints['translation_error']

            self.motion_smoothing = 0.0
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH - 1):
                    j = i + 1
                    depth_i = self.depth[i][:, :, :, 0]
                    depth_j = self.depth[j][:, :, :, 0]
                    image_j = self.image_stack[:, :, :, 3 * j:3 * (j + 1)]
                    image_i = self.image_stack[:, :, :, i * 3:(i + 1) * 3]
                    # We select a pair of consecutive images (and their respective
                    # predicted depth maps). Now we have the network predict a motion
                    # field that connects the two. We feed the pair of images into the
                    # network, once in forward order and then in reverse order. The
                    # results are fed into the loss calculation. The following losses are
                    # calculated:
                    # - RGB and SSIM photometric consistency.
                    # - Cycle consistency of rotations and translations for every pixel.
                    # - L1 smoothness of the disparity and the motion field.
                    # - Depth consistency
                    rot, trans, trans_res, mat = motion_prediction_net.motion_field_net(
                        images=tf.concat([image_i, image_j], axis=-1),
                        weight_reg=self.weight_reg)
                    inv_rot, inv_trans, inv_trans_res, inv_mat = (
                        motion_prediction_net.motion_field_net(
                            images=tf.concat([image_j, image_i], axis=-1),
                            weight_reg=self.weight_reg))

                    if self.learn_intrinsics:
                        intrinsic_mat = 0.5 * (mat + inv_mat)
                    else:
                        intrinsic_mat = self.intrinsic_mat[:, 0, :, :]

                    def dilate(x):
                        # Dilation by n pixels is roughtly max pooling by 2 * n + 1.
                        p = self.foreground_dilation * 2 + 1
                        return tf.nn.max_pool(x, [1, p, p, 1], [1] * 4, 'SAME')

                    trans += trans_res * dilate(self.seg_stack[:, :, :,
                                                               j:j + 1])
                    inv_trans += inv_trans_res * dilate(
                        self.seg_stack[:, :, :, i:i + 1])

                    tf.summary.image('trans%d%d' % (i, i + 1), trans)
                    tf.summary.image('trans%d%d' % (i + 1, i), inv_trans)

                    tf.summary.image('trans_res%d%d' % (i + 1, i),
                                     inv_trans_res)
                    tf.summary.image('trans_res%d%d' % (i, i + 1), trans_res)

                    self.motion_smoothing += _smoothness(trans)
                    self.motion_smoothing += _smoothness(inv_trans)
                    tf.summary.scalar(
                        'trans_stdev',
                        tf.sqrt(0.5 * tf.reduce_mean(
                            tf.square(trans) + tf.square(inv_trans))))

                    transformed_depth_j = transform_depth_map.using_motion_vector(
                        depth_j, trans, rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_j, image_j, depth_i, image_i,
                            rot, trans, inv_rot, inv_trans), i, j)

                    transformed_depth_i = transform_depth_map.using_motion_vector(
                        depth_i, inv_trans, inv_rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_i, image_i, depth_j, image_j,
                            inv_rot, inv_trans, rot, trans), j, i)

            # Build the total loss as composed of L1 reconstruction, SSIM, smoothing
            # and object size constraint loss as appropriate.
            self.reconstr_loss *= self.reconstr_weight
            self.export('self.reconstr_loss', self.reconstr_loss)
            self.total_loss = self.reconstr_loss
            if self.smooth_weight > 0:
                self.smooth_loss *= self.smooth_weight
                self.total_loss += self.smooth_loss
                self.export('self.smooth_loss', self.smooth_loss)
            if self.ssim_weight > 0:
                self.ssim_loss *= self.ssim_weight
                self.total_loss += self.ssim_loss
                self.export('self.ssim_loss', self.ssim_loss)

            if self.motion_smoothing_weight > 0:
                self.motion_smoothing *= self.motion_smoothing_weight
                self.total_loss += self.motion_smoothing
                self.export('self.motion_sm_loss', self.motion_smoothing)

            if self.depth_consistency_loss_weight:
                self.depth_consistency_loss *= self.depth_consistency_loss_weight
                self.total_loss += self.depth_consistency_loss
                self.export('self.depth_consistency_loss',
                            self.depth_consistency_loss)

            self.rot_loss *= self.rotation_consistency_weight
            self.trans_loss *= self.translation_consistency_weight
            self.export('rot_loss', self.rot_loss)
            self.export('trans_loss', self.trans_loss)

            self.total_loss += self.rot_loss
            self.total_loss += self.trans_loss

            self.export('self.total_loss', self.total_loss)
Beispiel #16
0
  def _get_richer_data(self, fake_data):
    inputs_tf = fake_data.inputs.input_ids
    labels_tf = fake_data.is_fake_tokens
    lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1)
    #retrieve the basic config
    V = self._bert_config.vocab_size
    #sub: 10%, del + ins: 5%
    N = int(self._config.max_predictions_per_seq * self._config.rich_prob)
    B, L = modeling.get_shape_list(inputs_tf)
    nlms = 0
    bilm = None
    if self._config.use_bilm:
      with open(self._config.bilm_file, 'rb') as f:
        bilm = tf.constant(np.load(f), tf.int32)
      _, nlms = modeling.get_shape_list(bilm)
    #make multiple partitions for edit op
    splits_list = []
    for i in range(B):
      one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32)
      one, _ = tf.unique(one)
      one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1),
                    lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0),
                    lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1])))
      splits_list.append(one[:, 2::2])
    splits_tf = tf.concat(splits_list, 0)
    splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1)
    splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1)
    size_splits = splits_up - splits_lo
    #update the inputs and labels giving random insertion and deletion
    new_labels_list = []
    new_inputs_list = []
    for i in range(B):
      inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :])
      labels_splits = tf.split(labels_tf[i, :], size_splits[i, :])
      one_inputs = []
      one_labels = []
      size_split = len(inputs_splits)
      inputs_end = inputs_splits[-1]
      labels_end = labels_splits[-1]
      for j in range(size_split-1):
        inputs = inputs_splits[j]
        labels = labels_splits[j] #label 1 for substistution
        rand_op = random.randint(2, self._config.num_preds - 1) 
        if rand_op == 2: #label 2 for insertion
          if bilm is None: #noise
            insert_tok = tf.random.uniform([1], 1, V, tf.int32)
          else: #2-gram prediction
            insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0)
          is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0])
          inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs)
          labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels)
          inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end)
          labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end)
        elif rand_op == 3: #label 3 for deletion
          labels = tf.concat([labels[:-2], tf.constant([3])], 0)
          inputs = inputs[:-1]
          inputs_end = tf.concat([inputs_end, tf.constant([0])], 0)
          labels_end = tf.concat([labels_end, tf.constant([0])], 0)
        elif rand_op == 4: #label 4 for swapping
          labels = tf.concat([labels[:-1], tf.constant([4])], 0)
          inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0)
        one_labels.append(labels)
        one_inputs.append(inputs)
      one_inputs.append(inputs_end)
      one_labels.append(labels_end)
      one_inputs_tf = tf.concat(one_inputs, 0)
      one_labels_tf = tf.concat(one_labels, 0)
      one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf)
      one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf)
      new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0))
      new_labels_list.append(tf.expand_dims(one_labels_tf, 0))

    new_inputs_tf = tf.concat(new_inputs_list, 0)
    new_labels_tf = tf.concat(new_labels_list, 0)
    new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32)
    updated_inputs = pretrain_data.get_updated_inputs(
        fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask)
    RicherData = collections.namedtuple("RicherData", [
        "inputs", "is_fake_tokens", "sampled_tokens"])
    return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf,
                     sampled_tokens=fake_data.sampled_tokens)
Beispiel #17
0
 def num_matched_rows(self):
   """Returns number (int32 scalar tensor) of matched rows."""
   unique_rows, _ = tf.unique(self.matched_row_indices())
   return tf.size(unique_rows)
    def _preprocess_image(self, sample):
        """Preprocesses the image and label.

    Args:
      sample: A sample containing image and label.

    Returns:
      sample: Sample with preprocessed image and label.

    Raises:
      ValueError: Ground truth label not provided during training.
    """
        image = sample[common.IMAGE]
        label = sample[common.LABELS_CLASS]

        if not self.strong_weak:
            if not self.output_valid:
                original_image, image, label = input_preprocess.preprocess_image_and_label(
                    image=image,
                    label=label,
                    crop_height=self.crop_size[0],
                    crop_width=self.crop_size[1],
                    min_resize_value=self.min_resize_value,
                    max_resize_value=self.max_resize_value,
                    resize_factor=self.resize_factor,
                    min_scale_factor=self.min_scale_factor,
                    max_scale_factor=self.max_scale_factor,
                    scale_factor_step_size=self.scale_factor_step_size,
                    ignore_label=self.ignore_label,
                    is_training=self.is_training,
                    model_variant=self.model_variant)
            else:
                original_image, image, label, valid = input_preprocess.preprocess_image_and_label(
                    image=image,
                    label=label,
                    crop_height=self.crop_size[0],
                    crop_width=self.crop_size[1],
                    min_resize_value=self.min_resize_value,
                    max_resize_value=self.max_resize_value,
                    resize_factor=self.resize_factor,
                    min_scale_factor=self.min_scale_factor,
                    max_scale_factor=self.max_scale_factor,
                    scale_factor_step_size=self.scale_factor_step_size,
                    ignore_label=self.ignore_label,
                    is_training=self.is_training,
                    model_variant=self.model_variant,
                    output_valid=self.output_valid)
                sample['valid'] = valid
        else:
            original_image, image, label, strong, valid = input_preprocess.preprocess_image_and_label(
                image=image,
                label=label,
                crop_height=self.crop_size[0],
                crop_width=self.crop_size[1],
                min_resize_value=self.min_resize_value,
                max_resize_value=self.max_resize_value,
                resize_factor=self.resize_factor,
                min_scale_factor=self.min_scale_factor,
                max_scale_factor=self.max_scale_factor,
                scale_factor_step_size=self.scale_factor_step_size,
                ignore_label=self.ignore_label,
                is_training=self.is_training,
                model_variant=self.model_variant,
                strong_weak=self.strong_weak)
            sample['strong'] = strong
            sample['valid'] = valid

        sample[common.IMAGE] = image

        if not self.is_training and self.output_original:
            # Original image is only used during visualization.
            sample[common.ORIGINAL_IMAGE] = original_image

        if label is not None:
            sample[common.LABEL] = label

        # Remove common.LABEL_CLASS key in the sample since it is only used to
        # derive label and not used in training and evaluation.
        sample.pop(common.LABELS_CLASS, None)

        # Convert segmentation map to multi-class label
        if self.with_cls and label is not None:
            base = tf.linalg.LinearOperatorIdentity(
                num_rows=self.num_of_classes - 1, dtype=tf.float32)
            base = base.to_dense()
            zero_filler = tf.zeros([1, self.num_of_classes - 1], tf.float32)
            base = tf.concat([zero_filler, base], axis=0)

            cls = tf.unique(tf.reshape(label, shape=[-1]))[0]
            select = tf.less(cls, self.ignore_label)
            cls = tf.boolean_mask(cls, select)
            cls_label = tf.reduce_sum(tf.gather(base, cls, axis=0), axis=0)
            sample['cls_label'] = tf.stop_gradient(cls_label)

        if self.cls_only:
            del sample[common.LABEL]

        return sample
def follow_mention(batch_entities,
                   relation_st_qry,
                   relation_en_qry,
                   entity_word_ids,
                   entity_word_masks,
                   ent2ment_ind,
                   ent2ment_val,
                   ment2ent_map,
                   word_emb_table,
                   word_weights,
                   mips_search_fn,
                   tf_db,
                   hidden_size,
                   mips_config,
                   qa_config,
                   is_training,
                   ensure_index=None):
  """Sparse implementation of the relation follow operation.

  Args:
    batch_entities: [batch_size, num_entities] SparseTensor of incoming entities
      and their scores.
    relation_st_qry: [batch_size, dim] Tensor representating start query vectors
      for dense retrieval.
    relation_en_qry: [batch_size, dim] Tensor representating end query vectors
      for dense retrieval.
    entity_word_ids: [num_entities, max_entity_len] Tensor holding word ids of
      each entity.
    entity_word_masks: [num_entities, max_entity_len] Tensor with masks into
      word ids above.
    ent2ment_ind: [num_entities, num_mentions] RaggedTensor mapping entities to
      mention indices which co-occur with them.
    ent2ment_val: [num_entities, num_mentions] RaggedTensor mapping entities to
      mention scores which co-occur with them.
    ment2ent_map: [num_mentions] Tensor mapping mentions to their entities.
    word_emb_table: [vocab_size, dim] Tensor of word embedddings.  (?)
    word_weights: [vocab_size, 1] Tensor of word weights.  (?)
    mips_search_fn: Function which accepts a dense query vector and returns the
      top-k indices closest to it (from the tf_db).
    tf_db: [num_mentions, 2 * dim] Tensor of mention representations.
    hidden_size: Scalar dimension of word embeddings.
    mips_config: MIPSConfig object.
    qa_config: QAConfig object.
    is_training: Boolean.
    ensure_index: [batch_size] Tensor of mention ids. Only needed if
      `is_training` is True.  (? each example only one ensure entity?)

  Returns:
    ret_mentions_ids: [batch_size, k] Tensor of retrieved mention ids.
    ret_mentions_scs: [batch_size, k] Tensor of retrieved mention scores.
    ret_entities_ids: [batch_size, k] Tensor of retrieved entities ids.
  """
  if qa_config.entity_score_threshold is not None:
    # Remove the entities which have scores lower than the threshold.
    mask = tf.greater(batch_entities.values, qa_config.entity_score_threshold)
    batch_entities = tf.sparse.retain(batch_entities, mask)
  batch_size = batch_entities.dense_shape[0]  # number of the batches
  batch_ind = batch_entities.indices[:, 0]  # the list of the batch ids
  entity_ind = batch_entities.indices[:, 1]  # the list of the entity ids
  entity_scs = batch_entities.values  # the list of the scores of each entity

  # Obtain BOW embeddings for the given set of entities.
  # [NNZ, dim]  NNZ (number of non-zero entries) = len(entity_ind)
  batch_entity_emb = model_utils.entity_emb(entity_ind, entity_word_ids,
                                            entity_word_masks, word_emb_table,
                                            word_weights)
  batch_entity_emb = batch_entity_emb * tf.expand_dims(entity_scs, axis=1)
  # [batch_size, dim]
  uniq_batch_ind, uniq_idx = tf.unique(batch_ind)
  agg_emb = tf.unsorted_segment_sum(batch_entity_emb, uniq_idx,
                                    tf.shape(uniq_batch_ind)[0])
  batch_bow_emb = tf.scatter_nd(
      tf.expand_dims(uniq_batch_ind, 1), agg_emb,
      tf.stack([batch_size, hidden_size], axis=0))
  batch_bow_emb.set_shape([None, hidden_size])
  if qa_config.projection_dim is not None:
    with tf.variable_scope("projection"):
      batch_bow_emb = contrib_layers.fully_connected(
          batch_bow_emb,
          qa_config.projection_dim,
          activation_fn=tf.nn.tanh,
          reuse=tf.AUTO_REUSE,
          scope="bow_projection")
  # Each instance in a batch has onely one vector as embedding.

  # Ragged sparse search.
  # (num_batch x num_entities) * (num_entities x num_mentions)
  # [batch_size x num_mentions] sparse
  sp_mention_vec = model_utils.sparse_ragged_mul(
      batch_entities,
      ent2ment_ind,
      ent2ment_val,
      batch_size,
      mips_config.num_mentions,
      qa_config.sparse_reduce_fn,  # max or sum
      threshold=qa_config.entity_score_threshold,
      fix_values_to_one=qa_config.fix_sparse_to_one)
  if is_training and qa_config.ensure_answer_sparse:
    ensure_indices = tf.stack([tf.range(batch_size), ensure_index], axis=-1)
    sp_ensure_vec = tf.SparseTensor(
        tf.cast(ensure_indices, tf.int64),
        tf.ones([batch_size]),
        dense_shape=[batch_size, mips_config.num_mentions])
    sp_mention_vec = tf.sparse.add(sp_mention_vec, sp_ensure_vec)
    sp_mention_vec = tf.SparseTensor(
        indices=sp_mention_vec.indices,
        values=tf.minimum(1., sp_mention_vec.values),
        dense_shape=sp_mention_vec.dense_shape)

  # Dense scam search.
  # [batch_size, 2 * dim]
  # Constuct query embeddings (dual encoder: [subject; relation]).
  scam_qrys = tf.concat(
      [batch_bow_emb + relation_st_qry, batch_bow_emb + relation_en_qry],
      axis=1)
  with tf.device("/cpu:0"):
    # [batch_size, num_neighbors]
    _, ret_mention_ids = mips_search_fn(scam_qrys)
    if is_training and qa_config.ensure_answer_dense:
      ret_mention_ids = model_utils.ensure_values_in_mat(
          ret_mention_ids, ensure_index, tf.int32)
    # [batch_size, num_neighbors, 2 * dim]
    ret_mention_emb = tf.gather(tf_db, ret_mention_ids)

  if qa_config.l2_normalize_db:
    ret_mention_emb = tf.nn.l2_normalize(ret_mention_emb, axis=2)
  # [batch_size, 1, num_neighbors]
  ret_mention_scs = tf.matmul(
      tf.expand_dims(scam_qrys, 1), ret_mention_emb, transpose_b=True)
  # [batch_size, num_neighbors]
  ret_mention_scs = tf.squeeze(ret_mention_scs, 1)
  # [batch_size, num_mentions] sparse
  dense_mention_vec = model_utils.convert_search_to_vector(
      ret_mention_scs, ret_mention_ids, tf.cast(batch_size, tf.int32),
      mips_config.num_neighbors, mips_config.num_mentions)

  # Combine sparse and dense search.
  if (is_training and qa_config.train_with_sparse) or (
      (not is_training) and qa_config.predict_with_sparse):
    # [batch_size, num_mentions] sparse
    if qa_config.sparse_strategy == "dense_first":
      ret_mention_vec = model_utils.sp_sp_matmul(dense_mention_vec,
                                                 sp_mention_vec)
    elif qa_config.sparse_strategy == "sparse_first":
      with tf.device("/cpu:0"):
        ret_mention_vec = model_utils.rescore_sparse(sp_mention_vec, tf_db,
                                                     scam_qrys)
    else:
      raise ValueError("Unrecognized sparse_strategy %s" %
                       qa_config.sparse_strategy)
  else:
    # [batch_size, num_mentions] sparse
    ret_mention_vec = dense_mention_vec

  # Get entity scores and ids.
  # [batch_size, num_entities] sparse
  entity_indices = tf.cast(
      tf.gather(ment2ent_map, ret_mention_vec.indices[:, 1]), tf.int64)
  ret_entity_vec = tf.SparseTensor(
      indices=tf.concat(
          [ret_mention_vec.indices[:, 0:1],
           tf.expand_dims(entity_indices, 1)],
          axis=1),
      values=ret_mention_vec.values,
      dense_shape=[batch_size, qa_config.num_entities])

  return ret_entity_vec, ret_mention_vec, dense_mention_vec, sp_mention_vec
def follow_fact(
    batch_facts,
    relation_st_qry,
    relation_en_qry,
    fact2fact_ind,
    fact2fact_val,
    fact2ent_ind,
    fact2ent_val,
    fact_mips_search_fn,
    tf_fact_db,
    fact_mips_config,
    qa_config,
    is_training,
    hop_id=0,
    is_printing=True,
):
  """Sparse implementation of the relation follow operation.

  Args:
    batch_facts: [batch_size, num_facts] SparseTensor of incoming facts and
      their scores.
    relation_st_qry: [batch_size, dim] Tensor representating start query vectors
      for dense retrieval.
    relation_en_qry: [batch_size, dim] Tensor representating end query vectors
      for dense retrieval.
    fact2fact_ind: [num_facts, num_facts] RaggedTensor mapping facts to entity
      indices which co-occur with them.
    fact2fact_val: [num_facts, num_facts] RaggedTensor mapping facts to entity
      scores which co-occur with them.
    fact2ent_ind: [num_facts, num_entities] RaggedTensor mapping facts to entity
      indices which co-occur with them.
    fact2ent_val: [num_facts, num_entities] RaggedTensor mapping facts to entity
      scores which co-occur with them.
    fact_mips_search_fn: Function which accepts a dense query vector and returns
      the top-k indices closest to it (from the tf_fact_db).
    tf_fact_db: [num_facts, 2 * dim] Tensor of fact representations.
    fact_mips_config: MIPS Config object.
    qa_config: QAConfig object.
    is_training: Boolean.
    hop_id: int, the current hop id.
    is_printing: if print results for debugging.

  Returns:
    ret_entities: [batch_size, num_entities] Tensor of retrieved entities.
    ret_facts: [batch_size, num_facts] Tensor of retrieved facts.
    dense_fact_vec: [batch_size, num_facts] Tensor of retrieved facts (dense).
    sp_fact_vec: [batch_size, num_facts] Tensor of retrieved facts (sparse).
  """
  num_facts = fact_mips_config.num_facts
  batch_size = batch_facts.dense_shape[0]  # number of examples in a batch
  example_ind = batch_facts.indices[:, 0]  # the list of the example ids
  fact_ind = batch_facts.indices[:, 1]  # the list of the fact ids
  fact_scs = batch_facts.values  # the list of the scores of each fact
  uniq_original_example_ind, uniq_local_example_idx = tf.unique(example_ind)
  # uniq_original_example_ind: local to original example id
  # uniq_local_example_idx: a list of local example id
  # tf.shape(uniq_original_example_ind)[0] = num_examples
  if qa_config.fact_score_threshold is not None:
    # Remove the facts which have scores lower than the threshold.
    mask = tf.greater(batch_facts.values, qa_config.fact_score_threshold)
    batch_facts = tf.sparse.retain(batch_facts, mask)
  # Sparse: Ragged sparse search from the current facts to the next facts.
  # (num_batch x num_facts) X (num_facts x num_facts)
  # [batch_size x num_facts] sparse
  if hop_id > 0:
    sp_fact_vec = model_utils.sparse_ragged_mul(
        batch_facts,
        fact2fact_ind,
        fact2fact_val,
        batch_size,
        num_facts,
        "sum",  # Note: check this.
        threshold=None,
        fix_values_to_one=True)
    # Note: find a better way for this.
    mask = tf.greater(sp_fact_vec.values, 3)  # 1/0.2 = 5
    sp_fact_vec = tf.sparse.retain(sp_fact_vec, mask)
  else:
    # For the first hop, then we use the init fact itself.
    # Because the sparse retieval is already done from the question.
    sp_fact_vec = batch_facts

  # Note: Remove the previous hop's facts
  # Note: Limit the number of fact followers.

  # Dense: Aggregate the facts in each batch as a single fact embedding vector.
  fact_embs = tf.gather(tf_fact_db, fact_ind)  # len(fact_ind) X 2dim
  # Note: check, does mean make sense?
  # sum if it was softmaxed
  # mean..
  del fact_scs  # Not used for now.
  # fact_embs = fact_embs * tf.expand_dims(fact_scs, axis=1)  #batch_fact.values
  ### Start of debugging w/ tf.Print ###
  if is_printing:
    fact_embs = tf.compat.v1.Print(
        input_=fact_embs,
        data=[tf.shape(batch_facts.indices)[0], batch_facts.indices],
        message="\n\n###\n batch_facts.indices and total #facts at hop %d \n" %
        hop_id,
        first_n=10,
        summarize=50)
    fact_embs = tf.compat.v1.Print(
        input_=fact_embs,
        data=[
            batch_facts.values,
        ],
        message="batch_facts.values at hop %d \n" % hop_id,
        first_n=10,
        summarize=25)
    fact_embs = tf.compat.v1.Print(
        input_=fact_embs,
        data=[tf.shape(sp_fact_vec.indices)[0], sp_fact_vec.indices],
        message="\n Sparse Fact Results @ hop %d \n" % hop_id +
        " sp_fact_vec.indices at hop %d \n" % hop_id,
        first_n=10,
        summarize=50)
    fact_embs = tf.compat.v1.Print(
        input_=fact_embs,
        data=[
            sp_fact_vec.values,
        ],
        message="sp_fact_vec.values at hop %d \n" % hop_id,
        first_n=10,
        summarize=25)
  ### End of debugging w/ tf.Print ###

  agg_emb = tf.math.unsorted_segment_mean(
      fact_embs, uniq_local_example_idx,
      tf.shape(uniq_original_example_ind)[0])
  batch_fact_emb = tf.scatter_nd(
      tf.expand_dims(uniq_original_example_ind, 1), agg_emb,
      tf.stack([batch_size, 2 * qa_config.projection_dim], axis=0))
  # Each instance in a batch has onely one vector as the overall fact emb.
  batch_fact_emb.set_shape([None, 2 * qa_config.projection_dim])

  # Note: Normalize the embeddings if they are not from SoftMax.
  # batch_fact_emb = tf.nn.l2_normalize(batch_fact_emb, axis=1)

  # Dense scam search.
  # [batch_size, 2 * dim]
  # Note: reform query embeddings.
  scam_qrys = batch_fact_emb + tf.concat([relation_st_qry, relation_en_qry],
                                         axis=1)
  with tf.device("/cpu:0"):
    # [batch_size, num_neighbors]
    _, ret_fact_ids = fact_mips_search_fn(scam_qrys)
    # [batch_size, num_neighbors, 2 * dim]
    ret_fact_emb = tf.gather(tf_fact_db, ret_fact_ids)

  if qa_config.l2_normalize_db:
    ret_fact_emb = tf.nn.l2_normalize(ret_fact_emb, axis=2)
  # [batch_size, 1, num_neighbors]
  # The score of a fact is its innder product with qry.
  ret_fact_scs = tf.matmul(
      tf.expand_dims(scam_qrys, 1), ret_fact_emb, transpose_b=True)
  # [batch_size, num_neighbors]
  ret_fact_scs = tf.squeeze(ret_fact_scs, 1)
  # [batch_size, num_facts] sparse
  dense_fact_vec = model_utils.convert_search_to_vector(
      ret_fact_scs, ret_fact_ids, tf.cast(batch_size, tf.int32),
      fact_mips_config.num_neighbors, fact_mips_config.num_facts)

  # Combine sparse and dense search.
  if (is_training and qa_config.train_with_sparse) or (
      (not is_training) and qa_config.predict_with_sparse):
    # [batch_size, num_mentions] sparse
    if qa_config.sparse_strategy == "dense_first":
      ret_fact_vec = model_utils.sp_sp_matmul(dense_fact_vec, sp_fact_vec)
    elif qa_config.sparse_strategy == "sparse_first":
      with tf.device("/cpu:0"):
        ret_fact_vec = model_utils.rescore_sparse(sp_fact_vec, tf_fact_db,
                                                  scam_qrys)
    else:
      raise ValueError("Unrecognized sparse_strategy %s" %
                       qa_config.sparse_strategy)
  else:
    # [batch_size, num_facts] sparse
    ret_fact_vec = dense_fact_vec

  # # Scaling facts with SoftMax.
  ret_fact_vec = tf.sparse.reorder(ret_fact_vec)
  # max_ip_scores = tf.reduce_max(ret_fact_vec.values)
  # min_ip_scores = tf.reduce_min(ret_fact_vec.values)
  # range_ip_scores = max_ip_scores - min_ip_scores
  # scaled_values = (ret_fact_vec.values - min_ip_scores) / range_ip_scores
  scaled_facts = tf.SparseTensor(
      indices=ret_fact_vec.indices,
      values=ret_fact_vec.values / tf.reduce_max(ret_fact_vec.values),
      dense_shape=ret_fact_vec.dense_shape)
  # ret_fact_vec_sf = tf.sparse.softmax(scaled_facts)
  ret_fact_vec_sf = scaled_facts

  # Remove the facts which have scores lower than the threshold.
  mask = tf.greater(ret_fact_vec_sf.values, 0.5)  # Must larger than max/5
  ret_fact_vec_sf_fitered = tf.sparse.retain(ret_fact_vec_sf, mask)

  # Note: add a soft way to score (all) the entities based on the facts.
  # Note: maybe use the pre-computed (tf-idf) similarity score here. e2e
  # Retrieve entities before Fact-SoftMaxing
  ret_entities_nosc = model_utils.sparse_ragged_mul(
      ret_fact_vec_sf,  # Use the non-filtered scores of the retrieved facts.
      fact2ent_ind,
      fact2ent_val,
      batch_size,
      qa_config.num_entities,
      "sum",
      threshold=qa_config.fact_score_threshold,
      fix_values_to_one=True)

  ret_entities = tf.SparseTensor(
      indices=ret_entities_nosc.indices,
      values=ret_entities_nosc.values / tf.reduce_max(ret_entities_nosc.values),
      dense_shape=ret_entities_nosc.dense_shape)

  ### Start of debugging w/ tf.Print ###
  if is_printing:
    tmp_vals = ret_entities.values

    tmp_vals = tf.compat.v1.Print(
        input_=tmp_vals,
        data=[tf.shape(ret_fact_vec.indices)[0], ret_fact_vec.indices],
        message="\n\n-rescored- ret_fact_vec.indices at hop %d \n" % hop_id,
        first_n=10,
        summarize=51)
    tmp_vals = tf.compat.v1.Print(
        input_=tmp_vals,
        data=[
            ret_fact_vec.values,
        ],
        message="-rescored- ret_fact_vec.values at hop %d \n" % hop_id,
        first_n=10,
        summarize=25)
    tmp_vals = tf.compat.v1.Print(
        input_=tmp_vals,
        data=[
            ret_fact_vec_sf.values,
        ],
        message="ret_fact_vec_sf.values at hop %d \n" % hop_id,
        first_n=10,
        summarize=25)
    tmp_vals = tf.compat.v1.Print(
        input_=tmp_vals,
        data=[
            tf.shape(ret_fact_vec_sf_fitered.values),
            ret_fact_vec_sf_fitered.values,
        ],
        message="ret_fact_vec_sf_fitered.values at hop %d \n" % hop_id,
        first_n=10,
        summarize=25)
    ret_entities = tf.SparseTensor(
        indices=ret_entities.indices,
        values=tmp_vals,
        dense_shape=ret_entities.dense_shape)
  ### End of debugging w/ tf.Print ###

  return ret_entities, ret_fact_vec_sf_fitered, None, None
tf.disable_v2_behavior()  # ver2.0 사용안함

# 1. argmin/argmax
a = tf.constant([5, 2, 1, 4, 3], dtype=tf.int32)
b = tf.constant([4, 5, 1, 3, 2])
c = tf.constant([[5, 4, 2], [3, 2, 4]])  # 2차원

# dimension : reduce 차원(vector = 0)
min_index = tf.arg_min(a, dimension=0)  # 1차원 대상
max_index = tf.arg_max(b, dimension=0)  # 1차원 대상
max_index2 = tf.arg_max(c, dimension=1)  # 2차원 대상
#
sess = tf.Session()
print(sess.run(min_index))  # 2
print(sess.run(max_index))  # 1
print(sess.run(max_index2))  # [0 2]

# 2. unique/setdiff1d

c = tf.constant(['a', 'b', 'a', 'c', 'b'])
# unique
cstr, cidx = tf.unique(c)
print(sess.run(cstr))  # [b'a' b'b' b'c']
print(sess.run(cidx))  # [0 1 0 2 1]

# setdiff1d : [5,2,1,4,3] - [1,3,2]
d = tf.constant([1, 3, 2], dtype=tf.int32)
set_result, set_idx = tf.setdiff1d(a, d)
print(sess.run(set_result))  # [5 4]
print(sess.run(set_idx))  # [0 3]
Beispiel #22
0
def multiply2n_ragged(tensor1, tensor2):
    #this  function multiplies two ragged tesnsors of rank 2 . the most outer ranks of the two tensros must be equal .
    #setting variables and constats
    outerloop_counter = tf.constant(0, dtype=tf.int32)
    carry_on = tf.constant(0, dtype=tf.int32)
    taValues = tf.TensorArray(tf.float32,
                              size=0,
                              dynamic_size=True,
                              clear_after_read=False,
                              infer_shape=False)
    taL2Splits = tf.TensorArray(tf.int32,
                                size=0,
                                dynamic_size=True,
                                clear_after_read=False,
                                infer_shape=False)
    taL1Splits = tf.TensorArray(tf.int32,
                                size=0,
                                dynamic_size=True,
                                clear_after_read=False,
                                infer_shape=False)
    taL1Splits = taL1Splits.write(
        0, [0])  ## required intialization for L1 split only
    innerloop_processing_graphed = tf.function(innerloop_processing)
    generateL1Tensor_writeback_graphed = tf.function(
        generateL1Tensor_writeback)

    def outerloop_cond(counter, input1, input2, taValues, taL2Splits,
                       taL1Splits, carry_on):
        value = tf.shape(input1[2])[0] - 1
        return counter < value  ## this is the length of the outermost dimision , stop of this

    def outloop_body(counter, input1, input2, taValues, taL2Splits, taL1Splits,
                     carry_on):
        l1_comp_begin = input1[2][
            counter]  ## this is begin position of the current row in the outer split  ( ie. the ith value in the outer row split tensor )
        l1_comp_end = input1[2][
            counter +
            1]  ## this is end position of the current row in the outer split   (ie. the ith + 1 value in the outer row split tensor)
        l1_comp2_begin = input2[2][
            counter]  ## we do the same for the second components
        l1_comp2_end = input2[2][
            counter + 1]  ## we do the same for the second components
        comp = innerloop_processing_graphed(
            l1_comp_begin, l1_comp_end, input1
        )  ## now retrive the data to be procesed for the selected rows from vector1
        comp2 = innerloop_processing_graphed(
            l1_comp2_begin, l1_comp2_end, input2)  ## do the same for vector 2

        #comp2 = tf.transpose(comp2) ### desired operation
        multiply = tf.matmul(comp, comp2)  #### This is the desired operation

        myshape = tf.shape(
            multiply
        )  ## calculate the shape of the result in order to prepare to write the result in a ragged tensor format.
        offset = tf.cond(
            taValues.size() > 0, lambda: tf.shape(taValues.concat())[0],
            lambda: 0
        )  ### this is a hack, TensorArray.concat returns an error if the array is empty. Thus we check before calling this.
        #print11=tf.print("=================Final Shape is : " ,myshape[0] , " X " ,myshape[1] )
        l2v = generateL1Tensor_writeback_graphed(
            offset, myshape[1], myshape[0]
        )  # generate the inner row split of the result for the current element
        taL2Splits = taL2Splits.write(
            counter, l2v)  # write back the inner rowlplit to a TensorArray
        taValues = taValues.write(
            counter, tf.reshape(multiply, [-1])
        )  # wirte back the actual ragged tensor elemnts in a another TensorArray
        carry_on = carry_on + myshape[
            0]  ## required to calculate the outer row splite
        taL1Splits = taL1Splits.write(
            counter + 1, [carry_on])  ## This is the outmost row split.
        with tf.control_dependencies(
            [comp, comp2, myshape, l2v, carry_on, multiply]):
            counter = counter + 1
        return counter, input1, input2, taValues, taL2Splits, taL1Splits, carry_on

    with tf.name_scope("RaggedMultiply"):
        outerloop_finalcounter, _, _, ta1, ta2, ta3, _ = tf.while_loop(
            outerloop_cond,
            outloop_body, [
                outerloop_counter, tensor1, tensor2, taValues, taL2Splits,
                taL1Splits, carry_on
            ],
            back_prop=True)
    uinquie_ta2, _ = tf.unique(
        ta2.concat()
    )  # this is required since some values might be duplicate in the row split itself
    t1 = ta1.concat()
    t3 = ta3.concat()
    #with  tf.control_dependencies([t1 , uinquie_ta2 ,t3  ]):
    final_values = t1, uinquie_ta2, t3
    return final_values
Beispiel #23
0
      shape_utils.assert_shape_equal_along_first_dimension(
          boxes_shape, confidences_shape))
  box_dimension_assert = tf.assert_equal(boxes_shape[1], 4)
  box_normalized_assert = shape_utils.assert_box_normalized(boxes)

  with tf.control_dependencies(
      [box_class_shape_assert, box_confidence_shape_assert,
       box_dimension_assert, box_normalized_assert]):
    quantized_boxes = tf.to_int64(boxes * (quantization_bins - 1))
    ymin, xmin, ymax, xmax = tf.unstack(quantized_boxes, axis=1)
    hashcodes = (
        ymin +
        xmin * quantization_bins +
        ymax * quantization_bins * quantization_bins +
        xmax * quantization_bins * quantization_bins * quantization_bins)
    unique_hashcodes, unique_indices = tf.unique(hashcodes)
    num_boxes = tf.shape(boxes)[0]
    num_unique_boxes = tf.shape(unique_hashcodes)[0]
    merged_box_indices = tf.unsorted_segment_min(
        tf.range(num_boxes), unique_indices, num_unique_boxes)
    merged_boxes = tf.gather(boxes, merged_box_indices)
    unique_indices = tf.to_int64(unique_indices)
    classes = tf.to_int64(classes)

    def map_box_encodings(i):
      """Produces box K-hot and score encodings for each class index."""
      box_mask = tf.equal(
          unique_indices, i * tf.ones(num_boxes, dtype=tf.int64))
      box_mask = tf.reshape(box_mask, [-1])
      box_indices = tf.boolean_mask(classes, box_mask)
      box_confidences = tf.boolean_mask(confidences, box_mask)
Beispiel #24
0
def detection_loss(cls_outputs, box_outputs, labels, params):
    """Computes total detection loss.

  Computes total detection loss including box and class loss from all levels.
  Args:
    cls_outputs: an OrderDict with keys representing levels and values
      representing logits in [batch_size, height, width, num_anchors].
    box_outputs: an OrderDict with keys representing levels and values
      representing box regression targets in [batch_size, height, width,
      num_anchors * 4].
    labels: the dictionary that returned from dataloader that includes
      groundtruth targets.
    params: the dictionary including training parameters specified in
      default_haprams function in this file.

  Returns:
    total_loss: an integer tensor representing total loss reducing from
      class and box losses from all levels.
    cls_loss: an integer tensor representing total class loss.
    box_loss: an integer tensor representing total box regression loss.
    box_iou_loss: an integer tensor representing total box iou loss.
  """
    # Sum all positives in a batch for normalization and avoid zero
    # num_positives_sum, which would lead to inf loss during training
    num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0
    levels = cls_outputs.keys()

    cls_losses = []
    box_losses = []
    sumrule = {}
    if params.get('sumrule'):
        sumrule = params['sumrule']
        # because of cls_targets -= 1 (so that bg class becomes -1, actual class then starts from 0)
        # we need to subtract 1 from sumrule as well.
        _sumrule = {}
        for k, v in sumrule.items():
            _sumrule[k - 1] = [vv - 1 for vv in v]
        sumrule = _sumrule

    def table_lookup(values, old_onehot, cls_targets_at_level):
        for val in values:
            if sumrule.get(val):
                new_val = sumrule[val]
                #prob = 1.0/len(new_val)
                prob = 0.5  # try sigmoid cross entropy first so set this to 0.5, if we use softmax we should set this to 1.0/len(new_val)
                if len(new_val) == 1:
                    # leaf node, prob = 1.0
                    prob = 1.0
                _matching_onehot = old_onehot[np.where(
                    cls_targets_at_level == val)]
                _matching_onehot[:, new_val] = prob
                _matching_onehot[:, val] = 0
                old_onehot[np.where(
                    cls_targets_at_level == val)] = _matching_onehot
        return old_onehot

    for level in levels:
        # Onehot encoding for classification labels.
        _cls_targets_at_level = tf.one_hot(labels['cls_targets_%d' % level],
                                           params['num_classes'])
        if params.get('sumrule'):
            unique_labels, _ = tf.unique(
                tf.reshape(labels['cls_targets_%d' % level], [-1]))
            # refine one-hot labels so that we map each label to it's finest leaves
            cls_targets_at_level = tf.numpy_function(
                table_lookup, [
                    unique_labels, _cls_targets_at_level,
                    labels['cls_targets_%d' % level]
                ], _cls_targets_at_level.dtype)
            cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                              _cls_targets_at_level.shape)
        else:
            cls_targets_at_level = _cls_targets_at_level

        if params['data_format'] == 'channels_first':
            bs, _, width, height, _ = cls_targets_at_level.get_shape().as_list(
            )
            cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                              [bs, -1, width, height])
        else:
            bs, width, height, _, _ = cls_targets_at_level.get_shape().as_list(
            )
            cls_targets_at_level = tf.reshape(cls_targets_at_level,
                                              [bs, width, height, -1])
        box_targets_at_level = labels['box_targets_%d' % level]

        cls_loss = focal_loss(cls_outputs[level],
                              cls_targets_at_level,
                              params['alpha'],
                              params['gamma'],
                              normalizer=num_positives_sum,
                              label_smoothing=params['label_smoothing'])

        if params['data_format'] == 'channels_first':
            cls_loss = tf.reshape(
                cls_loss, [bs, -1, width, height, params['num_classes']])
        else:
            cls_loss = tf.reshape(
                cls_loss, [bs, width, height, -1, params['num_classes']])
        cls_loss *= tf.cast(
            tf.expand_dims(tf.not_equal(labels['cls_targets_%d' % level], -2),
                           -1), tf.float32)
        cls_losses.append(tf.reduce_sum(cls_loss))

        if params['box_loss_weight']:
            box_losses.append(
                _box_loss(box_outputs[level],
                          box_targets_at_level,
                          num_positives_sum,
                          delta=params['delta']))

    if params['iou_loss_type']:
        input_anchors = anchors.Anchors(params['min_level'],
                                        params['max_level'],
                                        params['num_scales'],
                                        params['aspect_ratios'],
                                        params['anchor_scale'],
                                        params['image_size'])
        box_output_list = [tf.reshape(box_outputs[i], [-1, 4]) for i in levels]
        box_outputs = tf.concat(box_output_list, axis=0)
        box_target_list = [
            tf.reshape(labels['box_targets_%d' % level], [-1, 4])
            for level in levels
        ]
        box_targets = tf.concat(box_target_list, axis=0)
        anchor_boxes = tf.tile(input_anchors.boxes, [params['batch_size'], 1])
        box_outputs = anchors.decode_box_outputs(box_outputs, anchor_boxes)
        box_targets = anchors.decode_box_outputs(box_targets, anchor_boxes)
        box_iou_loss = _box_iou_loss(box_outputs, box_targets,
                                     num_positives_sum,
                                     params['iou_loss_type'])

    else:
        box_iou_loss = 0

    # Sum per level losses to total loss.
    cls_loss = tf.add_n(cls_losses)
    box_loss = tf.add_n(box_losses) if box_losses else 0

    total_loss = (cls_loss + params['box_loss_weight'] * box_loss +
                  params['iou_loss_weight'] * box_iou_loss)

    return total_loss, cls_loss, box_loss, box_iou_loss
Beispiel #25
0
    def _build(self, all_anchors, gt_boxes, im_shape):
        """
        We compare anchors to GT and using the minibatch size and the different
        config settings (clobber, foreground fraction, etc), we end up with
        training targets *only* for the elements we want to use in the batch,
        while everything else is ignored.

        Basically what it does is, first generate the targets for all (valid)
        anchors, and then start subsampling the positive (foreground) and the
        negative ones (background) based on the number of samples of each type
        that we want.

        Args:
            all_anchors:
                A Tensor with all the bounding boxes coords of the anchors.
                Its shape should be (num_anchors, 4).
            gt_boxes:
                A Tensor with the ground truth bounding boxes of the image of
                the batch being processed. Its shape should be (num_gt, 5).
                The last dimension is used for the label.
            im_shape:
                Shape of original image (height, width) in order to define
                anchor targers in respect with gt_boxes.

        Returns:
            Tuple of the tensors of:
                labels: (1, 0, -1) for each anchor.
                    Shape (num_anchors, 1)
                bbox_targets: 4d bbox targets as specified by paper.
                    Shape (num_anchors, 4)
                max_overlaps: Max IoU overlap with ground truth boxes.
                    Shape (num_anchors, 1)
        """
        # Keep only the coordinates of gt_boxes
        gt_boxes = gt_boxes[:, :4]
        all_anchors = all_anchors[:, :4]

        # Only keep anchors inside the image
        (x_min_anchor, y_min_anchor, x_max_anchor,
         y_max_anchor) = tf.unstack(all_anchors, axis=1)

        anchor_filter = tf.logical_and(
            tf.logical_and(
                tf.greater_equal(x_min_anchor, -self._allowed_border),
                tf.greater_equal(y_min_anchor, -self._allowed_border)),
            tf.logical_and(
                tf.less(x_max_anchor, im_shape[1] + self._allowed_border),
                tf.less(y_max_anchor, im_shape[0] + self._allowed_border)))

        # We (force) reshape the filter so that we can use it as a boolean mask
        anchor_filter = tf.reshape(anchor_filter, [-1])
        # Filter anchors.
        anchors = tf.boolean_mask(all_anchors,
                                  anchor_filter,
                                  name='filter_anchors')

        # Generate array with the labels for all_anchors.
        labels = tf.fill((tf.gather(tf.shape(all_anchors), [0])), -1)
        labels = tf.boolean_mask(labels, anchor_filter, name='filter_labels')

        # Intersection over union (IoU) overlap between the anchors and the
        # ground truth boxes.
        overlaps = bbox_overlap_tf(tf.to_float(anchors), tf.to_float(gt_boxes))

        # Generate array with the IoU value of the closest GT box for each
        # anchor.
        max_overlaps = tf.reduce_max(overlaps, axis=1)
        if not self._clobber_positives:
            # Assign bg labels first so that positive labels can clobber them.
            # First we get an array with True where IoU is less than
            # self._negative_overlap
            negative_overlap_nonzero = tf.less(max_overlaps,
                                               self._negative_overlap)

            # Finally we set 0 at True indices
            labels = tf.where(condition=negative_overlap_nonzero,
                              x=tf.zeros(tf.shape(labels)),
                              y=tf.to_float(labels))
        # Get the value of the max IoU for the closest anchor for each gt.
        gt_max_overlaps = tf.reduce_max(overlaps, axis=0)

        # Find all the indices that match (at least one, but could be more).
        gt_argmax_overlaps = tf.squeeze(tf.equal(overlaps, gt_max_overlaps))
        gt_argmax_overlaps = tf.where(gt_argmax_overlaps)[:, 0]
        # Eliminate duplicates indices.
        gt_argmax_overlaps, _ = tf.unique(gt_argmax_overlaps)
        # Order the indices for sparse_to_dense compatibility
        gt_argmax_overlaps, _ = tf.nn.top_k(gt_argmax_overlaps,
                                            k=tf.shape(gt_argmax_overlaps)[-1])
        gt_argmax_overlaps = tf.reverse(gt_argmax_overlaps, [0])

        # Foreground label: for each ground-truth, anchor with highest overlap.
        # When the argmax is many items we use all of them (for consistency).
        # We set 1 at gt_argmax_overlaps_cond indices
        gt_argmax_overlaps_cond = tf.sparse_to_dense(gt_argmax_overlaps,
                                                     tf.shape(
                                                         labels,
                                                         out_type=tf.int64),
                                                     True,
                                                     default_value=False)

        labels = tf.where(condition=gt_argmax_overlaps_cond,
                          x=tf.ones(tf.shape(labels)),
                          y=tf.to_float(labels))

        # Foreground label: above threshold Intersection over Union (IoU)
        # First we get an array with True where IoU is greater or equal than
        # self._positive_overlap
        positive_overlap_inds = tf.greater_equal(max_overlaps,
                                                 self._positive_overlap)
        # Finally we set 1 at True indices
        labels = tf.where(condition=positive_overlap_inds,
                          x=tf.ones(tf.shape(labels)),
                          y=labels)

        if self._clobber_positives:
            # Assign background labels last so that negative labels can clobber
            # positives. First we get an array with True where IoU is less than
            # self._negative_overlap
            negative_overlap_nonzero = tf.less(max_overlaps,
                                               self._negative_overlap)
            # Finally we set 0 at True indices
            labels = tf.where(condition=negative_overlap_nonzero,
                              x=tf.zeros(tf.shape(labels)),
                              y=labels)

        # Subsample positive labels if we have too many
        def subsample_positive():
            # Shuffle the foreground indices
            disable_fg_inds = tf.random_shuffle(fg_inds, seed=self._seed)
            # Select the indices that we have to ignore, this is
            # `tf.shape(fg_inds)[0] - num_fg` because we want to get only
            # `num_fg` foreground labels.
            disable_place = (tf.shape(fg_inds)[0] - num_fg)
            disable_fg_inds = disable_fg_inds[:disable_place]
            # Order the indices for sparse_to_dense compatibility
            disable_fg_inds, _ = tf.nn.top_k(disable_fg_inds,
                                             k=tf.shape(disable_fg_inds)[-1])
            disable_fg_inds = tf.reverse(disable_fg_inds, [0])
            disable_fg_inds = tf.sparse_to_dense(disable_fg_inds,
                                                 tf.shape(labels,
                                                          out_type=tf.int64),
                                                 True,
                                                 default_value=False)
            # Put -1 to ignore the anchors in the selected indices
            return tf.where(condition=tf.squeeze(disable_fg_inds),
                            x=tf.to_float(tf.fill(tf.shape(labels), -1)),
                            y=labels)

        num_fg = tf.to_int32(self._foreground_fraction * self._minibatch_size)
        # Get foreground indices, get True in the indices where we have a one.
        fg_inds = tf.equal(labels, 1)
        # We get only the indices where we have True.
        fg_inds = tf.squeeze(tf.where(fg_inds), axis=1)
        fg_inds_size = tf.size(fg_inds)
        # Condition for check if we have too many positive labels.
        subsample_positive_cond = fg_inds_size > num_fg
        # Check the condition and subsample positive labels.
        labels = tf.cond(subsample_positive_cond,
                         true_fn=subsample_positive,
                         false_fn=lambda: labels)

        # Subsample negative labels if we have too many
        def subsample_negative():
            # Shuffle the background indices
            disable_bg_inds = tf.random_shuffle(bg_inds, seed=self._seed)

            # Select the indices that we have to ignore, this is
            # `tf.shape(bg_inds)[0] - num_bg` because we want to get only
            # `num_bg` background labels.
            disable_place = (tf.shape(bg_inds)[0] - num_bg)
            disable_bg_inds = disable_bg_inds[:disable_place]
            # Order the indices for sparse_to_dense compatibility
            disable_bg_inds, _ = tf.nn.top_k(disable_bg_inds,
                                             k=tf.shape(disable_bg_inds)[-1])
            disable_bg_inds = tf.reverse(disable_bg_inds, [0])
            disable_bg_inds = tf.sparse_to_dense(disable_bg_inds,
                                                 tf.shape(labels,
                                                          out_type=tf.int64),
                                                 True,
                                                 default_value=False)
            # Put -1 to ignore the anchors in the selected indices
            return tf.where(condition=tf.squeeze(disable_bg_inds),
                            x=tf.to_float(tf.fill(tf.shape(labels), -1)),
                            y=labels)

        # Recalculate the foreground indices after (maybe) disable some of them

        # Get foreground indices, get True in the indices where we have a one.
        fg_inds = tf.equal(labels, 1)
        # We get only the indices where we have True.
        fg_inds = tf.squeeze(tf.where(fg_inds), axis=1)
        fg_inds_size = tf.size(fg_inds)

        num_bg = tf.to_int32(self._minibatch_size - fg_inds_size)
        # Get background indices, get True in the indices where we have a zero.
        bg_inds = tf.equal(labels, 0)
        # We get only the indices where we have True.
        bg_inds = tf.squeeze(tf.where(bg_inds), axis=1)
        bg_inds_size = tf.size(bg_inds)
        # Condition for check if we have too many positive labels.
        subsample_negative_cond = bg_inds_size > num_bg
        # Check the condition and subsample positive labels.
        labels = tf.cond(subsample_negative_cond,
                         true_fn=subsample_negative,
                         false_fn=lambda: labels)

        # Return bbox targets with shape (anchors.shape[0], 4).

        # Find the closest gt box for each anchor.
        argmax_overlaps = tf.argmax(overlaps, axis=1)
        # Eliminate duplicates.
        argmax_overlaps_unique, _ = tf.unique(argmax_overlaps)
        # Filter the gt_boxes.
        # We get only the indices where we have "inside anchors".
        anchor_filter_inds = tf.where(anchor_filter)
        gt_boxes = tf.gather(gt_boxes, argmax_overlaps)

        bbox_targets = encode_tf(anchors, gt_boxes)

        # For the anchors that arent foreground, we ignore the bbox_targets.
        anchor_foreground_filter = tf.equal(labels, 1)
        bbox_targets = tf.where(condition=anchor_foreground_filter,
                                x=bbox_targets,
                                y=tf.zeros_like(bbox_targets))

        # We unroll "inside anchors" value for all anchors (for shape
        # compatibility).

        # We complete the missed indices with zeros
        # (because scatter_nd has zeros as default).
        bbox_targets = tf.scatter_nd(indices=tf.to_int32(anchor_filter_inds),
                                     updates=bbox_targets,
                                     shape=tf.shape(all_anchors))

        labels_scatter = tf.scatter_nd(indices=tf.to_int32(anchor_filter_inds),
                                       updates=labels,
                                       shape=[tf.shape(all_anchors)[0]])
        # We have to put -1 to ignore the indices with 0 generated by
        # scatter_nd, otherwise it will be considered as background.
        labels = tf.where(condition=anchor_filter,
                          x=labels_scatter,
                          y=tf.to_float(tf.fill(tf.shape(labels_scatter), -1)))

        max_overlaps = tf.scatter_nd(indices=tf.to_int32(anchor_filter_inds),
                                     updates=max_overlaps,
                                     shape=[tf.shape(all_anchors)[0]])

        return labels, bbox_targets, max_overlaps