def _build_localization_loss(loss_config):
  """Builds a localization loss based on the loss config.

  Args:
    loss_config: A losses_pb2.LocalizationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
  if not isinstance(loss_config, losses_pb2.LocalizationLoss):
    raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')

  loss_type = loss_config.WhichOneof('localization_loss')

  if loss_type == 'weighted_l2':
    return losses.WeightedL2LocalizationLoss()

  if loss_type == 'weighted_smooth_l1':
    return losses.WeightedSmoothL1LocalizationLoss(
        loss_config.weighted_smooth_l1.delta)

  if loss_type == 'weighted_iou':
    return losses.WeightedIOULocalizationLoss()

  raise ValueError('Empty loss config.')
Exemple #2
0
def _build_localization_loss(loss_config):
    """Builds a localization loss based on the loss config.

    Args:
    loss_config: A yaml.LocalizationLoss object.

    Returns:
    Loss based on the config.

    Raises:
    ValueError: On invalid loss_config.
    """

    if 'weighted_l2' in loss_config:
        config = loss_config.weighted_l2
        if len(config.code_weight) == 0:
            code_weight = None
        else:
            code_weight = config.code_weight
        return losses.WeightedL2LocalizationLoss(code_weight)

    if 'weighted_smooth_l1' in loss_config:
        config = loss_config.weighted_smooth_l1
        if len(config.code_weight) == 0:
            code_weight = None
        else:
            code_weight = config.code_weight
        return losses.WeightedSmoothL1LocalizationLoss(config.sigma,
                                                       code_weight)
    else:
        raise ValueError('Empty loss config.')
Exemple #3
0
    def testReturnsCorrectAnchorwiseLoss(self):
        batch_size = 3
        num_anchors = 16
        code_size = 4
        prediction_tensor = tf.ones([batch_size, num_anchors, code_size])
        target_tensor = tf.zeros([batch_size, num_anchors, code_size])
        weights = tf.ones([batch_size, num_anchors])
        loss_op = losses.WeightedL2LocalizationLoss()
        loss = loss_op(prediction_tensor, target_tensor, weights=weights)

        expected_loss = np.ones((batch_size, num_anchors)) * 2
        with self.test_session() as sess:
            loss_output = sess.run(loss)
            self.assertAllClose(loss_output, expected_loss)
Exemple #4
0
    def testReturnsCorrectLoss(self):
        batch_size = 3
        num_anchors = 10
        code_size = 4
        prediction_tensor = tf.ones([batch_size, num_anchors, code_size])
        target_tensor = tf.zeros([batch_size, num_anchors, code_size])
        weights = tf.constant(
            [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
             [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], tf.float32)
        loss_op = losses.WeightedL2LocalizationLoss()
        loss = loss_op(prediction_tensor, target_tensor, weights=weights)

        expected_loss = (3 * 5 * 4) / 2.0
        with self.test_session() as sess:
            loss_output = sess.run(loss)
            self.assertAllClose(loss_output, expected_loss)
Exemple #5
0
  def testReturnsCorrectNanLoss(self):
    batch_size = 3
    num_anchors = 10
    code_size = 4
    prediction_tensor = tf.ones([batch_size, num_anchors, code_size])
    target_tensor = tf.concat([
        tf.zeros([batch_size, num_anchors, code_size / 2]),
        tf.ones([batch_size, num_anchors, code_size / 2]) * np.nan
    ], axis=2)
    weights = tf.ones([batch_size, num_anchors])
    loss_op = losses.WeightedL2LocalizationLoss()
    loss = loss_op(prediction_tensor, target_tensor, weights=weights,
                   ignore_nan_targets=True)
    loss = tf.reduce_sum(loss)

    expected_loss = (3 * 5 * 4) / 2.0
    with self.test_session() as sess:
      loss_output = sess.run(loss)
      self.assertAllClose(loss_output, expected_loss)
Exemple #6
0
def _build_localization_loss(loss_config):
    """Builds a localization loss based on the loss config.

  Args:
    loss_config: A losses_pb2.LocalizationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
    if not isinstance(loss_config, losses_pb2.LocalizationLoss):
        raise ValueError(
            'loss_config not of type losses_pb2.LocalizationLoss.')

    loss_type = loss_config.WhichOneof('localization_loss')

    if loss_type == 'weighted_l2':
        config = loss_config.weighted_l2
        if len(config.code_weight) == 0:
            code_weight = None
        else:
            code_weight = config.code_weight
        return losses.WeightedL2LocalizationLoss(code_weight)

    if loss_type == 'weighted_smooth_l1':
        config = loss_config.weighted_smooth_l1
        if len(config.code_weight) == 0:
            code_weight = None
        else:
            code_weight = config.code_weight
        return losses.WeightedSmoothL1LocalizationLoss(config.sigma,
                                                       code_weight)

    raise ValueError('Empty loss config.')