def apply_unbatched_loss_on_object_tensors(inputs, outputs, unbatched_loss_fn): """Applies the `unbatched_loss_fn` to each example in the batch.""" batch_size = len(inputs[standard_fields.InputDataFields.objects_length]) losses = [] for b in range(batch_size): inputs_1 = batch_utils.get_batch_size_1_input_objects(inputs=inputs, b=b) outputs_1 = batch_utils.get_batch_size_1_output_objects( outputs=outputs, b=b) cond_input = tf.greater( tf.shape( inputs_1[standard_fields.InputDataFields.objects_length])[0], 0) cond_output = tf.greater( tf.shape(outputs_1[ standard_fields.DetectionResultFields.objects_length])[0], 0) cond = tf.logical_and(cond_input, cond_output) # pylint: disable=cell-var-from-loop loss = tf.cond( cond, lambda: unbatched_loss_fn(inputs_1=inputs_1, outputs_1=outputs_1), lambda: tf.constant(0.0, dtype=tf.float32)) # pylint: enable=cell-var-from-loop losses.append(loss) return tf.reduce_mean(tf.stack(losses))
def box_corner_distance_loss_on_object_tensors(inputs, outputs, loss_type, delta=1.0, is_balanced=False): """Computes regression loss on object corner locations using object tensors. Args: inputs: A dictionary of tf.Tensors with our input data. outputs: A dictionary of tf.Tensors with the network output. loss_type: Loss type. delta: float, the voxel where the huber loss function changes from a quadratic to linear. is_balanced: If True, the per-voxel losses are re-weighted to have equal total weight for each object instance. Returns: localization_loss: A tf.float32 scalar corresponding to localization loss. """ def fn(inputs_1, outputs_1): return _box_corner_distance_loss_on_object_tensors( inputs=inputs_1, outputs=outputs_1, loss_type=loss_type, delta=delta, is_balanced=is_balanced) batch_size = len(inputs[standard_fields.InputDataFields.objects_length]) losses = [] for b in range(batch_size): inputs_1 = batch_utils.get_batch_size_1_input_objects(inputs=inputs, b=b) outputs_1 = batch_utils.get_batch_size_1_output_objects( outputs=outputs, b=b) cond_input = tf.greater( tf.shape( inputs_1[standard_fields.InputDataFields.objects_length])[0], 0) cond_output = tf.greater( tf.shape(outputs_1[ standard_fields.DetectionResultFields.objects_length])[0], 0) cond = tf.logical_and(cond_input, cond_output) # pylint: disable=cell-var-from-loop loss = tf.cond(cond, lambda: fn(inputs_1=inputs_1, outputs_1=outputs_1), lambda: tf.constant(0.0, dtype=tf.float32)) # pylint: enable=cell-var-from-loop losses.append(loss) return tf.reduce_mean(tf.stack(losses))