def _box_center_distance_loss_on_voxel_tensors_unbatched(
    inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate):
  """Computes huber loss on predicted object centers for each voxel."""
  inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs(
      inputs_1=inputs_1, outputs_1=outputs_1)

  def loss_fn_unbatched():
    """Loss function."""
    if is_intermediate:
      output_boxes_center = outputs_1[standard_fields.DetectionResultFields
                                      .intermediate_object_center_voxels]
    else:
      output_boxes_center = outputs_1[
          standard_fields.DetectionResultFields.object_center_voxels]
    return _box_center_distance_loss(
        loss_type=loss_type,
        is_balanced=is_balanced,
        input_boxes_center=inputs_1[
            standard_fields.InputDataFields.object_center_voxels],
        input_boxes_instance_id=inputs_1[
            standard_fields.InputDataFields.object_instance_id_voxels],
        output_boxes_center=output_boxes_center,
        delta=delta)

  return tf.cond(
      tf.reduce_any(valid_mask),
      loss_fn_unbatched, lambda: tf.constant(0.0, dtype=tf.float32))
示例#2
0
    def convert_to_simclr_episode(support_images=None,
                                  support_labels=None,
                                  support_class_ids=None,
                                  query_images=None,
                                  query_labels=None,
                                  query_class_ids=None):
        """Convert a single episode into a SimCLR Episode."""

        # If there were k query examples of class c, keep the first k support
        # examples of class c as 'simclr' queries.  We do this by assigning an
        # id for each image in the query set, implemented as label*1e5+x+1, where
        # x is the number of images of the same label with a lower index within
        # the query set.  We do the same for the support set, which gives us a
        # mapping between query and support images which is injective (as long
        # as there's enough support-set images of each class).
        #
        # note: assumes max support label is 10000 - max_images_per_class
        query_idx_within_class = tf.cast(
            tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]),
            tf.int32)
        query_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(query_idx_within_class, axis=1))
        query_uid = query_labels * 10000 + query_idx_within_class
        support_idx_within_class = tf.cast(
            tf.equal(support_labels[tf.newaxis, :],
                     support_labels[:, tf.newaxis]), tf.int32)
        support_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(support_idx_within_class, axis=1))
        support_uid = support_labels * 10000 + support_idx_within_class

        # compute which support-set images have matches in the query set, and
        # discard the rest to produce the new query set.
        support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis],
                                              query_uid[tf.newaxis, :]),
                                     axis=1)
        query_images = tf.boolean_mask(support_images, support_keep)

        support_labels = tf.range(tf.shape(support_labels)[0],
                                  dtype=support_labels.dtype)
        query_labels = tf.boolean_mask(support_labels, support_keep)
        query_class_ids = tf.boolean_mask(support_class_ids, support_keep)

        # Finally, apply SimCLR augmentation to all images.
        # Note simclr only blurs one image.
        query_images = simclr_augment(query_images, blur=True)
        support_images = simclr_augment(support_images)

        return (support_images, support_labels, support_class_ids,
                query_images, query_labels, query_class_ids)
def _box_corner_distance_loss_on_object_tensors(inputs, outputs, loss_type,
                                                delta, is_balanced):
    """Computes huber loss on object corner locations."""
    valid_mask_class = tf.greater(
        tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                   [-1]), 0)
    valid_mask_instance = tf.greater(
        tf.reshape(inputs[standard_fields.InputDataFields.objects_instance_id],
                   [-1]), 0)
    valid_mask = tf.logical_and(valid_mask_class, valid_mask_instance)

    def fn():
        for field in standard_fields.get_input_object_fields():
            if field in inputs:
                inputs[field] = tf.boolean_mask(inputs[field], valid_mask)
        for field in standard_fields.get_output_object_fields():
            if field in outputs:
                outputs[field] = tf.boolean_mask(outputs[field], valid_mask)
        return _box_corner_distance_loss(
            loss_type=loss_type,
            is_balanced=is_balanced,
            input_boxes_length=inputs[
                standard_fields.InputDataFields.objects_length],
            input_boxes_height=inputs[
                standard_fields.InputDataFields.objects_height],
            input_boxes_width=inputs[
                standard_fields.InputDataFields.objects_width],
            input_boxes_center=inputs[
                standard_fields.InputDataFields.objects_center],
            input_boxes_rotation_matrix=inputs[
                standard_fields.InputDataFields.objects_rotation_matrix],
            input_boxes_instance_id=inputs[
                standard_fields.InputDataFields.objects_instance_id],
            output_boxes_length=outputs[
                standard_fields.DetectionResultFields.objects_length],
            output_boxes_height=outputs[
                standard_fields.DetectionResultFields.objects_height],
            output_boxes_width=outputs[
                standard_fields.DetectionResultFields.objects_width],
            output_boxes_center=outputs[
                standard_fields.DetectionResultFields.objects_center],
            output_boxes_rotation_matrix=outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            delta=delta)

    return tf.cond(tf.reduce_any(valid_mask), fn,
                   lambda: tf.constant(0.0, dtype=tf.float32))
def _box_size_regression_loss_on_voxel_tensors_unbatched(
        inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate):
    """Computes regression loss on predicted object size for each voxel."""
    inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs(
        inputs_1=inputs_1, outputs_1=outputs_1)

    def loss_fn_unbatched():
        """Loss function."""
        if is_intermediate:
            output_boxes_length = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_length_voxels]
            output_boxes_height = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_height_voxels]
            output_boxes_width = outputs_1[
                standard_fields.DetectionResultFields.
                intermediate_object_width_voxels]
        else:
            output_boxes_length = outputs_1[
                standard_fields.DetectionResultFields.object_length_voxels]
            output_boxes_height = outputs_1[
                standard_fields.DetectionResultFields.object_height_voxels]
            output_boxes_width = outputs_1[
                standard_fields.DetectionResultFields.object_width_voxels]
        return _box_size_regression_loss(
            loss_type=loss_type,
            is_balanced=is_balanced,
            input_boxes_length=inputs_1[
                standard_fields.InputDataFields.object_length_voxels],
            input_boxes_height=inputs_1[
                standard_fields.InputDataFields.object_height_voxels],
            input_boxes_width=inputs_1[
                standard_fields.InputDataFields.object_width_voxels],
            input_boxes_instance_id=inputs_1[
                standard_fields.InputDataFields.object_instance_id_voxels],
            output_boxes_length=output_boxes_length,
            output_boxes_height=output_boxes_height,
            output_boxes_width=output_boxes_width,
            delta=delta)

    return tf.cond(tf.reduce_any(valid_mask), loss_fn_unbatched,
                   lambda: tf.constant(0.0, dtype=tf.float32))