def _positions_center_origin(height, width):
  """Returns image coordinates where the origin at the image center."""
  h = tf.range(0.0, height, 1)
  w = tf.range(0.0, width, 1)
  center_h = tf.cast(height, tf.float32) / 2.0 - 0.5
  center_w = tf.cast(width, tf.float32) / 2.0 - 0.5
  return tf.stack(tf.meshgrid(h - center_h, w - center_w, indexing='ij'), -1)
 def body(var_img, mean_color):
   x0 = tf.random.uniform([], 0, width, dtype=tf.int32)
   y0 = tf.random.uniform([], 0, height, dtype=tf.int32)
   dx = tf.random.uniform([], min_size, max_size, dtype=tf.int32)
   dy = tf.random.uniform([], min_size, max_size, dtype=tf.int32)
   x = tf.range(width)
   x_mask = (x0 <= x) & (x < x0+dx)
   y = tf.range(height)
   y_mask = (y0 <= y) & (y < y0+dy)
   mask = x_mask & y_mask[:, tf.newaxis]
   mask = tf.cast(mask[:, :, tf.newaxis], image_2.dtype)
   result = var_img * (1 - mask) + mean_color * mask
   return result
Example #3
0
  def _build_train_op(self):
    """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """
    target_distribution = tf.stop_gradient(self._build_target_distribution())

    # size of indices: batch_size x 1.
    indices = tf.range(tf.shape(self._replay_logits)[0])[:, None]
    # size of reshaped_actions: batch_size x 2.
    reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
    # For each element of the batch, fetch the logits for its selected action.
    chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions)

    loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=target_distribution,
        logits=chosen_action_logits)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=self.learning_rate,
        epsilon=self.optimizer_epsilon)

    update_priorities_op = self._replay.tf_set_priority(
        self._replay.indices, tf.sqrt(loss + 1e-10))

    target_priorities = self._replay.tf_get_priority(self._replay.indices)
    target_priorities = tf.math.add(target_priorities, 1e-10)
    target_priorities = 1.0 / tf.sqrt(target_priorities)
    target_priorities /= tf.reduce_max(target_priorities)

    weighted_loss = target_priorities * loss

    with tf.control_dependencies([update_priorities_op]):
      return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
Example #4
0
  def _build_target_distribution(self):
    self._reshape_networks()
    batch_size = tf.shape(self._replay.rewards)[0]
    # size of rewards: batch_size x 1
    rewards = self._replay.rewards[:, None]
    # size of tiled_support: batch_size x num_atoms
    tiled_support = tf.tile(self.support, [batch_size])
    tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms])
    # size of target_support: batch_size x num_atoms

    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
    # Incorporate terminal state to discount factor.
    # size of gamma_with_terminal: batch_size x 1
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = gamma_with_terminal[:, None]

    target_support = rewards + gamma_with_terminal * tiled_support
    # size of next_probabilities: batch_size  x num_actions x num_atoms
    next_probabilities = tf.contrib.layers.softmax(
        self._replay_next_logits)

    # size of next_qt: 1 x num_actions
    next_qt = tf.reduce_sum(self.support * next_probabilities, 2)
    # size of next_qt_argmax: 1 x batch_size
    next_qt_argmax = tf.argmax(
        next_qt + self._replay.next_legal_actions, axis=1)[:, None]
    batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
    # size of next_qt_argmax: batch_size x 2
    next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1)
    # size of next_probabilities: batch_size x num_atoms
    next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax)
    return project_distribution(target_support, next_probabilities,
                                self.support)
Example #5
0
    def _mine(self, x_in, y_in):
        """Mutual Infomation Neural Estimator.

        Implement mutual information neural estimator from
        Belghazi et al "Mutual Information Neural Estimation"
        http://proceedings.mlr.press/v80/belghazi18a/belghazi18a.pdf
        'DV':  sup_T E_P(T) - log E_Q(exp(T))
        where P is the joint distribution of X and Y, and Q is the product
         marginal distribution of P. DV is a lower bound for
         KLD(P||Q)=MI(X, Y).

        """
        y_in_tran = transpose2(y_in, 1, 0)
        # tf.random.shuffle() has no gradient defined, so use tf.gather()
        y_shuffle_tran = tf.gather(
            y_in_tran, tf.random.shuffle(tf.range(tf.shape(y_in_tran)[0])))
        y_shuffle = transpose2(y_shuffle_tran, 1, 0)

        # propagate the forward pass
        T_xy, _ = self._network([x_in, y_in])
        T_x_y, _ = self._network([x_in, y_shuffle])

        # compute the negative loss (maximize loss == minimize -loss)
        mean_exp_T_x_y = tf.reduce_mean(tf.math.exp(T_x_y), axis=1)
        loss = tf.reduce_mean(T_xy, axis=1) - tf.math.log(mean_exp_T_x_y)
        loss = tf.squeeze(loss, axis=-1)  # Mutual Information

        return loss
Example #6
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    detections_score = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_score], [-1])
    detections_class = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_class], [-1])
    num_detections = tf.shape(detections_score)[0]
    detections_instance_mask = tf.reshape(
        outputs[
            standard_fields.DetectionResultFields.instance_segments_voxel_mask],
        [num_detections, -1])
    gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                          [-1])
    num_gt = tf.shape(gt_class)[0]
    gt_voxel_instance_ids = tf.reshape(
        inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1])
    gt_instance_masks = tf.transpose(
        tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32))
    for c in self.class_range:
      gt_mask_c = tf.equal(gt_class, c)
      num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
      gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c)
      detections_mask_c = tf.equal(detections_class, c)
      num_detections_c = tf.math.reduce_sum(
          tf.cast(detections_mask_c, dtype=tf.int32))
      if num_detections_c == 0:
        continue
      det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
      det_instance_mask_c = tf.boolean_mask(detections_instance_mask,
                                            detections_mask_c)
      det_scores_c, sorted_indices = tf.math.top_k(
          det_scores_c, k=num_detections_c)
      det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices)
      tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
      if num_gt_c > 0:
        ious_c = instance_segmentation_utils.points_mask_iou(
            masks1=gt_instance_masks_c, masks2=det_instance_mask_c)
        max_overlap_gt_ids = tf.cast(
            tf.math.argmax(ious_c, axis=0), dtype=tf.int32)
        is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
        for i in tf.range(num_detections_c):
          gt_id = max_overlap_gt_ids[i]
          if (ious_c[gt_id, i] > self.iou_threshold and
              is_gt_box_detected[gt_id] == 0):
            tp_c = tf.maximum(
                tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c)
            is_gt_box_detected = tf.maximum(
                tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected)
      self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
      self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
      self.num_gt[c] += num_gt_c
    return tf.no_op()
Example #7
0
 def broken():
     # Using tf.random.uniform here avoids TF optimizations around constants in
     # graph mode (which can change the exception type vs. eager mode).
     one_hundred = tf.random.uniform(shape=(),
                                     minval=100,
                                     maxval=101,
                                     dtype=tf.int32)
     return self.evaluate(tf.range(10)[one_hundred])
Example #8
0
def knn_graph_from_points(points, num_valid_points, k,
                          distance_upper_bound, mask=None):
  """Returns the distances and indices of the neighbors of each point.

  Note that each point will have at least k neighbors unless the number of
  points is less than k. In that case, the python function that is wrapped in
  py_function will raise a value error.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, D] where D is the point
      dimensions.
    num_valid_points: A tf.int32 tensor of size [batch_size] containing the
      number of valid points in each batch example.
    k: Number of neighbors for each point.
    distance_upper_bound: Only build the graph using points that are closer than
      this distance.
    mask: If not None, A tf.bool tensor of size [batch_size, N]. If None, it is
      ignored. If not None, knn will be applied to only points where the mask is
      True. The points where the mask is False will have themselves as their
      neighbors.

  Returns:
    distances: A tf.float32 tensor of size [batch_size, N, k].
    indices: A tf.int32 tensor of size [batch_size, N, k].

  Raises:
    ValueError: If batch_size is unknown.
  """
  if points.get_shape().as_list()[0] is None:
    raise ValueError('Batch size is unknown.')
  batch_size = points.get_shape().as_list()[0]
  num_points = tf.shape(points)[1]

  def fn_knn_graph_from_points_unbatched(i):
    """Computes knn graph for example i in the batch."""
    num_valid_points_i = num_valid_points[i]
    points_i = points[i, :num_valid_points_i, :]
    if mask is None:
      mask_i = None
    else:
      mask_i = mask[i, :num_valid_points_i]
    distances_i, indices_i = knn_graph_from_points_unbatched(
        points=points_i,
        k=k,
        distance_upper_bound=distance_upper_bound,
        mask=mask_i)
    distances_i = tf.pad(
        distances_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
    indices_i = tf.pad(
        indices_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
    return distances_i, indices_i

  distances, indices = tf.map_fn(
      fn=fn_knn_graph_from_points_unbatched,
      elems=tf.range(batch_size),
      dtype=(tf.float32, tf.int32))

  return distances, indices
Example #9
0
def embedding_regularization_loss(inputs,
                                  outputs,
                                  lambda_coef=0.0001,
                                  regularization_type='unit_length',
                                  is_intermediate=False):
  """Classification loss with an iou threshold.

  Args:
    inputs: A dictionary that contains
      num_valid_voxels - A tf.int32 tensor of size [batch_size].
      instance_ids - A tf.int32 tensor of size [batch_size, n].
    outputs: A dictionart that contains
      embeddings - A tf.float32 tensor of size [batch_size, n, f].
    lambda_coef: Regularization loss coefficient.
    regularization_type: Regularization loss type. Supported values are 'msq'
      and 'unit_length'. 'msq' stands for 'mean square' which penalizes the
      embedding vectors if they have a length far from zero. 'unit_length'
      penalizes the embedding vectors if they have a length far from one.
    is_intermediate: True if applied to intermediate predictions;
      otherwise, False.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  instance_ids_key = standard_fields.InputDataFields.object_instance_id_voxels
  num_voxels_key = standard_fields.InputDataFields.num_valid_voxels
  if is_intermediate:
    embedding_key = (
        standard_fields.DetectionResultFields
        .intermediate_instance_embedding_voxels)
  else:
    embedding_key = (
        standard_fields.DetectionResultFields.instance_embedding_voxels)
  if instance_ids_key not in inputs:
    raise ValueError('instance_ids is missing in inputs.')
  if embedding_key not in outputs:
    raise ValueError('embedding is missing in outputs.')
  if num_voxels_key not in inputs:
    raise ValueError('num_voxels is missing in inputs.')
  batch_size = inputs[num_voxels_key].get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('batch_size is not defined at graph construction time.')
  num_valid_voxels = inputs[num_voxels_key]
  num_voxels = tf.shape(inputs[instance_ids_key])[1]
  valid_mask = tf.less(
      tf.tile(tf.expand_dims(tf.range(num_voxels), axis=0), [batch_size, 1]),
      tf.expand_dims(num_valid_voxels, axis=1))
  valid_mask = tf.reshape(valid_mask, [-1])
  embedding_dims = outputs[embedding_key].get_shape().as_list()[-1]
  if embedding_dims is None:
    raise ValueError(
        'Embedding dimension is unknown at graph construction time.')
  embedding = tf.reshape(outputs[embedding_key], [-1, embedding_dims])
  embedding = tf.boolean_mask(embedding, valid_mask)
  return metric_learning_losses.regularization_loss(
      embedding=embedding,
      lambda_coef=lambda_coef,
      regularization_type=regularization_type)
    def permute_dims(representation):
        representation = np.array(representation)

        for j in range(representation.shape[1]):
            permutation_index = tf.random.shuffle(
                tf.range(representation.shape[0]))
            representation[:, j] = representation[permutation_index, j]

        return representation
Example #11
0
def per_voxel_point_sample_segment_func(data, segment_ids, num_segments,
                                        num_samples_per_voxel):
    """Samples features from the points within each voxel.

  Args:
    data: A tf.float32 tensor of size [N, F].
    segment_ids: A tf.int32 tensor of size [N].
    num_segments: Number of segments.
    num_samples_per_voxel: Number of features to sample per voxel. If the voxel
      has less number of points in it, the point features will be padded by 0.

  Returns:
    A tf.float32 tensor of size [num_segments, num_samples_per_voxel, F].
    A tf.int32 indices of size [N, num_samples_per_voxel].
  """
    num_channels = data.get_shape().as_list()[1]
    if num_channels is None:
        raise ValueError('num_channels is None.')
    n = tf.shape(segment_ids)[0]

    def _body_fn(i, indices_range, indices):
        """Computes the indices of the i-th point feature in each segment."""
        indices_i = tf.math.unsorted_segment_max(data=indices_range,
                                                 segment_ids=segment_ids,
                                                 num_segments=num_segments)
        indices_i_positive_mask = tf.greater(indices_i, 0)
        indices_i_positive = tf.boolean_mask(indices_i,
                                             indices_i_positive_mask)
        boolean_mask = tf.scatter_nd(indices=tf.cast(tf.expand_dims(
            indices_i_positive - 1, axis=1),
                                                     dtype=tf.int64),
                                     updates=tf.ones_like(indices_i_positive,
                                                          dtype=tf.int32),
                                     shape=(n, ))
        indices_range *= (1 - boolean_mask)
        indices_i *= tf.cast(indices_i_positive_mask, dtype=tf.int32)
        indices_i = tf.pad(tf.expand_dims(indices_i, axis=1),
                           paddings=[[0, 0],
                                     [i, num_samples_per_voxel - i - 1]])
        indices += indices_i
        i = i + 1
        return i, indices_range, indices

    cond = lambda i, indices_range, indices: i < num_samples_per_voxel

    (_, _, indices) = tf.while_loop(
        cond=cond,
        body=_body_fn,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.range(n) + 1,
                   tf.zeros([num_segments, num_samples_per_voxel],
                            dtype=tf.int32)))

    data = tf.pad(data, paddings=[[1, 0], [0, 0]])
    voxel_features = tf.gather(data, tf.reshape(indices, [-1]))
    return tf.reshape(voxel_features,
                      [num_segments, num_samples_per_voxel, num_channels])
Example #12
0
    def convert_to_simclr_episode(support_images=None,
                                  support_labels=None,
                                  support_class_ids=None,
                                  query_images=None,
                                  query_labels=None,
                                  query_class_ids=None):
        """Convert a single episode into a SimCLR Episode."""

        # If there were k query examples of class c, keep the first k support
        # examples of class c as 'simclr' queries.  We do this by assigning an
        # id for each image in the query set, implemented as label*1e5+x+1, where
        # x is the number of images of the same label with a lower index within
        # the query set.  We do the same for the support set, which gives us a
        # mapping between query and support images which is injective (as long
        # as there's enough support-set images of each class).
        #
        # note: assumes max support label is 10000 - max_images_per_class
        query_idx_within_class = tf.cast(
            tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]),
            tf.int32)
        query_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(query_idx_within_class, axis=1))
        query_uid = query_labels * 10000 + query_idx_within_class
        support_idx_within_class = tf.cast(
            tf.equal(support_labels[tf.newaxis, :],
                     support_labels[:, tf.newaxis]), tf.int32)
        support_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(support_idx_within_class, axis=1))
        support_uid = support_labels * 10000 + support_idx_within_class

        # compute which support-set images have matches in the query set, and
        # discard the rest to produce the new query set.
        support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis],
                                              query_uid[tf.newaxis, :]),
                                     axis=1)
        query_images = tf.boolean_mask(support_images, support_keep)

        support_labels = tf.range(tf.shape(support_labels)[0],
                                  dtype=support_labels.dtype)
        query_labels = tf.boolean_mask(support_labels, support_keep)
        query_class_ids = tf.boolean_mask(support_class_ids, support_keep)

        # Finally, apply SimCLR augmentation to all images.
        # Note simclr only blurs one image.
        query_images = simclr_augment(query_images, blur=True)
        support_images = simclr_augment(support_images)

        return (support_images, support_labels, support_class_ids,
                query_images, query_labels, query_class_ids)
Example #13
0
def identity_knn_graph_unbatched(points, k):
    """Returns each points as its own neighbor k times.

  Args:
    points: A tf.float32 tensor of [N, D] where D is the point dimensions.
    k: Number of neighbors for each point.

  Returns:
    distances: A tf.float32 tensor of [N, k]. Distances is all zeros since
      each point is returned as its own neighbor.
    indices: A tf.int32 tensor of [N, k]. Each row will contain values that
      are identical to the index of that row.
  """
    num_points = tf.shape(points)[0]
    indices = tf.expand_dims(tf.range(num_points), axis=1)
    indices = tf.tile(indices, [1, k])
    distances = tf.zeros([num_points, k], dtype=tf.float32)
    return distances, indices
Example #14
0
def points_offset_in_voxels(points, grid_cell_size):
  """Converts points into offsets in voxel grid.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, 3].
    grid_cell_size: The size of the grid cells in x, y, z dimensions in the
      voxel grid. It should be either a tf.float32 tensor, a numpy array or a
      list of size [3].

  Returns:
    voxel_xyz_offsets: A tf.float32 tensor of size [batch_size, N, 3].
  """
  batch_size = points.get_shape().as_list()[0]

  def fn(i):
    return _points_offset_in_voxels_unbatched(
        points=points[i, :, :], grid_cell_size=grid_cell_size)

  return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
Example #15
0
def sparse_voxel_grid_to_pointcloud(voxel_features, segment_ids,
                                    num_valid_voxels, num_valid_points):
    """Convert voxel features back to points given their segment ids.

  Args:
    voxel_features: A tf.float32 tensor of size [batch_size, N', F].
    segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point
      indicating which (flattened) voxel cell its data was mapped to.
    num_valid_voxels: A tf.int32 tensor of size [batch_size] containing the
      number of valid voxels in each batch example.
    num_valid_points: A tf.int32 tensor of size [batch_size] containing the
      number of valid points in each batch example.

  Returns:
    point_features: A tf.float32 tensor of size [batch_size, N, F].

  Raises:
    ValueError: If batch_size is unknown at graph construction time.
  """
    batch_size = voxel_features.shape[0]
    if batch_size is None:
        raise ValueError('batch_size is unknown at graph construction time.')
    num_points = tf.shape(segment_ids)[1]

    def fn(i):
        num_valid_voxels_i = num_valid_voxels[i]
        num_valid_points_i = num_valid_points[i]
        voxel_features_i = voxel_features[i, :num_valid_voxels_i, :]
        segment_ids_i = segment_ids[i, :num_valid_points_i]
        point_features = tf.gather(voxel_features_i, segment_ids_i)
        point_features_rank = len(point_features.get_shape().as_list())
        point_features_paddings = [[0, num_points - num_valid_points_i]]
        for _ in range(point_features_rank - 1):
            point_features_paddings.append([0, 0])
        point_features = tf.pad(point_features,
                                paddings=point_features_paddings)
        return point_features

    return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
Example #16
0
def identity_knn_graph(points, num_valid_points, k):  # pylint: disable=unused-argument
    """Returns each points as its own neighbor k times.

  Args:
    points: A tf.float32 tensor of size [num_batches, N, D] where D is the point
      dimensions.
    num_valid_points: A tf.int32 tensor of size [num_batches] containing the
      number of valid points in each batch example.
    k: Number of neighbors for each point.

  Returns:
    distances: A tf.float32 tensor of [num_batches, N, k]. Distances is all
      zeros since each point is returned as its own neighbor.
    indices: A tf.int32 tensor of [num_batches, N, k]. Each row will contain
      values that are identical to the index of that row.
  """
    num_batches = points.get_shape()[0]
    num_points = tf.shape(points)[1]
    indices = tf.expand_dims(tf.range(num_points), axis=1)
    indices = tf.tile(tf.expand_dims(indices, axis=0), [num_batches, 1, k])
    distances = tf.zeros([num_batches, num_points, k], dtype=tf.float32)
    return distances, indices
Example #17
0
def compute_episode_stats(episode):
    """Computes various episode stats: way, shots, and class IDs.

  Args:
    episode: An EpisodeDataset.

  Returns:
    way: An int constant tensor. The number of classes in the episode.
    shots: An int 1D tensor: The number of support examples per class.
    class_ids: An int 1D tensor: (absolute) class IDs.
  """
    # The train labels of the next episode.
    train_labels = episode.train_labels
    # Compute way.
    episode_classes, _ = tf.unique(train_labels)
    way = tf.size(episode_classes)
    # Compute shots.
    class_ids = tf.reshape(tf.range(way), [way, 1])
    class_labels = tf.reshape(train_labels, [1, -1])
    is_equal = tf.equal(class_labels, class_ids)
    shots = tf.reduce_sum(tf.cast(is_equal, tf.int32), axis=1)
    # Compute class_ids.
    class_ids, _ = tf.unique(episode.train_class_ids)
    return way, shots, class_ids
Example #18
0
def pointcloud_to_sparse_voxel_grid(points, features, num_valid_points,
                                    grid_cell_size, voxels_pad_or_clip_size,
                                    segment_func):
  """Converts a pointcloud into a voxel grid.

  This function calls the `pointcloud_to_sparse_voxel_grid_unbatched`
  function avove in a while loop to map a batch of points to a batch of voxels.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, 3].
    features: A tf.float32 tensor of size [batch_size, N, F].
    num_valid_points: A tf.int32 tensor of size [num_batches] containing the
      number of valid points in each batch example.
    grid_cell_size: A tf.float32 tensor of size [3].
    voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None,
      it will not perform the padding.
    segment_func: A tensorflow function that operates on segments. Examples are
      one of tf.math.unsorted_segment_{min/max/mean/prod/sum}.

  Returns:
    voxel_features: A tf.float32 tensor of size [batch_size, N', F]
      or [batch_size, N', G, F] where G is the number of points sampled per
      voxel.
    voxel_indices: A tf.int32 tensor of size [batch_size, N', 3].
    num_valid_voxels: A tf.int32 tensor of size [batch_size].
    segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point
      indicating which (flattened) voxel cell its data was mapped to.
    voxel_start_location: A size [batch_size, 3] tf.float32 tensor of voxel
      start locations.

  Raises:
    ValueError: If pooling method is unknown.
  """
  batch_size = points.get_shape().as_list()[0]
  if batch_size is None:
    batch_size = tf.shape(points)[0]
  num_points = tf.shape(points)[1]

  def fn(i):
    """Map function."""
    num_valid_points_i = num_valid_points[i]
    points_i = points[i, :num_valid_points_i, :]
    features_i = features[i, :num_valid_points_i, :]
    voxel_features_i, voxel_indices_i, segment_ids_i, voxel_start_location_i = (
        pointcloud_to_sparse_voxel_grid_unbatched(
            points=points_i,
            features=features_i,
            grid_cell_size=grid_cell_size,
            segment_func=segment_func))
    num_valid_voxels_i = tf.shape(voxel_features_i)[0]
    (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
     segment_ids_i) = _pad_or_clip_voxels(
         voxel_features=voxel_features_i,
         voxel_indices=voxel_indices_i,
         num_valid_voxels=num_valid_voxels_i,
         segment_ids=segment_ids_i,
         voxels_pad_or_clip_size=voxels_pad_or_clip_size)
    segment_ids_i = tf.pad(
        segment_ids_i, paddings=[[0, num_points - num_valid_points_i]])
    return (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
            segment_ids_i, voxel_start_location_i)

  return tf.map_fn(
      fn=fn,
      elems=tf.range(batch_size),
      dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.float32))
Example #19
0
def prepare_scannet_frame_dataset(inputs,
                                  min_pixel_depth=0.3,
                                  max_pixel_depth=6.0,
                                  valid_object_classes=None):
  """Maps the fields from loaded input to standard fields.

  Args:
    inputs: A dictionary of input tensors.
    min_pixel_depth: Pixels with depth values less than this are pruned.
    max_pixel_depth: Pixels with depth values more than this are pruned.
    valid_object_classes: List of valid object classes. if None, it is ignored.

  Returns:
    A dictionary of input tensors with standard field names.
  """
  prepared_inputs = {}
  if 'cameras/rgbd_camera/intrinsics/K' not in inputs:
    raise ValueError('Intrinsic matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/R' not in inputs:
    raise ValueError('Extrinsic rotation matrix is missing.')
  if 'cameras/rgbd_camera/extrinsics/t' not in inputs:
    raise ValueError('Extrinsics translation is missing.')
  if 'cameras/rgbd_camera/depth_image' not in inputs:
    raise ValueError('Depth image is missing.')
  if 'cameras/rgbd_camera/color_image' not in inputs:
    raise ValueError('Color image is missing.')
  if 'frame_name' in inputs:
    prepared_inputs[standard_fields.InputDataFields
                    .camera_image_name] = inputs['frame_name']
  camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K']
  depth_image = inputs['cameras/rgbd_camera/depth_image']
  image_height = tf.shape(depth_image)[0]
  image_width = tf.shape(depth_image)[1]
  x, y = tf.meshgrid(
      tf.range(image_width), tf.range(image_height), indexing='xy')
  x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1])
  y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1])
  point_positions = projections.image_frame_to_camera_frame(
      image_frame=tf.concat([x, y], axis=1),
      camera_intrinsics=camera_intrinsics)
  rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R']
  translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t']
  point_positions = projections.to_world_frame(
      camera_frame_points=point_positions,
      rotate_world_to_camera=rotate_world_to_camera,
      translate_world_to_camera=translate_world_to_camera)
  prepared_inputs[standard_fields.InputDataFields
                  .point_positions] = point_positions * tf.reshape(
                      depth_image, [-1, 1])
  depth_values = tf.reshape(depth_image, [-1])
  valid_depth_mask = tf.logical_and(
      tf.greater_equal(depth_values, min_pixel_depth),
      tf.less_equal(depth_values, max_pixel_depth))
  prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape(
      tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32),
      [-1, 3])
  prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0)
  prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0
  prepared_inputs[
      standard_fields.InputDataFields.point_positions] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_positions],
          valid_depth_mask)
  prepared_inputs[
      standard_fields.InputDataFields.point_colors] = tf.boolean_mask(
          prepared_inputs[standard_fields.InputDataFields.point_colors],
          valid_depth_mask)
  if 'cameras/rgbd_camera/semantic_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]),
            dtype=tf.int32)
    prepared_inputs[
        standard_fields.InputDataFields.object_class_points] = tf.boolean_mask(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            valid_depth_mask)
  if 'cameras/rgbd_camera/instance_image' in inputs:
    prepared_inputs[
        standard_fields.InputDataFields.object_instance_id_points] = tf.cast(
            tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]),
            dtype=tf.int32)
    prepared_inputs[standard_fields.InputDataFields
                    .object_instance_id_points] = tf.boolean_mask(
                        prepared_inputs[standard_fields.InputDataFields
                                        .object_instance_id_points],
                        valid_depth_mask)

  if valid_object_classes is not None:
    valid_objects_mask = tf.cast(
        tf.zeros_like(
            prepared_inputs[
                standard_fields.InputDataFields.object_class_points],
            dtype=tf.int32),
        dtype=tf.bool)
    for object_class in valid_object_classes:
      valid_objects_mask = tf.logical_or(
          valid_objects_mask,
          tf.equal(
              prepared_inputs[
                  standard_fields.InputDataFields.object_class_points],
              object_class))
    valid_objects_mask = tf.cast(
        valid_objects_mask,
        dtype=prepared_inputs[
            standard_fields.InputDataFields.object_class_points].dtype)
    prepared_inputs[standard_fields.InputDataFields
                    .object_class_points] *= valid_objects_mask
  return prepared_inputs
Example #20
0
 def broken(sess):
     index = tf.placeholder(tf.int32, name='index')
     slice_op = tf.range(10)[index]
     sess.run(slice_op, feed_dict={index: 11})
  def on_predict_batch_end(self, batch, logs=None):
    """Write mesh summaries of semantics groundtruth and prediction point clouds at the end of each validation batch."""
    inputs = logs['inputs']
    outputs = logs['outputs']
    if self._metric:
      for metric in self._metric:
        metric.update_state(inputs=inputs, outputs=outputs)

    if batch <= self.num_qualitative_examples:
      # point cloud visualization
      vertices = tf.reshape(
          inputs[standard_fields.InputDataFields.point_positions], [-1, 3])
      num_valid_points = tf.squeeze(
          inputs[standard_fields.InputDataFields.num_valid_points])
      logits = outputs[
          standard_fields.DetectionResultFields.object_semantic_points]
      num_classes = logits.get_shape().as_list()[-1]
      logits = tf.reshape(logits, [-1, num_classes])
      gt_semantic_class = tf.reshape(
          inputs[standard_fields.InputDataFields.object_class_points], [-1])

      vertices = vertices[:num_valid_points, :]
      logits = logits[:num_valid_points, :]
      gt_semantic_class = gt_semantic_class[:num_valid_points]
      max_num_points = tf.math.minimum(self.max_num_points_qualitative,
                                       num_valid_points)
      sample_indices = tf.random.shuffle(
          tf.range(num_valid_points))[:max_num_points]
      vertices = tf.gather(vertices, sample_indices)
      logits = tf.gather(logits, sample_indices)
      gt_semantic_class = tf.gather(gt_semantic_class, sample_indices)
      semantic_class = tf.math.argmax(logits, axis=1)
      pred_colors = tf.gather(self._pascal_color_map, semantic_class, axis=0)
      gt_colors = tf.gather(self._pascal_color_map, gt_semantic_class, axis=0)

      if standard_fields.InputDataFields.point_colors in inputs:
        point_colors = (tf.reshape(
            inputs[standard_fields.InputDataFields.point_colors], [-1, 3]) +
                        1.0) * 255.0 / 2.0
        point_colors = point_colors[:num_valid_points, :]
        point_colors = tf.gather(point_colors, sample_indices)
        point_colors = tf.math.minimum(point_colors, 255.0)
        point_colors = tf.math.maximum(point_colors, 0.0)
        point_colors = tf.cast(point_colors, dtype=tf.uint8)
      else:
        point_colors = tf.ones_like(vertices, dtype=tf.uint8) * 128

      # add points and colors for predicted objects
      if standard_fields.DetectionResultFields.objects_length in outputs:
        box_corners = box_utils.get_box_corners_3d(
            boxes_length=outputs[
                standard_fields.DetectionResultFields.objects_length],
            boxes_height=outputs[
                standard_fields.DetectionResultFields.objects_height],
            boxes_width=outputs[
                standard_fields.DetectionResultFields.objects_width],
            boxes_rotation_matrix=outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            boxes_center=outputs[
                standard_fields.DetectionResultFields.objects_center])
        box_points = box_utils.get_box_as_dotted_lines(box_corners)

        objects_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        box_colors = tf.gather(self._pascal_color_map, objects_class, axis=0)
        box_colors = tf.repeat(
            box_colors[:, tf.newaxis, :], box_points.shape[1], axis=1)
        box_points = tf.reshape(box_points, [-1, 3])
        box_colors = tf.reshape(box_colors, [-1, 3])
        pred_vertices = tf.concat([vertices, box_points], axis=0)
        pred_colors = tf.concat([pred_colors, box_colors], axis=0)
      else:
        pred_vertices = vertices

      # add points and colors for gt objects
      if standard_fields.InputDataFields.objects_length in inputs:
        box_corners = box_utils.get_box_corners_3d(
            boxes_length=tf.reshape(
                inputs[standard_fields.InputDataFields.objects_length],
                [-1, 1]),
            boxes_height=tf.reshape(
                inputs[standard_fields.InputDataFields.objects_height],
                [-1, 1]),
            boxes_width=tf.reshape(
                inputs[standard_fields.InputDataFields.objects_width], [-1, 1]),
            boxes_rotation_matrix=tf.reshape(
                inputs[standard_fields.InputDataFields.objects_rotation_matrix],
                [-1, 3, 3]),
            boxes_center=tf.reshape(
                inputs[standard_fields.InputDataFields.objects_center],
                [-1, 3]))
        box_points = box_utils.get_box_as_dotted_lines(box_corners)

        objects_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        box_colors = tf.gather(self._pascal_color_map, objects_class, axis=0)
        box_colors = tf.repeat(
            box_colors[:, tf.newaxis, :], box_points.shape[1], axis=1)

        box_points = tf.reshape(box_points, [-1, 3])
        box_colors = tf.reshape(box_colors, [-1, 3])
        gt_vertices = tf.concat([vertices, box_points], axis=0)
        gt_colors = tf.concat([gt_colors, box_colors], axis=0)
      else:
        gt_vertices = vertices
      if batch == 1:
        logging.info('writing point cloud(shape %s) to summery.',
                     gt_vertices.shape)
      if standard_fields.InputDataFields.camera_image_name in inputs:
        camera_image_name = str(inputs[
            standard_fields.InputDataFields.camera_image_name].numpy()[0])
      else:
        camera_image_name = str(batch)
      logging.info(camera_image_name)
      with self._val_mesh_writer.as_default():
        mesh_summary.mesh(
            name=(self.split + '_points/' + camera_image_name),
            vertices=tf.expand_dims(vertices, axis=0),
            faces=None,
            colors=tf.expand_dims(point_colors, axis=0),
            config_dict=self._mesh_config_dict,
            step=self._val_step,
        )
        mesh_summary.mesh(
            name=(self.split + '_predictions/' + camera_image_name),
            vertices=tf.expand_dims(pred_vertices, axis=0),
            faces=None,
            colors=tf.expand_dims(pred_colors, axis=0),
            config_dict=self._mesh_config_dict,
            step=self._val_step,
        )
        mesh_summary.mesh(
            name=(self.split + '_ground_truth/' + camera_image_name),
            vertices=tf.expand_dims(gt_vertices, axis=0),
            faces=None,
            colors=tf.expand_dims(gt_colors, axis=0),
            config_dict=self._mesh_config_dict,
            step=self._val_step,
        )
      if batch == self.num_qualitative_examples:
        self._val_mesh_writer.flush()
def classification_loss_using_mask_iou(inputs,
                                       outputs,
                                       num_samples,
                                       max_instance_id=None,
                                       similarity_strategy='distance',
                                       is_balanced=True,
                                       is_intermediate=False):
    """Classification loss with an iou threshold.

  Args:
    inputs: A dictionary that contains
      num_valid_voxels - A tf.int32 tensor of size [batch_size].
      instance_ids - A tf.int32 tensor of size [batch_size, n].
      class_labels - A tf.int32 tensor of size [batch_size, n]. It is assumed
        that the background voxels are assigned to class 0.
    outputs: A dictionart that contains
      embeddings - A tf.float32 tensor of size [batch_size, n, f].
      logits - A tf.float32 tensor of size [batch_size, n, num_classes]. It is
        assumed that background is class 0.
    num_samples: An int determining the number of samples.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.
    is_intermediate: True if applied to intermediate predictions;
      otherwise, False.

  Returns:
    A tf.float32 scalar loss tensor.
  """
    instance_ids_key = standard_fields.InputDataFields.object_instance_id_voxels
    class_labels_key = standard_fields.InputDataFields.object_class_voxels
    num_voxels_key = standard_fields.InputDataFields.num_valid_voxels
    if is_intermediate:
        embedding_key = (standard_fields.DetectionResultFields.
                         intermediate_instance_embedding_voxels)
        logits_key = (standard_fields.DetectionResultFields.
                      intermediate_object_semantic_voxels)
    else:
        embedding_key = (
            standard_fields.DetectionResultFields.instance_embedding_voxels)
        logits_key = standard_fields.DetectionResultFields.object_semantic_voxels
    if instance_ids_key not in inputs:
        raise ValueError('instance_ids is missing in inputs.')
    if class_labels_key not in inputs:
        raise ValueError('class_labels is missing in inputs.')
    if num_voxels_key not in inputs:
        raise ValueError('num_voxels is missing in inputs.')
    if embedding_key not in outputs:
        raise ValueError('embedding is missing in outputs.')
    if logits_key not in outputs:
        raise ValueError('logits is missing in outputs.')
    batch_size = inputs[num_voxels_key].get_shape().as_list()[0]
    if batch_size is None:
        raise ValueError(
            'batch_size is not defined at graph construction time.')
    num_valid_voxels = inputs[num_voxels_key]
    num_voxels = tf.shape(inputs[instance_ids_key])[1]
    valid_mask = tf.less(
        tf.tile(tf.expand_dims(tf.range(num_voxels), axis=0), [batch_size, 1]),
        tf.expand_dims(num_valid_voxels, axis=1))
    return classification_loss_using_mask_iou_func(
        embeddings=outputs[embedding_key],
        logits=outputs[logits_key],
        instance_ids=tf.reshape(inputs[instance_ids_key], [batch_size, -1]),
        class_labels=inputs[class_labels_key],
        num_samples=num_samples,
        valid_mask=valid_mask,
        max_instance_id=max_instance_id,
        similarity_strategy=similarity_strategy,
        is_balanced=is_balanced)
Example #23
0
    def update_state(self, inputs, outputs):
        """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
        detections_score = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_score], [-1])
        detections_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        detections_length = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_length],
            [-1])
        detections_height = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_height],
            [-1])
        detections_width = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_width], [-1])
        detections_center = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_center],
            [-1, 3])
        detections_rotation_matrix = tf.reshape(
            outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            [-1, 3, 3])
        gt_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        gt_length = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_length], [-1])
        gt_height = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_height], [-1])
        gt_width = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_width], [-1])
        gt_center = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_center], [-1, 3])
        gt_rotation_matrix = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_rotation_matrix],
            [-1, 3, 3])
        for c in self.class_range:
            gt_mask_c = tf.equal(gt_class, c)
            num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
            gt_length_c = tf.boolean_mask(gt_length, gt_mask_c)
            gt_height_c = tf.boolean_mask(gt_height, gt_mask_c)
            gt_width_c = tf.boolean_mask(gt_width, gt_mask_c)
            gt_center_c = tf.boolean_mask(gt_center, gt_mask_c)
            gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix,
                                                   gt_mask_c)
            detections_mask_c = tf.equal(detections_class, c)
            num_detections_c = tf.math.reduce_sum(
                tf.cast(detections_mask_c, dtype=tf.int32))
            if num_detections_c == 0:
                continue
            det_length_c = tf.boolean_mask(detections_length,
                                           detections_mask_c)
            det_height_c = tf.boolean_mask(detections_height,
                                           detections_mask_c)
            det_width_c = tf.boolean_mask(detections_width, detections_mask_c)
            det_center_c = tf.boolean_mask(detections_center,
                                           detections_mask_c)
            det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix,
                                                    detections_mask_c)
            det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
            det_scores_c, sorted_indices = tf.math.top_k(det_scores_c,
                                                         k=num_detections_c)
            det_length_c = tf.gather(det_length_c, sorted_indices)
            det_height_c = tf.gather(det_height_c, sorted_indices)
            det_width_c = tf.gather(det_width_c, sorted_indices)
            det_center_c = tf.gather(det_center_c, sorted_indices)
            det_rotation_matrix_c = tf.gather(det_rotation_matrix_c,
                                              sorted_indices)
            tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
            if num_gt_c > 0:
                ious_c = box_ops.iou3d(
                    boxes1_length=gt_length_c,
                    boxes1_height=gt_height_c,
                    boxes1_width=gt_width_c,
                    boxes1_center=gt_center_c,
                    boxes1_rotation_matrix=gt_rotation_matrix_c,
                    boxes2_length=det_length_c,
                    boxes2_height=det_height_c,
                    boxes2_width=det_width_c,
                    boxes2_center=det_center_c,
                    boxes2_rotation_matrix=det_rotation_matrix_c)
                max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0),
                                             dtype=tf.int32)
                is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
                for i in tf.range(num_detections_c):
                    gt_id = max_overlap_gt_ids[i]
                    if (ious_c[gt_id, i] > self.iou_threshold
                            and is_gt_box_detected[gt_id] == 0):
                        tp_c = tf.maximum(
                            tf.one_hot(i, num_detections_c, dtype=tf.int32),
                            tp_c)
                        is_gt_box_detected = tf.maximum(
                            tf.one_hot(gt_id, num_gt_c, dtype=tf.int32),
                            is_gt_box_detected)
            self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
            self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
            self.num_gt[c] += num_gt_c
        return tf.no_op()
Example #24
0
def npair_loss(inputs,
               outputs,
               num_samples,
               max_instance_id=None,
               similarity_strategy='distance',
               loss_strategy='softmax',
               is_intermediate=False):
  """N-pair metric learning loss for learning feature embeddings.

  Args:
    inputs: A dictionary that contains
      instance_ids - A tf.int32 tensor of size [batch_size, n].
      valid_mask - A tf.bool tensor of size [batch_size, n] that is True when an
        element is valid and False if it needs to be ignored. By default the
        value is None which means it is not applied.
    outputs: A dictionary that contains
      embeddings - A tf.float32 tensor of size [batch_size, n, f].
    num_samples: An int determinig the number of samples.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'.
    is_intermediate: True if applied to intermediate predictions;
      otherwise, False.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  instance_ids_key = standard_fields.InputDataFields.object_instance_id_voxels
  num_voxels_key = standard_fields.InputDataFields.num_valid_voxels
  if is_intermediate:
    embedding_key = (
        standard_fields.DetectionResultFields
        .intermediate_instance_embedding_voxels)
  else:
    embedding_key = (
        standard_fields.DetectionResultFields.instance_embedding_voxels)
  if instance_ids_key not in inputs:
    raise ValueError('object_instance_id_voxels is missing in inputs.')
  if num_voxels_key not in inputs:
    raise ValueError('num_voxels is missing in inputs.')
  if embedding_key not in outputs:
    raise ValueError('embedding key is missing in outputs.')
  batch_size = inputs[num_voxels_key].get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('batch_size is not defined at graph construction time.')
  num_valid_voxels = inputs[num_voxels_key]
  num_voxels = tf.shape(inputs[instance_ids_key])[1]
  valid_mask = tf.less(
      tf.tile(tf.expand_dims(tf.range(num_voxels), axis=0), [batch_size, 1]),
      tf.expand_dims(num_valid_voxels, axis=1))
  return npair_loss_func(
      embeddings=outputs[embedding_key],
      instance_ids=tf.reshape(inputs[instance_ids_key], [batch_size, -1]),
      num_samples=num_samples,
      valid_mask=valid_mask,
      max_instance_id=max_instance_id,
      similarity_strategy=similarity_strategy,
      loss_strategy=loss_strategy)