예제 #1
0
def apply_unbatched_loss_on_object_tensors(inputs, outputs, unbatched_loss_fn):
    """Applies the `unbatched_loss_fn` to each example in the batch."""
    batch_size = len(inputs[standard_fields.InputDataFields.objects_length])
    losses = []
    for b in range(batch_size):
        inputs_1 = batch_utils.get_batch_size_1_input_objects(inputs=inputs,
                                                              b=b)
        outputs_1 = batch_utils.get_batch_size_1_output_objects(
            outputs=outputs, b=b)
        cond_input = tf.greater(
            tf.shape(
                inputs_1[standard_fields.InputDataFields.objects_length])[0],
            0)
        cond_output = tf.greater(
            tf.shape(outputs_1[
                standard_fields.DetectionResultFields.objects_length])[0], 0)
        cond = tf.logical_and(cond_input, cond_output)
        # pylint: disable=cell-var-from-loop
        loss = tf.cond(
            cond,
            lambda: unbatched_loss_fn(inputs_1=inputs_1, outputs_1=outputs_1),
            lambda: tf.constant(0.0, dtype=tf.float32))
        # pylint: enable=cell-var-from-loop
        losses.append(loss)
    return tf.reduce_mean(tf.stack(losses))
def box_corner_distance_loss_on_object_tensors(inputs,
                                               outputs,
                                               loss_type,
                                               delta=1.0,
                                               is_balanced=False):
    """Computes regression loss on object corner locations using object tensors.

  Args:
    inputs: A dictionary of tf.Tensors with our input data.
    outputs: A dictionary of tf.Tensors with the network output.
    loss_type: Loss type.
    delta: float, the voxel where the huber loss function changes from a
      quadratic to linear.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for each object instance.

  Returns:
    localization_loss: A tf.float32 scalar corresponding to localization loss.
  """
    def fn(inputs_1, outputs_1):
        return _box_corner_distance_loss_on_object_tensors(
            inputs=inputs_1,
            outputs=outputs_1,
            loss_type=loss_type,
            delta=delta,
            is_balanced=is_balanced)

    batch_size = len(inputs[standard_fields.InputDataFields.objects_length])
    losses = []
    for b in range(batch_size):
        inputs_1 = batch_utils.get_batch_size_1_input_objects(inputs=inputs,
                                                              b=b)
        outputs_1 = batch_utils.get_batch_size_1_output_objects(
            outputs=outputs, b=b)
        cond_input = tf.greater(
            tf.shape(
                inputs_1[standard_fields.InputDataFields.objects_length])[0],
            0)
        cond_output = tf.greater(
            tf.shape(outputs_1[
                standard_fields.DetectionResultFields.objects_length])[0], 0)
        cond = tf.logical_and(cond_input, cond_output)
        # pylint: disable=cell-var-from-loop
        loss = tf.cond(cond,
                       lambda: fn(inputs_1=inputs_1, outputs_1=outputs_1),
                       lambda: tf.constant(0.0, dtype=tf.float32))
        # pylint: enable=cell-var-from-loop
        losses.append(loss)
    return tf.reduce_mean(tf.stack(losses))