Beispiel #1
0
 def UnitY(angle):
     return tf.reshape([
         tf.cos(angle), 0.,
         tf.sin(angle), 0., 1., 0., -tf.sin(angle), 0.,
         tf.cos(angle)
     ],
                       shape=[3, 3])
Beispiel #2
0
 def _UnitZ(angle):
     return tf.reshape([
         tf.cos(angle), -tf.sin(angle), 0.,
         tf.sin(angle),
         tf.cos(angle), 0., 0., 0., 1.
     ],
                       shape=[3, 3])
Beispiel #3
0
 def _rotate(bbox, theta):
     rotation_matrix = tf.reshape(
         [tf.cos(theta), -tf.sin(theta),
          tf.sin(theta),
          tf.cos(theta)],
         shape=(2, 2))
     return tf.matmul(bbox, rotation_matrix)
Beispiel #4
0
 def _UnitX(angle):
     return tf.reshape([
         1., 0., 0., 0.,
         tf.cos(angle), -tf.sin(angle), 0.,
         tf.sin(angle),
         tf.cos(angle)
     ],
                       shape=[3, 3])
Beispiel #5
0
def BatchMakeRotationMatrix(yaw, clockwise=False):
    """Create a Nx3x3 rotation matrix from yaw.

  Args:
    yaw: float tensor representing a yaw angle in radians.
    clockwise: Whether to have the rotation be applied clockwise (True) or
      counter-clockwise (False). Defaults to counter-clockwise to maintain
      same semantics to MakeRotationMatrix.

  Returns:
    A [N, 3, 3] tensor corresponding to a rotation matrix.
  """

    if clockwise:
        yaw = -yaw

    cos = tf.cos(yaw)
    sin = tf.sin(yaw)
    zero = tf.zeros_like(cos)
    one = tf.ones_like(cos)

    rotation_matrix = tf.stack(
        [cos, -sin, zero, sin, cos, zero, zero, zero, one],
        axis=-1)  # pyformat: disable
    rotation_matrix = tf.reshape(rotation_matrix, [-1, 3, 3])

    return rotation_matrix
Beispiel #6
0
def BBoxCorners(bboxes):
    """Extract the corner points from a 7-DOF bbox representation.

  Args:
    bboxes: A [batch, num_boxes, 7] floating point bounding box representation
      ([x, y, z, dx, dy, dz, phi]).

  Returns:
    A [batch, num_boxes, 8, 3] floating point Tensor containing
      the corner (x, y, z) points for every bounding box.
  """
    # Code adapted from vale/soapbox codebase.
    #
    # Corners in normalized box frame (unit cube centered at origin).
    #
    # Dimensions is [length, width, height].
    corners = tf.constant([
        [0.5, 0.5, 0.5],  # top
        [-0.5, 0.5, 0.5],  # top
        [-0.5, -0.5, 0.5],  # top
        [0.5, -0.5, 0.5],  # top
        [0.5, 0.5, -0.5],  # bottom
        [-0.5, 0.5, -0.5],  # bottom
        [-0.5, -0.5, -0.5],  # bottom
        [0.5, -0.5, -0.5],  # bottom
    ])

    batch, nb, _ = py_utils.GetShape(bboxes, 3)

    # Extract location, dimension, and rotation.
    location = bboxes[:, :, :3]
    dimensions = bboxes[:, :, 3:6]
    phi_world = bboxes[:, :, 6]

    # Convert rotation_phis into rotation matrices along unit z.
    cos = tf.cos(phi_world)
    sin = tf.sin(phi_world)
    zero = tf.zeros_like(cos)
    one = tf.ones_like(cos)
    rotations_world = tf.reshape(
        tf.stack([cos, -sin, zero, sin, cos, zero, zero, zero, one], axis=2),
        [batch, nb, 3, 3])

    # Create axis-aligned corners from length/width/height.
    corners = tf.einsum('bni,ji->bnji', dimensions, corners)

    # Rotate the corners coordinates to the rotated world frame.
    corners = tf.einsum('bnij,bnkj->bnki', rotations_world, corners)

    # Translate corners to the world location.
    corners = corners + tf.reshape(location, (batch, nb, 1, 3))
    return corners
Beispiel #7
0
def BatchMakeRotationMatrix(yaw):
    """Create a Nx3x3 rotation matrix from yaw.

  Args:
    yaw: float tensor representing a yaw angle in radians.

  Returns:
    A [N, 3, 3] tensor corresponding to a rotation matrix.
  """

    cos = tf.cos(yaw)
    sin = tf.sin(yaw)
    zero = tf.zeros_like(cos)
    one = tf.ones_like(cos)

    rotation_matrix = tf.stack(
        [cos, sin, zero, -sin, cos, zero, zero, zero, one],
        axis=-1)  # pyformat: disable
    rotation_matrix = tf.reshape(rotation_matrix, [-1, 3, 3])

    return rotation_matrix
Beispiel #8
0
def BBoxCorners2D(bboxes):
    """Extract the corner points from a 5-DOF bbox representation.

  Args:
    bboxes: A [..., 5] floating point bounding box representation ([x, y, dx,
      dy, phi]).

  Returns:
    A [..., 4, 2] floating point Tensor containing
      the corner (x, y) points for every bounding box.
  """
    corners = tf.constant([
        [0.5, 0.5],
        [-0.5, 0.5],
        [-0.5, -0.5],
        [0.5, -0.5],
    ])

    leading_shape = py_utils.GetShape(bboxes)[:-1]

    # Extract location, dimension, and rotation.
    location = bboxes[..., :2]
    dimensions = bboxes[..., 2:4]
    phi_world = bboxes[..., 4]

    # Convert rotation_phis into rotation matrices along unit z.
    cos = tf.cos(phi_world)
    sin = tf.sin(phi_world)
    rotations_world = tf.reshape(tf.stack([cos, -sin, sin, cos], axis=-1),
                                 leading_shape + [2, 2])

    # Create axis-aligned corners from length/width/height.
    corners = tf.einsum('...i,ji->...ji', dimensions, corners)

    # Rotate the corners coordinates to the rotated world frame.
    corners = tf.einsum('...ij,...kj->...ki', rotations_world, corners)

    # Translate corners to the world location.
    corners = corners + tf.reshape(location, leading_shape + [1, 2])
    return corners
  def _XYZFromRangeImage(self,
                         lidar_image,
                         lidar_image_mask,
                         extrinsics,
                         inclinations,
                         pixel_pose=None,
                         frame_pose=None):
    """Extract the cartesian coordinates from the range image.

    Args:
       lidar_image: [H, W, C] range image Tensor.
       lidar_image_mask: [H, W] boolean indicating which 2d coordinates in the
         lidar image are present.
       extrinsics: [4, 4] float matrix representing transformation matrix to
         world coordinates.
       inclinations: [V] beam inclinations vector.
       pixel_pose: [64, 2650, 4, 4] tensor representing per pixel pose of GBR.
       frame_pose: [4, 4] matrix representing vehicle to world transformation.

    Returns:
      [H, W, 3] range image cartesian coordinates.
    """
    height, width, channels = py_utils.GetShape(lidar_image, 3)

    conversion_dtype = tf.float32
    lidar_image = tf.cast(lidar_image, conversion_dtype)
    extrinsics = tf.cast(extrinsics, conversion_dtype)
    inclinations = tf.cast(inclinations, conversion_dtype)
    inclinations = tf.reverse(inclinations, axis=[-1])

    az_correction = py_utils.HasShape(
        tf.atan2(extrinsics[1, 0], extrinsics[0, 0]), [])
    ratios = (tf.cast(tf.range(width, 0, -1), dtype=conversion_dtype) -
              .5) / tf.cast(width, conversion_dtype)
    ratios = py_utils.HasShape(ratios, [width])

    azimuth = (ratios * 2. - 1.) * np.pi - az_correction[..., tf.newaxis]
    azimuth = py_utils.HasShape(azimuth, [width])

    lidar_image_mask = lidar_image_mask[..., tf.newaxis]
    lidar_image_mask = tf.tile(lidar_image_mask, [1, 1, channels])
    lidar_image = tf.where(lidar_image_mask, lidar_image,
                           tf.zeros_like(lidar_image))
    lidar_image_range = lidar_image[..., 0]

    azimuth = py_utils.HasShape(azimuth[tf.newaxis, ...], [1, width])
    inclinations = py_utils.HasShape(inclinations[..., tf.newaxis], [height, 1])

    cos_azimuth = tf.cos(azimuth)
    sin_azimuth = tf.sin(azimuth)
    cos_incl = tf.cos(inclinations)
    sin_incl = tf.sin(inclinations)

    x = cos_azimuth * cos_incl * lidar_image_range
    y = sin_azimuth * cos_incl * lidar_image_range
    z = sin_incl * lidar_image_range

    lidar_image_points = tf.stack([x, y, z], -1)
    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    rotation = extrinsics[0:3, 0:3]
    translation = extrinsics[0:3, 3][tf.newaxis, ...]

    # Transform the image points in cartesian coordinates to
    # the world coordinate system using the extrinsics matrix.
    #
    # We first flatten the points, apply rotation, then
    # reshape to restore the original input and then apply
    # translation.
    lidar_image_points = tf.matmul(
        tf.reshape(lidar_image_points, [-1, 3]), rotation, transpose_b=True)
    lidar_image_points = tf.reshape(lidar_image_points, [height, width, 3])
    lidar_image_points += translation

    lidar_image_points = py_utils.HasShape(lidar_image_points,
                                           [height, width, 3])
    # GBR uses per pixel pose.
    if pixel_pose is not None:
      pixel_pose_rotation = pixel_pose[..., 0:3, 0:3]
      pixel_pose_translation = pixel_pose[..., 0:3, 3]
      lidar_image_points = tf.einsum(
          'hwij,hwj->hwi', pixel_pose_rotation,
          lidar_image_points) + pixel_pose_translation
      if frame_pose is None:
        raise ValueError('frame_pose must be set when pixel_pose is set.')
      # To vehicle frame corresponding to the given frame_pose
      # [4, 4]
      world_to_vehicle = tf.linalg.inv(frame_pose)
      world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3]
      world_to_vehicle_translation = world_to_vehicle[0:3, 3]
      # [H, W, 3]
      lidar_image_points = tf.einsum(
          'ij,hwj->hwi', world_to_vehicle_rotation,
          lidar_image_points) + world_to_vehicle_translation[tf.newaxis,
                                                             tf.newaxis, :]

    return lidar_image_points
Beispiel #10
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Computes loss and other metrics for the given predictions.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: The output of `ComputePredictions`, contains: logits - [b,
        nx, ny, nz, na, 7 + num_classes]. na is the number of anchor
        boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt).
      input_batch: The input batch from which we accesses the groundtruth.

    Returns:
      Two dicts defined as BaseTask.ComputeLoss.
    """
        p = self.params
        predicted_residuals = py_utils.HasShape(
            predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7])
        predicted_class_logits = py_utils.HasShape(
            predictions.classification_logits,
            [-1, -1, -1, -1, p.num_anchors, p.num_classes])
        bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6)

        # Compute class and regression weights.
        class_weights = input_batch.assigned_cls_mask
        class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na])
        reg_weights = input_batch.assigned_reg_mask
        reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na])
        reg_weights = tf.expand_dims(reg_weights, -1)

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask,
                                                [bs, nx, ny, nz, na])
            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask,
                                               axis=[1, 2, 3, 4])
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))
            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [bs, 1, 1, 1, 1, 1])

            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        # Classification loss.
        assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels,
                                               [bs, nx, ny, nz, na])
        class_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_class_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)
        class_loss *= class_weights[..., tf.newaxis]
        class_loss_sum = tf.reduce_sum(class_loss)

        # Regression loss.
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7])

        # Location and dimensions loss.
        reg_loc_and_dims_loss = self._utils.ScaledHuberLoss(
            predictions=py_utils.HasShape(predicted_residuals[..., :6],
                                          [bs, nx, ny, nz, na, 6]),
            labels=anchor_localization_residuals[..., :6],
            delta=1 / (3.**2))

        # Rotation loss with SmoothL1(sin(delta)).
        rot_delta = (predicted_residuals[..., 6:] -
                     input_batch.anchor_localization_residuals[..., 6:])

        if p.use_atan2_heading_loss:
            atan2_of_delta = tf.atan2(tf.sin(rot_delta), tf.cos(rot_delta))
            reg_rot_loss = self._utils.ScaledHuberLoss(
                predictions=atan2_of_delta,
                labels=tf.zeros_like(atan2_of_delta),
                delta=1 / (3.**2))
        else:
            # Rotation loss with SmoothL1(sin(delta)).
            reg_rot_loss = self._utils.ScaledHuberLoss(
                predictions=tf.sin(rot_delta),
                labels=tf.zeros_like(rot_delta),
                delta=1 / (3.**2))

        # Direction loss
        if p.direction_classifier_weight > 0.0:
            # The target rotations are in the assigned_gt_bbox tensor,
            # which already has assigned a gt bounding box to every anchor.
            rot_target = input_batch.assigned_gt_bbox[..., 6]
            # If rotation is > 0, the class is 1, else it is 0.
            rot_dir = tf.cast(rot_target > 0., tf.int32)

            # Compute one-hot labels as a target.
            rot_dir_onehot = tf.one_hot(rot_dir, 2)

            # Manually handle loss reduction.
            dir_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=rot_dir_onehot,
                logits=predictions.predicted_dir,
                weights=tf.squeeze(reg_weights, axis=-1),
                reduction=tf.losses.Reduction.NONE)
            # Reduce across all dimensions (we'll divide by the batch size below).
            dir_loss_sum = tf.reduce_sum(dir_loss)
        else:
            dir_loss_sum = 0.0

        # Compute loss contribution from location and dimension separately.
        reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights
        reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss)

        reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights
        reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss)

        # Compute rotation loss contribution.
        reg_rot_loss *= reg_weights
        reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss)

        # Num. predictions.
        # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes.
        preds = tf.cast(bs, class_loss_sum.dtype)

        # Normalize all of the components by batch size.
        reg_loc_loss = reg_loc_loss_sum / preds
        reg_dim_loss = reg_dim_loss_sum / preds
        reg_rot_loss = reg_rot_loss_sum / preds
        class_loss = class_loss_sum / preds
        dir_loss = dir_loss_sum / preds

        # Compute total localization regression loss.
        reg_loss = (p.location_loss_weight * reg_loc_loss +
                    p.dimension_loss_weight * reg_dim_loss +
                    p.rotation_loss_weight * reg_rot_loss)

        # Apply weights to normalized class losses.
        loss = (class_loss * p.classification_loss_weight +
                reg_loss * p.localization_loss_weight +
                dir_loss * p.direction_classifier_weight)

        metrics_dict = {
            'loss': (loss, preds),
            'loss/class': (class_loss, preds),
            'loss/reg': (reg_loss, preds),
            'loss/reg/rot': (reg_rot_loss, preds),
            'loss/reg/loc': (reg_loc_loss, preds),
            'loss/reg/dim': (reg_dim_loss, preds),
            'loss/dir': (dir_loss, preds),
        }

        # Calculate dimension errors
        min_angle_rad = -np.pi if p.use_atan2_heading_loss else 0
        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            input_batch.anchor_bboxes,
            predicted_residuals,
            min_angle_rad=min_angle_rad,
            max_angle_rad=np.pi)
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = {
            'residuals': predicted_residuals,
            'classification_logits': predicted_class_logits,
        }

        return metrics_dict, per_example_dict
Beispiel #11
0
    def ComputeLoss(self, theta, predictions, input_batch):
        """Compute loss for the sparse detector model v1.

    Args:
      theta: A `.NestedMap` object containing variable values of this task.
      predictions: A `.NestedMap` object containing residuals and
        classification_logits.
      input_batch: A `.NestedMap` expected to contain cell_center_xyz,
        cell_points_xyz, cell_feature, anchor_bboxes,
        anchor_localization_residuals, assigned_gt_labels, and
        assigned_cls_mask. See class doc string for details.

    Returns:
      Two dicts:

      - A dict containing str keys and (metric, weight) pairs as values, where
        one of the keys is expected to be 'loss'.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index.
    """
        p = self.params

        batch_size, num_centers = py_utils.GetShape(
            input_batch.cell_center_xyz, 2)

        # Assert shapes of inputs.
        anchor_bboxes = py_utils.HasShape(
            input_batch.anchor_bboxes,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        anchor_localization_residuals = py_utils.HasShape(
            input_batch.anchor_localization_residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])
        predicted_residuals = py_utils.HasShape(
            predictions.residuals,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7])

        assigned_gt_labels = py_utils.HasShape(
            input_batch.assigned_gt_labels,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        predicted_classification_logits = py_utils.HasShape(
            predictions.classification_logits, [
                batch_size, num_centers, p.num_anchor_bboxes_per_center,
                p.num_classes
            ])

        # assigned_cls_mask is for weighting the classification loss.
        # Ignored targets will have their mask = 0; this happens when their IOU is
        # not high enough to be a foreground object and not low enough to be
        # background.
        class_weights = py_utils.HasShape(
            input_batch.assigned_cls_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        class_weights = tf.reshape(
            class_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        # Broadcast per class loss weights. For each anchor, there are num_classes
        # prediction heads, we weight the outputs of these heads by the per class
        # loss weights.
        per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]],
                                            dtype=tf.float32)
        per_class_loss_weight = py_utils.HasShape(per_class_loss_weight,
                                                  [1, 1, 1, p.num_classes])
        class_weights *= per_class_loss_weight
        class_weights = py_utils.HasShape(class_weights, [
            batch_size, num_centers, p.num_anchor_bboxes_per_center,
            p.num_classes
        ])

        # We use assigned_reg_mask for masking the regression loss.
        # Only foreground objects will have assigned_reg_mask = 1.
        reg_weights = py_utils.HasShape(
            input_batch.assigned_reg_mask,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center])
        reg_weights = tf.reshape(
            reg_weights,
            [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1])

        if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER:
            # Compute number of positive anchors per example.
            foreground_mask = py_utils.HasShape(
                input_batch.assigned_reg_mask,
                [batch_size, num_centers, p.num_anchor_bboxes_per_center])

            # Sum to get the number of foreground anchors for each example.
            loss_normalization = tf.reduce_sum(foreground_mask, axis=2)
            loss_normalization = tf.maximum(loss_normalization,
                                            tf.ones_like(loss_normalization))

            # Reshape for broadcasting.
            loss_normalization = tf.reshape(loss_normalization,
                                            [batch_size, num_centers, 1, 1])

            # Normalize so that the loss is independent of # centers.
            loss_normalization *= num_centers
            class_weights /= loss_normalization
            reg_weights /= loss_normalization

        classification_loss = py_utils.SigmoidCrossEntropyFocalLoss(
            logits=predicted_classification_logits,
            labels=tf.one_hot(assigned_gt_labels, p.num_classes),
            alpha=p.focal_loss_alpha,
            gamma=p.focal_loss_gamma)

        # Apply mask.
        classification_loss *= class_weights

        # TODO(jngiam): Consider normalizing by num_foreground_anchors for each
        # example instead. This would match the 1/N_positive normalization in
        # point pillars.

        # Reduce sum over centers, boxes and classes.
        classification_loss = tf.reduce_sum(classification_loss,
                                            axis=[1, 2, 3])

        # Reduce mean over batch.
        classification_loss = tf.reduce_mean(classification_loss)

        # Localization regression loss with Huber loss (SmoothL1).
        regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss(
            labels=anchor_localization_residuals[..., :6],
            predictions=predicted_residuals[..., :6],
            delta=p.huber_loss_delta)

        # TODO(jngiam): Consider other methods for rotation loss such as softmax
        # binning.
        # For the rotation loss, we use SmoothL1(sine(delta)), this enables the
        # rotation loss to be the same independent of direction.
        rotation_delta = (predicted_residuals[..., 6:] -
                          anchor_localization_residuals[..., 6:])
        if p.use_atan2_heading_loss:
            atan2_of_delta = tf.atan2(tf.sin(rotation_delta),
                                      tf.cos(rotation_delta))
            regression_rotation_loss = self._utils.ScaledHuberLoss(
                predictions=atan2_of_delta,
                labels=tf.zeros_like(atan2_of_delta),
                delta=1 / (3.**2))
        else:
            regression_rotation_loss = self._utils_3d.ScaledHuberLoss(
                labels=tf.zeros_like(rotation_delta),
                predictions=tf.sin(rotation_delta),
                delta=p.huber_loss_delta)

        reg_loc_loss = regression_loc_and_dims_loss[..., :3]
        reg_dim_loss = regression_loc_and_dims_loss[..., 3:6]

        gt_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            anchor_localization_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)
        predicted_bboxes = self._utils_3d.ResidualsToBBoxes(
            anchor_bboxes,
            predicted_residuals,
            min_angle_rad=-np.pi,
            max_angle_rad=np.pi)

        # Apply mask to individual losses.
        #
        # And then reduce sum over centers, boxes, residuals, and batch
        # and divide by the batch_size.
        regression_rotation_loss *= reg_weights
        reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size

        reg_loc_loss *= reg_weights
        reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size

        reg_dim_loss *= reg_weights
        reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size

        # Do not create corner loss graph if weight is 0.0
        # TODO(bcyang): Remove condition after fixing corner loss NaN issue
        if p.corner_loss_weight != 0.0:
            reg_corner_loss = self._utils_3d.CornerLoss(
                gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes)
            reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1)

            reg_corner_loss *= reg_weights
            reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size
        else:
            reg_corner_loss = 0.0

        # Sum components of regression loss.
        regression_loss = (p.location_loss_weight * reg_loc_loss +
                           p.dimension_loss_weight * reg_dim_loss +
                           p.rotation_loss_weight * reg_rot_loss +
                           p.corner_loss_weight * reg_corner_loss)

        # Compute total loss.
        total_loss = (p.loss_weight_localization * regression_loss +
                      p.loss_weight_classification * classification_loss)

        metrics_dict = py_utils.NestedMap({
            'loss': (total_loss, batch_size),
            'loss/regression': (regression_loss, batch_size),
            'loss/regression/loc': (reg_loc_loss, batch_size),
            'loss/regression/dim': (reg_dim_loss, batch_size),
            'loss/regression/rot': (reg_rot_loss, batch_size),
            'loss/regression/corner': (reg_corner_loss, batch_size),
            'loss/classification': (classification_loss, batch_size),
        })

        # Calculate dimension errors
        dimension_errors_dict = self._BBoxDimensionErrors(
            gt_bboxes, predicted_bboxes, reg_weights)
        metrics_dict.update(dimension_errors_dict)

        per_example_dict = py_utils.NestedMap({
            'residuals': predicted_residuals,
            'classification_logits': predicted_classification_logits,
            'predicted_bboxes': predicted_bboxes,
            'gt_bboxes': gt_bboxes,
            'reg_weights': reg_weights,
        })

        return metrics_dict, per_example_dict