def _build_localization_loss(loss_config): """Builds a localization loss based on the loss config. Args: loss_config: A losses_pb2.LocalizationLoss object. Returns: Loss based on the config. Raises: ValueError: On invalid loss_config. """ if not isinstance(loss_config, losses_pb2.LocalizationLoss): raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.') loss_type = loss_config.WhichOneof('localization_loss') if loss_type == 'weighted_l2': return losses.WeightedL2LocalizationLoss() if loss_type == 'weighted_smooth_l1': return losses.WeightedSmoothL1LocalizationLoss( loss_config.weighted_smooth_l1.delta) if loss_type == 'weighted_iou': return losses.WeightedIOULocalizationLoss() raise ValueError('Empty loss config.')
def testReturnsCorrectLoss(self): prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1], [0, 0, 1, 1], [0, 0, .5, .25]]]) target_tensor = tf.constant([[[1.5, 0, 2.4, 1], [0, 0, 1, 1], [50, 50, 500.5, 100.25]]]) weights = [[1.0, .5, 2.0]] loss_op = losses.WeightedIOULocalizationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) exp_loss = 2.0 with self.test_session() as sess: loss_output = sess.run(loss) self.assertAllClose(loss_output, exp_loss)