def _create_centernet_model(model_id: int = 52,
                            num_hourglasses: int = 2
                            ) -> centernet_model.CenterNetModel:
  """Create centernet model to load TF1 weights."""
  task_config = centernet.CenterNetTask(
      model=centernet.CenterNetModel(
          backbone=backbones.Backbone(
              type="hourglass",
              hourglass=backbones.Hourglass(
                  model_id=model_id, num_hourglasses=num_hourglasses))))
  model_config = task_config.model

  backbone = factory.build_backbone(
      input_specs=tf.keras.layers.InputSpec(shape=[1, 512, 512, 3]),
      backbone_config=model_config.backbone,
      norm_activation_config=model_config.norm_activation)

  task_outputs = task_config.get_output_length_dict()
  head = centernet_head.CenterNetHead(
      input_specs=backbone.output_specs,
      task_outputs=task_outputs,
      input_levels=model_config.head.input_levels)

  detect_generator_obj = detection_generator.CenterNetDetectionGenerator()

  model = centernet_model.CenterNetModel(
      backbone=backbone, head=head, detection_generator=detect_generator_obj)
  logging.info("Successfully created centernet model.")

  return model
Ejemplo n.º 2
0
    def testBuildCenterNet(self):
        backbone = hourglass.build_hourglass(
            input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]),
            backbone_config=backbones.Backbone(type='hourglass'),
            norm_activation_config=common.NormActivation(use_sync_bn=True))

        task_config = {
            'ct_heatmaps': 90,
            'ct_offset': 2,
            'ct_size': 2,
        }

        input_levels = ['2_0', '2']

        head = centernet_head.CenterNetHead(task_outputs=task_config,
                                            input_specs=backbone.output_specs,
                                            input_levels=input_levels)

        detection_ge = detection_generator.CenterNetDetectionGenerator()

        model = centernet_model.CenterNetModel(
            backbone=backbone, head=head, detection_generator=detection_ge)

        outputs = model(tf.zeros((5, 512, 512, 3)))
        self.assertLen(outputs['raw_output'], 3)
        self.assertLen(outputs['raw_output']['ct_heatmaps'], 2)
        self.assertLen(outputs['raw_output']['ct_offset'], 2)
        self.assertLen(outputs['raw_output']['ct_size'], 2)
        self.assertEqual(outputs['raw_output']['ct_heatmaps'][0].shape,
                         (5, 128, 128, 90))
        self.assertEqual(outputs['raw_output']['ct_offset'][0].shape,
                         (5, 128, 128, 2))
        self.assertEqual(outputs['raw_output']['ct_size'][0].shape,
                         (5, 128, 128, 2))
Ejemplo n.º 3
0
 def test_hourglass(self):
     backbone = hourglass.build_hourglass(
         input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]),
         backbone_config=backbones.Backbone(type='hourglass'),
         norm_activation_config=common.NormActivation(use_sync_bn=True))
     inputs = np.zeros((2, 512, 512, 3), dtype=np.float32)
     outputs = backbone(inputs)
     self.assertEqual(outputs['2_0'].shape, (2, 128, 128, 256))
     self.assertEqual(outputs['2'].shape, (2, 128, 128, 256))
Ejemplo n.º 4
0
class CenterNetModel(hyperparams.Config):
  """Config for centernet model."""
  num_classes: int = 90
  max_num_instances: int = 128
  input_size: List[int] = dataclasses.field(default_factory=list)
  backbone: backbones.Backbone = backbones.Backbone(
      type='hourglass', hourglass=backbones.Hourglass(model_id=52))
  head: CenterNetHead = CenterNetHead()
  # pylint: disable=line-too-long
  detection_generator: CenterNetDetectionGenerator = CenterNetDetectionGenerator()
  norm_activation: common.NormActivation = common.NormActivation(
      norm_momentum=0.1, norm_epsilon=1e-5, use_sync_bn=True)