示例#1
0
 def test_build_reweighting_unmatched_anchors(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
   }
   classification_weight: 0.8
   localization_weight: 0.2
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     (classification_loss, localization_loss, classification_weight,
      localization_weight, hard_example_miner, _,
      _) = losses_builder.build(losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSoftmaxClassificationLoss))
     self.assertTrue(
         isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
     self.assertAlmostEqual(classification_weight, 0.8)
     self.assertAlmostEqual(localization_weight, 0.2)
示例#2
0
 def test_build_hard_example_miner_with_non_default_values(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
     num_hard_examples: 32
     iou_threshold: 0.5
     loss_type: LOCALIZATION
     max_negatives_per_positive: 10
     min_negatives_per_image: 3
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, hard_example_miner, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertEqual(hard_example_miner._num_hard_examples, 32)
     self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
     self.assertEqual(hard_example_miner._max_negatives_per_positive, 10)
     self.assertEqual(hard_example_miner._min_negatives_per_image, 3)
示例#3
0
 def test_raise_error_on_empty_config(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     with self.assertRaises(ValueError):
         losses_builder.build(losses_proto)
示例#4
0
 def test_do_not_build_hard_example_miner_by_default(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
     self.assertEqual(hard_example_miner, None)
示例#5
0
 def test_build_weighted_iou_localization_loss(self):
     losses_text_proto = """
   localization_loss {
     weighted_iou {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(localization_loss, losses.WeightedIOULocalizationLoss))
示例#6
0
 def test_build_weighted_sigmoid_classification_loss(self):
     losses_text_proto = """
   classification_loss {
     weighted_sigmoid {
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSigmoidClassificationLoss))
示例#7
0
 def test_build_weighted_logits_softmax_classification_loss(self):
     losses_text_proto = """
   classification_loss {
     weighted_logits_softmax {
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSoftmaxClassificationAgainstLogitsLoss))
 def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
     losses_text_proto = """
   localization_loss {
     weighted_smooth_l1 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(localization_loss,
                    losses.WeightedSmoothL1LocalizationLoss))
     self.assertAlmostEqual(localization_loss._delta, 1.0)
示例#9
0
 def test_raise_error_when_both_focal_loss_and_hard_example_miner(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_sigmoid_focal {
     }
   }
   hard_example_miner {
   }
   classification_weight: 0.8
   localization_weight: 0.2
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     with self.assertRaises(ValueError):
         losses_builder.build(losses_proto)
示例#10
0
 def test_build_weighted_sigmoid_focal_classification_loss(self):
     losses_text_proto = """
   classification_loss {
     weighted_sigmoid_focal {
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.SigmoidFocalClassificationLoss))
     self.assertAlmostEqual(classification_loss._alpha, None)
     self.assertAlmostEqual(classification_loss._gamma, 2.0)
示例#11
0
 def test_build_hard_example_miner_for_localization_loss(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
     loss_type: LOCALIZATION
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertEqual(hard_example_miner._loss_type, 'loc')
示例#12
0
 def test_anchorwise_output(self):
     losses_text_proto = """
   classification_loss {
     weighted_sigmoid {
       anchorwise_output: true
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSigmoidClassificationLoss))
     predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
     targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
     weights = tf.constant([[1.0, 1.0]])
     loss = classification_loss(predictions, targets, weights=weights)
     self.assertEqual(loss.shape, [1, 2])
示例#13
0
 def test_anchorwise_output(self):
     losses_text_proto = """
   localization_loss {
     weighted_smooth_l1 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(localization_loss,
                    losses.WeightedSmoothL1LocalizationLoss))
     predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0,
                                                        1.0]]])
     targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
     weights = tf.constant([[1.0, 1.0]])
     loss = localization_loss(predictions, targets, weights=weights)
     self.assertEqual(loss.shape, [1, 2])