def _create_centernet_model(model_id: int = 52,
                            num_hourglasses: int = 2
                            ) -> centernet_model.CenterNetModel:
  """Create centernet model to load TF1 weights."""
  task_config = centernet.CenterNetTask(
      model=centernet.CenterNetModel(
          backbone=backbones.Backbone(
              type="hourglass",
              hourglass=backbones.Hourglass(
                  model_id=model_id, num_hourglasses=num_hourglasses))))
  model_config = task_config.model

  backbone = factory.build_backbone(
      input_specs=tf.keras.layers.InputSpec(shape=[1, 512, 512, 3]),
      backbone_config=model_config.backbone,
      norm_activation_config=model_config.norm_activation)

  task_outputs = task_config.get_output_length_dict()
  head = centernet_head.CenterNetHead(
      input_specs=backbone.output_specs,
      task_outputs=task_outputs,
      input_levels=model_config.head.input_levels)

  detect_generator_obj = detection_generator.CenterNetDetectionGenerator()

  model = centernet_model.CenterNetModel(
      backbone=backbone, head=head, detection_generator=detect_generator_obj)
  logging.info("Successfully created centernet model.")

  return model
Ejemplo n.º 2
0
    def testBuildCenterNet(self):
        backbone = hourglass.build_hourglass(
            input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]),
            backbone_config=backbones.Backbone(type='hourglass'),
            norm_activation_config=common.NormActivation(use_sync_bn=True))

        task_config = {
            'ct_heatmaps': 90,
            'ct_offset': 2,
            'ct_size': 2,
        }

        input_levels = ['2_0', '2']

        head = centernet_head.CenterNetHead(task_outputs=task_config,
                                            input_specs=backbone.output_specs,
                                            input_levels=input_levels)

        detection_ge = detection_generator.CenterNetDetectionGenerator()

        model = centernet_model.CenterNetModel(
            backbone=backbone, head=head, detection_generator=detection_ge)

        outputs = model(tf.zeros((5, 512, 512, 3)))
        self.assertLen(outputs['raw_output'], 3)
        self.assertLen(outputs['raw_output']['ct_heatmaps'], 2)
        self.assertLen(outputs['raw_output']['ct_offset'], 2)
        self.assertLen(outputs['raw_output']['ct_size'], 2)
        self.assertEqual(outputs['raw_output']['ct_heatmaps'][0].shape,
                         (5, 128, 128, 90))
        self.assertEqual(outputs['raw_output']['ct_offset'][0].shape,
                         (5, 128, 128, 2))
        self.assertEqual(outputs['raw_output']['ct_size'][0].shape,
                         (5, 128, 128, 2))
Ejemplo n.º 3
0
    def build_model(self):
        """get an instance of CenterNet."""
        model_config = self.task_config.model
        input_specs = tf.keras.layers.InputSpec(shape=[None] +
                                                model_config.input_size)

        l2_weight_decay = self.task_config.weight_decay
        # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
        # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
        # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
        l2_regularizer = (tf.keras.regularizers.l2(l2_weight_decay / 2.0)
                          if l2_weight_decay else None)

        backbone = factory.build_backbone(
            input_specs=input_specs,
            backbone_config=model_config.backbone,
            norm_activation_config=model_config.norm_activation,
            l2_regularizer=l2_regularizer)

        task_outputs = self.task_config.get_output_length_dict()
        head_config = model_config.head
        head = centernet_head.CenterNetHead(
            input_specs=backbone.output_specs,
            task_outputs=task_outputs,
            input_levels=head_config.input_levels,
            heatmap_bias=head_config.heatmap_bias)

        # output_specs is a dict
        backbone_output_spec = backbone.output_specs[
            head_config.input_levels[-1]]
        if len(backbone_output_spec) == 4:
            bb_output_height = backbone_output_spec[1]
        elif len(backbone_output_spec) == 3:
            bb_output_height = backbone_output_spec[0]
        else:
            raise ValueError
        self._net_down_scale = int(model_config.input_size[0] /
                                   bb_output_height)
        dg_config = model_config.detection_generator
        detect_generator_obj = detection_generator.CenterNetDetectionGenerator(
            max_detections=dg_config.max_detections,
            peak_error=dg_config.peak_error,
            peak_extract_kernel_size=dg_config.peak_extract_kernel_size,
            class_offset=dg_config.class_offset,
            net_down_scale=self._net_down_scale,
            input_image_dims=model_config.input_size[0],
            use_nms=dg_config.use_nms,
            nms_pre_thresh=dg_config.nms_pre_thresh,
            nms_thresh=dg_config.nms_thresh)

        model = centernet_model.CenterNetModel(
            backbone=backbone,
            head=head,
            detection_generator=detect_generator_obj)

        return model