Ejemplo n.º 1
0
def deepmac_proto_to_params(deepmac_config):
    """Convert proto to named tuple."""

    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(deepmac_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))

    return DeepMACParams(
        dim=deepmac_config.dim,
        classification_loss=classification_loss,
        task_loss_weight=deepmac_config.task_loss_weight,
        pixel_embedding_dim=deepmac_config.pixel_embedding_dim,
        allowed_masked_classes_ids=deepmac_config.allowed_masked_classes_ids,
        mask_size=deepmac_config.mask_size,
        mask_num_subsamples=deepmac_config.mask_num_subsamples,
        use_xy=deepmac_config.use_xy,
        network_type=deepmac_config.network_type,
        use_instance_embedding=deepmac_config.use_instance_embedding,
        num_init_channels=deepmac_config.num_init_channels,
        predict_full_resolution_masks=deepmac_config.
        predict_full_resolution_masks,
        postprocess_crop_size=deepmac_config.postprocess_crop_size)
Ejemplo n.º 2
0
 def test_build_hard_example_miner_for_localization_loss(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   classification_in_image_level_loss {
     weighted_sigmoid {
     }
   }
   hard_example_miner {
     loss_type: LOCALIZATION
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, _, _, hard_example_miner = losses_builder.build(
         losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertEqual(hard_example_miner._loss_type, 'loc')
Ejemplo n.º 3
0
 def test_build_reweighting_unmatched_anchors(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
   }
   classification_weight: 0.8
   localization_weight: 0.2
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     (classification_loss, localization_loss, classification_weight,
      localization_weight, hard_example_miner, _,
      _) = losses_builder.build(losses_proto)
     self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
     self.assertIsInstance(classification_loss,
                           losses.WeightedSoftmaxClassificationLoss)
     self.assertIsInstance(localization_loss,
                           losses.WeightedL2LocalizationLoss)
     self.assertAlmostEqual(classification_weight, 0.8)
     self.assertAlmostEqual(localization_weight, 0.2)
Ejemplo n.º 4
0
 def test_anchorwise_output(self):
     losses_text_proto = """
   localization_loss {
     weighted_smooth_l1 {
       anchorwise_output: true
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   classification_in_image_level_loss {
     weighted_sigmoid {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(
         isinstance(localization_loss,
                    losses.WeightedSmoothL1LocalizationLoss))
     predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0,
                                                        1.0]]])
     targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
     weights = tf.constant([[1.0, 1.0]])
     loss = localization_loss(predictions, targets, weights=weights)
     self.assertEqual(loss.shape, [1, 2])
 def test_build_hard_example_miner_with_non_default_values(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
     num_hard_examples: 32
     iou_threshold: 0.5
     loss_type: LOCALIZATION
     max_negatives_per_positive: 10
     min_negatives_per_image: 3
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertEqual(hard_example_miner._num_hard_examples, 32)
     self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
     self.assertEqual(hard_example_miner._max_negatives_per_positive, 10)
     self.assertEqual(hard_example_miner._min_negatives_per_image, 3)
Ejemplo n.º 6
0
 def test_build_weighted_sigmoid_focal_loss_non_default(self):
     losses_text_proto = """
   classification_loss {
     weighted_sigmoid_focal {
       alpha: 0.25
       gamma: 3.0
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
   classification_in_image_level_loss {
     weighted_sigmoid {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.SigmoidFocalClassificationLoss))
     self.assertAlmostEqual(classification_loss._alpha, 0.25)
     self.assertAlmostEqual(classification_loss._gamma, 3.0)
 def test_build_all_loss_parameters(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   hard_example_miner {
   }
   classification_weight: 0.8
   localization_weight: 0.2
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     (classification_loss, localization_loss, classification_weight,
      localization_weight,
      hard_example_miner) = losses_builder.build(losses_proto)
     self.assertTrue(isinstance(hard_example_miner,
                                losses.HardExampleMiner))
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSoftmaxClassificationLoss))
     self.assertTrue(
         isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
     self.assertAlmostEqual(classification_weight, 0.8)
     self.assertAlmostEqual(localization_weight, 0.2)
 def test_raise_error_on_empty_config(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     with self.assertRaises(ValueError):
         losses_builder.build(losses_proto)
Ejemplo n.º 9
0
def temporal_offset_proto_to_params(temporal_offset_config):
    """Converts CenterNet.TemporalOffsetEstimation proto to param-tuple."""
    loss = losses_pb2.Loss()
    # Add dummy classification loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the classification loss
    # directly.
    loss.classification_loss.weighted_sigmoid.CopyFrom(
        losses_pb2.WeightedSigmoidClassificationLoss())
    loss.localization_loss.CopyFrom(temporal_offset_config.localization_loss)
    _, localization_loss, _, _, _, _, _ = losses_builder.build(loss)
    return center_net_meta_arch.TemporalOffsetParams(
        localization_loss=localization_loss,
        task_loss_weight=temporal_offset_config.task_loss_weight)
 def test_do_not_build_hard_example_miner_by_default(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
     self.assertEqual(hard_example_miner, None)
Ejemplo n.º 11
0
def object_detection_proto_to_params(od_config):
    """Converts CenterNet.ObjectDetection proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy classification loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the classification loss
    # directly.
    loss.classification_loss.weighted_sigmoid.CopyFrom(
        losses_pb2.WeightedSigmoidClassificationLoss())
    loss.localization_loss.CopyFrom(od_config.localization_loss)
    _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.ObjectDetectionParams(
        localization_loss=localization_loss,
        scale_loss_weight=od_config.scale_loss_weight,
        offset_loss_weight=od_config.offset_loss_weight,
        task_loss_weight=od_config.task_loss_weight)
Ejemplo n.º 12
0
def mask_proto_to_params(mask_config):
    """Converts CenterNet.MaskEstimation proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(mask_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.MaskParams(
        classification_loss=classification_loss,
        task_loss_weight=mask_config.task_loss_weight,
        mask_height=mask_config.mask_height,
        mask_width=mask_config.mask_width,
        score_threshold=mask_config.score_threshold,
        heatmap_bias_init=mask_config.heatmap_bias_init)
Ejemplo n.º 13
0
 def test_build_weighted_sigmoid_classification_loss(self):
   losses_text_proto = """
     classification_loss {
       weighted_sigmoid {
       }
     }
     localization_loss {
       weighted_l2 {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertTrue(isinstance(classification_loss,
                              losses.WeightedSigmoidClassificationLoss))
Ejemplo n.º 14
0
def object_center_proto_to_params(oc_config):
    """Converts CenterNet.ObjectCenter proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the localization loss
    # directly.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(oc_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.ObjectCenterParams(
        classification_loss=classification_loss,
        object_center_loss_weight=oc_config.object_center_loss_weight,
        heatmap_bias_init=oc_config.heatmap_bias_init,
        min_box_overlap_iou=oc_config.min_box_overlap_iou,
        max_box_predictions=oc_config.max_box_predictions)
Ejemplo n.º 15
0
 def test_build_weighted_softmax_classification_loss(self):
   losses_text_proto = """
     classification_loss {
       weighted_softmax {
       }
     }
     localization_loss {
       weighted_l2 {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertIsInstance(classification_loss,
                         losses.WeightedSoftmaxClassificationLoss)
 def test_build_weighted_iou_localization_loss(self):
     losses_text_proto = """
   localization_loss {
     weighted_iou {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _ = losses_builder.build(losses_proto)
     self.assertTrue(
         isinstance(localization_loss, losses.WeightedIOULocalizationLoss))
Ejemplo n.º 17
0
def tracking_proto_to_params(tracking_config):
    """Converts CenterNet.TrackEstimation proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the localization loss
    # directly.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(tracking_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = losses_builder.build(loss)
    return center_net_meta_arch.TrackParams(
        num_track_ids=tracking_config.num_track_ids,
        reid_embed_size=tracking_config.reid_embed_size,
        classification_loss=classification_loss,
        num_fc_layers=tracking_config.num_fc_layers,
        task_loss_weight=tracking_config.task_loss_weight)
Ejemplo n.º 18
0
 def test_build_bootstrapped_sigmoid_classification_loss(self):
   losses_text_proto = """
     classification_loss {
       bootstrapped_sigmoid {
         alpha: 0.5
       }
     }
     localization_loss {
       weighted_l2 {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertIsInstance(classification_loss,
                         losses.BootstrappedSigmoidClassificationLoss)
Ejemplo n.º 19
0
 def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
   losses_text_proto = """
     localization_loss {
       weighted_smooth_l1 {
       }
     }
     classification_loss {
       weighted_softmax {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertTrue(isinstance(localization_loss,
                              losses.WeightedSmoothL1LocalizationLoss))
   self.assertAlmostEqual(localization_loss._delta, 1.0)
Ejemplo n.º 20
0
 def test_build_weighted_sigmoid_focal_classification_loss(self):
   losses_text_proto = """
     classification_loss {
       weighted_sigmoid_focal {
       }
     }
     localization_loss {
       weighted_l2 {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertTrue(isinstance(classification_loss,
                              losses.SigmoidFocalClassificationLoss))
   self.assertAlmostEqual(classification_loss._alpha, None)
   self.assertAlmostEqual(classification_loss._gamma, 2.0)
Ejemplo n.º 21
0
 def test_build_dice_loss(self):
     losses_text_proto = """
   classification_loss {
     weighted_dice_classification_loss {
       squared_normalization: true
     }
   }
   localization_loss {
     l1_localization_loss {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertIsInstance(classification_loss,
                           losses.WeightedDiceClassificationLoss)
     assert classification_loss._squared_normalization
Ejemplo n.º 22
0
 def test_build_hard_example_miner_for_classification_loss(self):
   losses_text_proto = """
     localization_loss {
       weighted_l2 {
       }
     }
     classification_loss {
       weighted_softmax {
       }
     }
     hard_example_miner {
       loss_type: CLASSIFICATION
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
   self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
   self.assertEqual(hard_example_miner._loss_type, 'cls')
Ejemplo n.º 23
0
 def test_raise_error_when_both_focal_loss_and_hard_example_miner(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_sigmoid_focal {
     }
   }
   hard_example_miner {
   }
   classification_weight: 0.8
   localization_weight: 0.2
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     with self.assertRaises(ValueError):
         losses_builder.build(losses_proto)
Ejemplo n.º 24
0
 def test_build_penalty_reduced_logistic_focal_loss(self):
     losses_text_proto = """
   classification_loss {
     penalty_reduced_logistic_focal_loss {
       alpha: 2.0
       beta: 4.0
     }
   }
   localization_loss {
     l1_localization_loss {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertIsInstance(classification_loss,
                           losses.PenaltyReducedLogisticFocalLoss)
     self.assertAlmostEqual(classification_loss._alpha, 2.0)
     self.assertAlmostEqual(classification_loss._beta, 4.0)
Ejemplo n.º 25
0
 def test_build_weighted_l2_localization_loss(self):
     losses_text_proto = """
   localization_loss {
     weighted_l2 {
     }
   }
   classification_loss {
     weighted_softmax {
     }
   }
   classification_in_image_level_loss {
     weighted_sigmoid {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     _, localization_loss, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(
         isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
Ejemplo n.º 26
0
 def test_anchorwise_output(self):
   losses_text_proto = """
     classification_loss {
       weighted_sigmoid {
         anchorwise_output: true
       }
     }
     localization_loss {
       weighted_l2 {
       }
     }
   """
   losses_proto = losses_pb2.Loss()
   text_format.Merge(losses_text_proto, losses_proto)
   classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
   self.assertTrue(isinstance(classification_loss,
                              losses.WeightedSigmoidClassificationLoss))
   predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
   targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
   weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
   loss = classification_loss(predictions, targets, weights=weights)
   self.assertEqual(loss.shape, [1, 2, 3])
Ejemplo n.º 27
0
 def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
     losses_text_proto = """
   classification_loss {
     weighted_softmax {
       logit_scale: 2.0
     }
   }
   localization_loss {
     weighted_l2 {
     }
   }
   classification_in_image_level_loss {
     weighted_sigmoid {
     }
   }
 """
     losses_proto = losses_pb2.Loss()
     text_format.Merge(losses_text_proto, losses_proto)
     classification_loss, _, _, _, _, _, _ = losses_builder.build(
         losses_proto)
     self.assertTrue(
         isinstance(classification_loss,
                    losses.WeightedSoftmaxClassificationLoss))