Ejemplo n.º 1
0
 def test_losses(self):
     tf.random.set_seed(1111)
     box_loss = train_lib.BoxLoss()
     box_iou_loss = train_lib.BoxIouLoss(iou_loss_type='ciou',
                                         min_level=3,
                                         max_level=3,
                                         num_scales=1,
                                         aspect_ratios=[1.0],
                                         anchor_scale=1.0,
                                         image_size=32)
     alpha = 0.25
     gamma = 1.5
     focal_loss_v2 = train_lib.FocalLoss(
         alpha, gamma, reduction=tf.keras.losses.Reduction.NONE)
     box_outputs = tf.random.normal([64, 4])
     box_targets = tf.random.normal([64, 4])
     num_positives = tf.constant(4.0)
     self.assertEqual(
         legacy_fn._box_loss(box_outputs, box_targets, num_positives),
         box_loss([num_positives, box_targets], box_outputs))
     self.assertAllEqual(
         legacy_fn.focal_loss(box_outputs, box_targets, alpha, gamma,
                              num_positives),
         focal_loss_v2([num_positives, box_targets], box_outputs))
     # TODO(tanmingxing): Re-enable this test after fixing this failing test.
     # self.assertEqual(
     #     legacy_fn._box_iou_loss(box_outputs, box_targets, num_positives,
     #                             'ciou'),
     #     box_iou_loss([num_positives, box_targets], box_outputs))
     iou_loss = box_iou_loss([num_positives, box_targets], box_outputs)
     self.assertAlmostEqual(iou_loss.numpy(), 4.924635, places=5)
Ejemplo n.º 2
0
 def test_losses(self):
     box_loss = train_lib.BoxLoss()
     box_iou_loss = train_lib.BoxIouLoss('ciou')
     alpha = 0.25
     gamma = 1.5
     focal_loss_v2 = train_lib.FocalLoss(
         alpha, gamma, reduction=tf.keras.losses.Reduction.NONE)
     box_outputs = tf.ones([8])
     box_targets = tf.zeros([8])
     num_positives = 4.0
     self.assertEqual(
         legacy_fn._box_loss(box_outputs, box_targets, num_positives),
         box_loss([num_positives, box_targets], box_outputs))
     self.assertEqual(
         legacy_fn._box_iou_loss(box_outputs, box_targets, num_positives,
                                 'ciou'),
         box_iou_loss([num_positives, box_targets], box_outputs))
     self.assertAllEqual(
         legacy_fn.focal_loss(box_outputs, box_targets, alpha, gamma,
                              num_positives),
         focal_loss_v2([num_positives, box_targets], box_outputs))