Ejemplo n.º 1
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    detections_score = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_score], [-1])
    detections_class = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_class], [-1])
    num_detections = tf.shape(detections_score)[0]
    detections_instance_mask = tf.reshape(
        outputs[
            standard_fields.DetectionResultFields.instance_segments_voxel_mask],
        [num_detections, -1])
    gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                          [-1])
    num_gt = tf.shape(gt_class)[0]
    gt_voxel_instance_ids = tf.reshape(
        inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1])
    gt_instance_masks = tf.transpose(
        tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32))
    for c in self.class_range:
      gt_mask_c = tf.equal(gt_class, c)
      num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
      gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c)
      detections_mask_c = tf.equal(detections_class, c)
      num_detections_c = tf.math.reduce_sum(
          tf.cast(detections_mask_c, dtype=tf.int32))
      if num_detections_c == 0:
        continue
      det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
      det_instance_mask_c = tf.boolean_mask(detections_instance_mask,
                                            detections_mask_c)
      det_scores_c, sorted_indices = tf.math.top_k(
          det_scores_c, k=num_detections_c)
      det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices)
      tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
      if num_gt_c > 0:
        ious_c = instance_segmentation_utils.points_mask_iou(
            masks1=gt_instance_masks_c, masks2=det_instance_mask_c)
        max_overlap_gt_ids = tf.cast(
            tf.math.argmax(ious_c, axis=0), dtype=tf.int32)
        is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
        for i in tf.range(num_detections_c):
          gt_id = max_overlap_gt_ids[i]
          if (ious_c[gt_id, i] > self.iou_threshold and
              is_gt_box_detected[gt_id] == 0):
            tp_c = tf.maximum(
                tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c)
            is_gt_box_detected = tf.maximum(
                tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected)
      self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
      self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
      self.num_gt[c] += num_gt_c
    return tf.no_op()
Ejemplo n.º 2
0
    def _build_select_slate_op(self):
        p_no_click = self._prob_no_click_ph
        p = self._doc_affinity_scores_ph
        q = self._net_outputs.q_values[0]
        with tf.name_scope('select_slate'):
            self._output_slate = self._select_slate_fn(self._slate_size,
                                                       p_no_click, p, q)

        self._output_slate = tf.Print(
            self._output_slate,
            [tf.constant('cp 1'), self._output_slate, p, q],
            summarize=10000)
        self._output_slate = tf.reshape(self._output_slate,
                                        (self._slate_size, ))

        self._action_counts = tf.get_variable(
            'action_counts',
            shape=[self._num_candidates],
            initializer=tf.zeros_initializer())
        output_slate = tf.reshape(self._output_slate, [-1])
        output_one_hot = tf.one_hot(output_slate, self._num_candidates)
        update_ops = []
        for i in range(self._slate_size):
            update_ops.append(
                tf.assign_add(self._action_counts, output_one_hot[i]))
        self._select_action_update_op = tf.group(*update_ops)
    def _build_train_op(self):
        """Builds the training op for Rainbow.

    Returns:
      train_op: An op performing one step of training.
    """

        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_qs *
                                        replay_action_one_hot,
                                        reduction_indices=1,
                                        name='replay_chosen_q')

        target = tf.stop_gradient(self._build_target_q_op())
        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)

        update_priorities_op = self._replay.tf_set_priority(
            self._replay.indices, tf.sqrt(loss + 1e-10))

        target_priorities = self._replay.tf_get_priority(self._replay.indices)
        target_priorities = tf.math.add(target_priorities, 1e-10)
        target_priorities = 1.0 / tf.sqrt(target_priorities)
        target_priorities /= tf.reduce_max(target_priorities)

        weighted_loss = target_priorities * loss

        with tf.control_dependencies([update_priorities_op]):
            return self.optimizer.minimize(
                tf.reduce_mean(weighted_loss)), weighted_loss
Ejemplo n.º 4
0
 def _generate_goal(self, observation, state):
     """Generate a goal."""
     batch_size = tf.shape(tf.nest.flatten(observation)[0])[0]
     # generate goal with the same batch size as observation
     samples = tf.random.categorical(tf.math.log([self._p_goal]),
                                     batch_size)
     samples_onehot = tf.one_hot(indices=samples, depth=self._num_of_goals)
     samples_onehot = tf.reshape(samples_onehot, [batch_size, -1])
     return samples_onehot
def classification_loss_using_mask_iou_func_unbatched(
        embeddings, instance_ids, sampled_embeddings, sampled_instance_ids,
        sampled_class_labels, sampled_logits, similarity_strategy,
        is_balanced):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [n, f].
    instance_ids: A tf.int32 tensor of size [n].
    sampled_embeddings: A tf.float32 tensor of size [num_samples, f].
    sampled_instance_ids: A tf.int32 tensor of size [num_samples].
    sampled_class_labels: A tf.int32 tensor of size [num_samples, 1].
    sampled_logits: A tf.float32 tensor of size [num_samples, num_classes].
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 loss scalar tensor.
  """
    predicted_soft_masks = metric_learning_utils.embedding_centers_to_soft_masks(
        embedding=embeddings,
        centers=sampled_embeddings,
        similarity_strategy=similarity_strategy)
    predicted_masks = tf.cast(tf.greater(predicted_soft_masks, 0.5),
                              dtype=tf.float32)
    gt_masks = tf.cast(tf.equal(tf.expand_dims(sampled_instance_ids, axis=1),
                                tf.expand_dims(instance_ids, axis=0)),
                       dtype=tf.float32)
    pairwise_iou = instance_segmentation_utils.points_mask_pairwise_iou(
        masks1=predicted_masks, masks2=gt_masks)
    num_classes = sampled_logits.get_shape().as_list()[1]
    sampled_class_labels_one_hot = tf.one_hot(indices=tf.reshape(
        sampled_class_labels, [-1]),
                                              depth=num_classes)
    sampled_class_labels_one_hot_fg = sampled_class_labels_one_hot[:, 1:]
    iou_coefs = tf.tile(tf.reshape(pairwise_iou, [-1, 1]),
                        [1, num_classes - 1])
    sampled_class_labels_one_hot_fg *= iou_coefs
    sampled_class_labels_one_hot_bg = tf.maximum(
        1.0 - tf.math.reduce_sum(
            sampled_class_labels_one_hot_fg, axis=1, keepdims=True), 0.0)
    sampled_class_labels_one_hot = tf.concat(
        [sampled_class_labels_one_hot_bg, sampled_class_labels_one_hot_fg],
        axis=1)
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(sampled_instance_ids, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=sampled_logits,
                                  labels=sampled_class_labels_one_hot,
                                  **params)
  def _build_train_op(self):
    """Builds a training op.

    Returns:
      train_op: An op performing one step of training.
    """
    replay_action_one_hot = tf.one_hot(
        self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
    replay_chosen_q = tf.reduce_sum(
        self._replay_qs * replay_action_one_hot,
        reduction_indices=1,
        name='replay_chosen_q')

    target = tf.stop_gradient(self._build_target_q_op())
    loss = tf.losses.huber_loss(
        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
    return self.optimizer.minimize(tf.reduce_mean(loss))
def classification_loss_fn(logits, labels, num_valid_voxels=None, weights=1.0):
    """Semantic segmentation cross entropy loss."""
    logits_rank = len(logits.get_shape().as_list())
    labels_rank = len(labels.get_shape().as_list())
    if logits_rank != labels_rank:
        raise ValueError('Logits and labels should have the same rank.')
    if logits_rank != 2 and logits_rank != 3:
        raise ValueError(
            'Logits and labels should have either 2 or 3 dimensions.')
    if logits_rank == 2:
        if num_valid_voxels is not None:
            raise ValueError(
                '`num_valid_voxels` should be None if not using batched logits.'
            )
    elif logits_rank == 3:
        if num_valid_voxels is None:
            raise ValueError(
                '`num_valid_voxels` cannot be None if using batched logits.')
    if logits_rank == 3:
        if (isinstance(weights, tf.Tensor)
                and len(weights.get_shape().as_list()) == 3):
            use_weights = True
        else:
            use_weights = False
        batch_size = logits.get_shape().as_list()[0]
        logits_list = []
        labels_list = []
        weights_list = []
        for i in range(batch_size):
            num_valid_voxels_i = num_valid_voxels[i]
            logits_list.append(logits[i, 0:num_valid_voxels_i, :])
            labels_list.append(labels[i, 0:num_valid_voxels_i, :])
            if use_weights:
                weights_list.append(weights[i, 0:num_valid_voxels_i, :])
        logits = tf.concat(logits_list, axis=0)
        labels = tf.concat(labels_list, axis=0)
        if use_weights:
            weights = tf.concat(weights_list, axis=0)
    weights = tf.convert_to_tensor(weights, dtype=tf.float32)
    if labels.get_shape().as_list()[-1] == 1:
        num_classes = logits.get_shape().as_list()[-1]
        labels = tf.one_hot(tf.reshape(labels, shape=[-1]), num_classes)
    losses = tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.stop_gradient(labels), logits=logits)
    return tf.reduce_mean(losses * tf.reshape(weights, [-1]))
 def loss_fn():
     """Loss function."""
     num_classes = logits.get_shape().as_list()[-1]
     if num_classes is None:
         raise ValueError('Number of classes is unknown.')
     masked_logits = tf.boolean_mask(logits, background_mask)
     masked_weights = tf.pow(
         1.0 - tf.reshape(tf.nn.softmax(masked_logits)[:, 0], [-1, 1]),
         gamma)
     num_points = tf.shape(masked_logits)[0]
     masked_weights = masked_weights * tf.cast(
         num_points, dtype=tf.float32) / tf.reduce_sum(masked_weights)
     masked_labels_one_hot = tf.one_hot(indices=tf.boolean_mask(
         labels, background_mask),
                                        depth=num_classes)
     loss = classification_loss_fn(logits=masked_logits,
                                   labels=masked_labels_one_hot,
                                   weights=masked_weights)
     return loss
 def loss_fn():
     """Loss function."""
     num_classes = logits.get_shape().as_list()[-1]
     if num_classes is None:
         raise ValueError('Number of classes is unknown.')
     labels_one_hot = tf.one_hot(indices=(labels - 1),
                                 depth=(num_classes - 1))
     inverse_distance_coef = tf.maximum(
         tf.minimum(
             1.0 -
             normalized_center_distance / max_positive_normalized_distance,
             1.0), 0.0)
     labels_one_hot = tf.reshape(inverse_distance_coef,
                                 [-1, 1]) * labels_one_hot
     background_label = 1.0 - tf.math.reduce_sum(
         labels_one_hot, axis=1, keepdims=True)
     labels_one_hot = tf.concat([background_label, labels_one_hot], axis=1)
     loss = classification_loss_fn(logits=logits,
                                   labels=labels_one_hot,
                                   **params)
     return loss
Ejemplo n.º 10
0
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
        replay_next_target_value = tf.reduce_max(
            self._replay_next_target_net_outputs.q_values, 1)
        replay_target_value = tf.reduce_max(
            self._replay_target_net_outputs.q_values, 1)

        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_chosen_q = tf.reduce_sum(self._replay_net_outputs.q_values *
                                        replay_action_one_hot,
                                        axis=1,
                                        name='replay_chosen_q')
        replay_target_chosen_q = tf.reduce_sum(
            self._replay_target_net_outputs.q_values * replay_action_one_hot,
            axis=1,
            name='replay_chosen_q')

        augmented_rewards = self._replay.rewards - self.alpha * (
            replay_target_value - replay_target_chosen_q)

        target = (augmented_rewards +
                  self.cumulative_gamma * replay_next_target_value *
                  (1. - tf.cast(self._replay.terminals, tf.float32)))
        target = tf.stop_gradient(target)

        loss = tf.losses.huber_loss(target,
                                    replay_chosen_q,
                                    reduction=tf.losses.Reduction.NONE)
        if self.summary_writer is not None:
            with tf.variable_scope('Losses'):
                tf.summary.scalar('HuberLoss', tf.reduce_mean(loss))
        return self.optimizer.minimize(tf.reduce_mean(loss))
Ejemplo n.º 11
0
 def set_element(v, i, x):
     mask = tf.one_hot(i, tf.shape(input=v)[0])
     v_new = tf.ones_like(v) * x
     return tf.where(tf.equal(mask, 1), v_new, v)
Ejemplo n.º 12
0
 def _encode_action(self, action):
     if tensor_spec.is_discrete(self._action_spec):
         return tf.one_hot(indices=action, depth=self._num_actions)
     else:
         return action
Ejemplo n.º 13
0
    def update_state(self, inputs, outputs):
        """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
        detections_score = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_score], [-1])
        detections_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        detections_length = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_length],
            [-1])
        detections_height = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_height],
            [-1])
        detections_width = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_width], [-1])
        detections_center = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_center],
            [-1, 3])
        detections_rotation_matrix = tf.reshape(
            outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            [-1, 3, 3])
        gt_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        gt_length = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_length], [-1])
        gt_height = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_height], [-1])
        gt_width = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_width], [-1])
        gt_center = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_center], [-1, 3])
        gt_rotation_matrix = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_rotation_matrix],
            [-1, 3, 3])
        for c in self.class_range:
            gt_mask_c = tf.equal(gt_class, c)
            num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
            gt_length_c = tf.boolean_mask(gt_length, gt_mask_c)
            gt_height_c = tf.boolean_mask(gt_height, gt_mask_c)
            gt_width_c = tf.boolean_mask(gt_width, gt_mask_c)
            gt_center_c = tf.boolean_mask(gt_center, gt_mask_c)
            gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix,
                                                   gt_mask_c)
            detections_mask_c = tf.equal(detections_class, c)
            num_detections_c = tf.math.reduce_sum(
                tf.cast(detections_mask_c, dtype=tf.int32))
            if num_detections_c == 0:
                continue
            det_length_c = tf.boolean_mask(detections_length,
                                           detections_mask_c)
            det_height_c = tf.boolean_mask(detections_height,
                                           detections_mask_c)
            det_width_c = tf.boolean_mask(detections_width, detections_mask_c)
            det_center_c = tf.boolean_mask(detections_center,
                                           detections_mask_c)
            det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix,
                                                    detections_mask_c)
            det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
            det_scores_c, sorted_indices = tf.math.top_k(det_scores_c,
                                                         k=num_detections_c)
            det_length_c = tf.gather(det_length_c, sorted_indices)
            det_height_c = tf.gather(det_height_c, sorted_indices)
            det_width_c = tf.gather(det_width_c, sorted_indices)
            det_center_c = tf.gather(det_center_c, sorted_indices)
            det_rotation_matrix_c = tf.gather(det_rotation_matrix_c,
                                              sorted_indices)
            tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
            if num_gt_c > 0:
                ious_c = box_ops.iou3d(
                    boxes1_length=gt_length_c,
                    boxes1_height=gt_height_c,
                    boxes1_width=gt_width_c,
                    boxes1_center=gt_center_c,
                    boxes1_rotation_matrix=gt_rotation_matrix_c,
                    boxes2_length=det_length_c,
                    boxes2_height=det_height_c,
                    boxes2_width=det_width_c,
                    boxes2_center=det_center_c,
                    boxes2_rotation_matrix=det_rotation_matrix_c)
                max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0),
                                             dtype=tf.int32)
                is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
                for i in tf.range(num_detections_c):
                    gt_id = max_overlap_gt_ids[i]
                    if (ious_c[gt_id, i] > self.iou_threshold
                            and is_gt_box_detected[gt_id] == 0):
                        tp_c = tf.maximum(
                            tf.one_hot(i, num_detections_c, dtype=tf.int32),
                            tp_c)
                        is_gt_box_detected = tf.maximum(
                            tf.one_hot(gt_id, num_gt_c, dtype=tf.int32),
                            is_gt_box_detected)
            self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
            self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
            self.num_gt[c] += num_gt_c
        return tf.no_op()
Ejemplo n.º 14
0
    def forward_pass(self, data):
        """Computes the query logits for the given episode `data`."""

        if self.film_init == 'scratch':
            self.film_selector = None
        elif self.film_init == 'imagenet':
            # Note: this makes the assumption that the first set of learned FiLM
            # parameters corresponds to the ImageNet dataset. Otherwise, the
            # following line should be changed appropriately.
            self.film_selector = 0
        elif self.film_init in ['blender', 'blender_hard']:
            dataset_logits = functional_backbones.dataset_classifier(
                data.support_images)
            if self.film_init == 'blender_hard':
                # Select only the argmax entry.
                self.film_selector = tf.one_hot(
                    tf.math.argmax(dataset_logits, axis=-1),
                    depth=tf.shape(dataset_logits)[1])
            else:
                # Take a convex combination.
                self.film_selector = tf.nn.softmax(dataset_logits, axis=-1)

        if self.num_steps:
            # Initial forward pass, required for the `unused_op` below and for placing
            # variables in tf.trainable_variables() for the below block to pick up.
            loss = self._compute_losses(data, compute_on_query=False)['loss']

            # Pick out the variables to optimize.
            self.opt_vars = []
            for var in tf.trainable_variables():
                if '_for_film_learner' in var.name:
                    self.opt_vars.append(var)
            tf.logging.info('FiLMLearner will optimize vars: {}'.format(
                self.opt_vars))

        for i in range(self.num_steps):
            if i == 0:
                # Re-initialize the variables to optimize for the new episode, to ensure
                # the FiLM parameters aren't re-used across tasks of a given dataset.
                vars_reset = tf.variables_initializer(var_list=self.opt_vars)
                # Adam related variables are created when minimize() is called.
                # We create an unused op here to put all adam varariables under
                # the 'adam_opt' namescope and create a reset op to reinitialize
                # these variables before the first finetune step.
                with tf.variable_scope('adam_opt', reuse=tf.AUTO_REUSE):
                    unused_op = self.opt.minimize(loss, var_list=self.opt_vars)
                adam_reset = tf.variables_initializer(self.opt.variables())

                with tf.control_dependencies([vars_reset, adam_reset, loss] +
                                             self.opt_vars):
                    print_op = tf.no_op()
                    if self.debug_log:
                        print_op = tf.print([
                            'step: %d' % i, self.opt_vars[0][0], 'loss:', loss
                        ],
                                            summarize=-1)

                    with tf.control_dependencies([print_op]):
                        # Get the train op.
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

            else:
                with tf.control_dependencies([train_op, loss, acc] +
                                             self.opt_vars +
                                             [query_loss, query_acc] *
                                             int(self.debug_log)):

                    print_op = tf.no_op()
                    if self.debug_log:
                        print_list = [
                            '################',
                            'step: %d' % i,
                            self.opt_vars[0][0],
                            'support loss:',
                            loss,
                            'query loss:',
                            query_loss,
                            'support acc:',
                            acc,
                            'query acc:',
                            query_acc,
                        ]
                        print_op = tf.print(print_list)

                    with tf.control_dependencies([print_op]):
                        # Get the train op (the loss is returned just for printing).
                        results = self._get_train_op(data)
                        (train_op, loss, query_loss, acc,
                         query_acc) = (results['train_op'], results['loss'],
                                       results['query_loss'], results['acc'],
                                       results['query_acc'])

        # Training is now over, compute the final query logits.
        dependency_list = [] if not self.num_steps else [train_op
                                                         ] + self.opt_vars
        with tf.control_dependencies(dependency_list):
            results = self._compute_losses(data, compute_on_query=True)
            (loss, query_loss, query_logits, acc,
             query_acc) = (results['loss'], results['query_loss'],
                           results['query_logits'], results['acc'],
                           results['query_acc'])

            print_op = tf.no_op()
            if self.debug_log:
                print_op = tf.print([
                    'Done training',
                    'support loss:',
                    loss,
                    'query loss:',
                    query_loss,
                    'support acc:',
                    acc,
                    'query acc:',
                    query_acc,
                ])
            with tf.control_dependencies([print_op]):
                query_logits = tf.identity(query_logits)

        return query_logits
Ejemplo n.º 15
0
def npair_loss_func(embeddings,
                    instance_ids,
                    num_samples,
                    valid_mask=None,
                    max_instance_id=None,
                    similarity_strategy='dotproduct',
                    loss_strategy='softmax'):
  """N-pair metric learning loss for learning feature embeddings.

  Args:
    embeddings: A tf.float32 tensor of size [batch_size, n, f].
    instance_ids: A tf.int32 tensor of size [batch_size, n].
    num_samples: An int determinig the number of samples.
    valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an
      element is valid and False if it needs to be ignored. By default the value
      is None which means it is not applied.
    max_instance_id: If set, instance ids larger than that value will be
      ignored. If not set, it will be computed from instance_ids tensor.
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'.

  Returns:
    A tf.float32 scalar loss tensor.
  """
  batch_size = embeddings.get_shape().as_list()[0]
  if batch_size is None:
    raise ValueError('Unknown batch size at graph construction time.')
  if max_instance_id is None:
    max_instance_id = tf.reduce_max(instance_ids)
  sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample(
      features=embeddings,
      instance_ids=instance_ids,
      num_samples=num_samples,
      valid_mask=valid_mask,
      max_instance_id=max_instance_id)
  losses = []
  for i in range(batch_size):
    sampled_instance_ids_i = sampled_instance_ids[i, :]
    sampled_embeddings_i = sampled_embeddings[i, :, :]
    min_ids_i = tf.math.reduce_min(sampled_instance_ids_i)
    max_ids_i = tf.math.reduce_max(sampled_instance_ids_i)
    target_i = tf.one_hot(
        sampled_instance_ids_i,
        depth=(max_instance_id + 1),
        dtype=tf.float32)

    # pylint: disable=cell-var-from-loop
    def npair_loss_i():
      return metric_learning_losses.npair_loss(
          embedding=sampled_embeddings_i,
          target=target_i,
          similarity_strategy=similarity_strategy,
          loss_strategy=loss_strategy)
# pylint: enable=cell-var-from-loop

    loss_i = tf.cond(
        max_ids_i > min_ids_i, npair_loss_i,
        lambda: tf.constant(0.0, dtype=tf.float32))
    losses.append(loss_i)
  return tf.math.reduce_mean(losses)
Ejemplo n.º 16
0
 def one_hot_encode(x, scale):
     x = tf.squeeze(x, axis=1)
     x = tf.cast(x, tf.int32)
     return tf.one_hot(x, scale, axis=1)