def _box_center_distance_loss_on_voxel_tensors_unbatched( inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate): """Computes huber loss on predicted object centers for each voxel.""" inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs( inputs_1=inputs_1, outputs_1=outputs_1) def loss_fn_unbatched(): """Loss function.""" if is_intermediate: output_boxes_center = outputs_1[standard_fields.DetectionResultFields .intermediate_object_center_voxels] else: output_boxes_center = outputs_1[ standard_fields.DetectionResultFields.object_center_voxels] return _box_center_distance_loss( loss_type=loss_type, is_balanced=is_balanced, input_boxes_center=inputs_1[ standard_fields.InputDataFields.object_center_voxels], input_boxes_instance_id=inputs_1[ standard_fields.InputDataFields.object_instance_id_voxels], output_boxes_center=output_boxes_center, delta=delta) return tf.cond( tf.reduce_any(valid_mask), loss_fn_unbatched, lambda: tf.constant(0.0, dtype=tf.float32))
def convert_to_simclr_episode(support_images=None, support_labels=None, support_class_ids=None, query_images=None, query_labels=None, query_class_ids=None): """Convert a single episode into a SimCLR Episode.""" # If there were k query examples of class c, keep the first k support # examples of class c as 'simclr' queries. We do this by assigning an # id for each image in the query set, implemented as label*1e5+x+1, where # x is the number of images of the same label with a lower index within # the query set. We do the same for the support set, which gives us a # mapping between query and support images which is injective (as long # as there's enough support-set images of each class). # # note: assumes max support label is 10000 - max_images_per_class query_idx_within_class = tf.cast( tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]), tf.int32) query_idx_within_class = tf.linalg.diag_part( tf.cumsum(query_idx_within_class, axis=1)) query_uid = query_labels * 10000 + query_idx_within_class support_idx_within_class = tf.cast( tf.equal(support_labels[tf.newaxis, :], support_labels[:, tf.newaxis]), tf.int32) support_idx_within_class = tf.linalg.diag_part( tf.cumsum(support_idx_within_class, axis=1)) support_uid = support_labels * 10000 + support_idx_within_class # compute which support-set images have matches in the query set, and # discard the rest to produce the new query set. support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis], query_uid[tf.newaxis, :]), axis=1) query_images = tf.boolean_mask(support_images, support_keep) support_labels = tf.range(tf.shape(support_labels)[0], dtype=support_labels.dtype) query_labels = tf.boolean_mask(support_labels, support_keep) query_class_ids = tf.boolean_mask(support_class_ids, support_keep) # Finally, apply SimCLR augmentation to all images. # Note simclr only blurs one image. query_images = simclr_augment(query_images, blur=True) support_images = simclr_augment(support_images) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)
def _box_corner_distance_loss_on_object_tensors(inputs, outputs, loss_type, delta, is_balanced): """Computes huber loss on object corner locations.""" valid_mask_class = tf.greater( tf.reshape(inputs[standard_fields.InputDataFields.objects_class], [-1]), 0) valid_mask_instance = tf.greater( tf.reshape(inputs[standard_fields.InputDataFields.objects_instance_id], [-1]), 0) valid_mask = tf.logical_and(valid_mask_class, valid_mask_instance) def fn(): for field in standard_fields.get_input_object_fields(): if field in inputs: inputs[field] = tf.boolean_mask(inputs[field], valid_mask) for field in standard_fields.get_output_object_fields(): if field in outputs: outputs[field] = tf.boolean_mask(outputs[field], valid_mask) return _box_corner_distance_loss( loss_type=loss_type, is_balanced=is_balanced, input_boxes_length=inputs[ standard_fields.InputDataFields.objects_length], input_boxes_height=inputs[ standard_fields.InputDataFields.objects_height], input_boxes_width=inputs[ standard_fields.InputDataFields.objects_width], input_boxes_center=inputs[ standard_fields.InputDataFields.objects_center], input_boxes_rotation_matrix=inputs[ standard_fields.InputDataFields.objects_rotation_matrix], input_boxes_instance_id=inputs[ standard_fields.InputDataFields.objects_instance_id], output_boxes_length=outputs[ standard_fields.DetectionResultFields.objects_length], output_boxes_height=outputs[ standard_fields.DetectionResultFields.objects_height], output_boxes_width=outputs[ standard_fields.DetectionResultFields.objects_width], output_boxes_center=outputs[ standard_fields.DetectionResultFields.objects_center], output_boxes_rotation_matrix=outputs[ standard_fields.DetectionResultFields.objects_rotation_matrix], delta=delta) return tf.cond(tf.reduce_any(valid_mask), fn, lambda: tf.constant(0.0, dtype=tf.float32))
def _box_size_regression_loss_on_voxel_tensors_unbatched( inputs_1, outputs_1, loss_type, delta, is_balanced, is_intermediate): """Computes regression loss on predicted object size for each voxel.""" inputs_1, outputs_1, valid_mask = _get_voxels_valid_inputs_outputs( inputs_1=inputs_1, outputs_1=outputs_1) def loss_fn_unbatched(): """Loss function.""" if is_intermediate: output_boxes_length = outputs_1[ standard_fields.DetectionResultFields. intermediate_object_length_voxels] output_boxes_height = outputs_1[ standard_fields.DetectionResultFields. intermediate_object_height_voxels] output_boxes_width = outputs_1[ standard_fields.DetectionResultFields. intermediate_object_width_voxels] else: output_boxes_length = outputs_1[ standard_fields.DetectionResultFields.object_length_voxels] output_boxes_height = outputs_1[ standard_fields.DetectionResultFields.object_height_voxels] output_boxes_width = outputs_1[ standard_fields.DetectionResultFields.object_width_voxels] return _box_size_regression_loss( loss_type=loss_type, is_balanced=is_balanced, input_boxes_length=inputs_1[ standard_fields.InputDataFields.object_length_voxels], input_boxes_height=inputs_1[ standard_fields.InputDataFields.object_height_voxels], input_boxes_width=inputs_1[ standard_fields.InputDataFields.object_width_voxels], input_boxes_instance_id=inputs_1[ standard_fields.InputDataFields.object_instance_id_voxels], output_boxes_length=output_boxes_length, output_boxes_height=output_boxes_height, output_boxes_width=output_boxes_width, delta=delta) return tf.cond(tf.reduce_any(valid_mask), loss_fn_unbatched, lambda: tf.constant(0.0, dtype=tf.float32))