Example #1
0
 def test_raise_error_on_empty_box_coder(self):
     box_coder_text_proto = """
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     with self.assertRaises(ValueError):
         box_coder_builder.build(box_coder_proto)
Example #2
0
 def test_build_square_box_coder_with_defaults(self):
     box_coder_text_proto = """
   square_box_coder {
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertTrue(
         isinstance(box_coder_object, square_box_coder.SquareBoxCoder))
     self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0])
Example #3
0
 def test_build_mean_stddev_box_coder(self):
     box_coder_text_proto = """
   mean_stddev_box_coder {
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertTrue(
         isinstance(box_coder_object,
                    mean_stddev_box_coder.MeanStddevBoxCoder))
Example #4
0
 def test_build_keypoint_box_coder_with_defaults(self):
     box_coder_text_proto = """
   keypoint_box_coder {
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertIsInstance(box_coder_object,
                           keypoint_box_coder.KeypointBoxCoder)
     self.assertEqual(box_coder_object._scale_factors,
                      [10.0, 10.0, 5.0, 5.0])
Example #5
0
 def test_build_square_box_coder_with_non_default_parameters(self):
     box_coder_text_proto = """
   square_box_coder {
     y_scale: 6.0
     x_scale: 3.0
     length_scale: 7.0
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertTrue(
         isinstance(box_coder_object, square_box_coder.SquareBoxCoder))
     self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0])
Example #6
0
 def test_build_faster_rcnn_box_coder_with_non_default_parameters(self):
     box_coder_text_proto = """
   faster_rcnn_box_coder {
     y_scale: 6.0
     x_scale: 3.0
     height_scale: 7.0
     width_scale: 8.0
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertIsInstance(box_coder_object,
                           faster_rcnn_box_coder.FasterRcnnBoxCoder)
     self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
Example #7
0
 def test_build_keypoint_box_coder_with_non_default_parameters(self):
     box_coder_text_proto = """
   keypoint_box_coder {
     num_keypoints: 6
     y_scale: 6.0
     x_scale: 3.0
     height_scale: 7.0
     width_scale: 8.0
   }
 """
     box_coder_proto = box_coder_pb2.BoxCoder()
     text_format.Merge(box_coder_text_proto, box_coder_proto)
     box_coder_object = box_coder_builder.build(box_coder_proto)
     self.assertIsInstance(box_coder_object,
                           keypoint_box_coder.KeypointBoxCoder)
     self.assertEqual(box_coder_object._num_keypoints, 6)
     self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
def _build_ssd_model(ssd_config,
                     is_training,
                     add_summaries,
                     add_background_class=True):
    """Builds an SSD detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      SSDMetaArch.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tf summaries in the model.
    add_background_class: Whether to add an implicit background class to one-hot
      encodings of groundtruth labels. Set to false if using groundtruth labels
      with an explicit background class or using multiclass scores instead of
      truth in the case of distillation.
  Returns:
    SSDMetaArch based on the config.

  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
    num_classes = ssd_config.num_classes

    # Feature extractor
    feature_extractor = _build_ssd_feature_extractor(
        feature_extractor_config=ssd_config.feature_extractor,
        is_training=is_training)

    box_coder = box_coder_builder.build(ssd_config.box_coder)
    matcher = matcher_builder.build(ssd_config.matcher)
    region_similarity_calculator = sim_calc.build(
        ssd_config.similarity_calculator)
    encode_background_as_zeros = ssd_config.encode_background_as_zeros
    negative_class_weight = ssd_config.negative_class_weight
    ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
                                                    ssd_config.box_predictor,
                                                    is_training, num_classes)
    anchor_generator = anchor_generator_builder.build(
        ssd_config.anchor_generator)
    image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
    non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
        ssd_config.post_processing)
    (classification_loss, localization_loss, classification_weight,
     localization_weight, hard_example_miner,
     random_example_sampler) = losses_builder.build(ssd_config.loss)
    normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
    normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize

    return ssd_meta_arch.SSDMetaArch(
        is_training,
        anchor_generator,
        ssd_box_predictor,
        box_coder,
        feature_extractor,
        matcher,
        region_similarity_calculator,
        encode_background_as_zeros,
        negative_class_weight,
        image_resizer_fn,
        non_max_suppression_fn,
        score_conversion_fn,
        classification_loss,
        localization_loss,
        classification_weight,
        localization_weight,
        normalize_loss_by_num_matches,
        hard_example_miner,
        add_summaries=add_summaries,
        normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
        freeze_batchnorm=ssd_config.freeze_batchnorm,
        inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
        add_background_class=add_background_class,
        random_example_sampler=random_example_sampler)