Exemplo n.º 1
0
 def test_builder(self, backbone_type, input_size, has_att_heads):
     num_classes = 2
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     if has_att_heads:
         attribute_heads_config = [
             retinanet_cfg.AttributeHead(name='att1'),
             retinanet_cfg.AttributeHead(name='att2',
                                         type='classification',
                                         size=2),
         ]
     else:
         attribute_heads_config = None
     model_config = retinanet_cfg.RetinaNet(
         num_classes=num_classes,
         backbone=backbones.Backbone(type=backbone_type),
         head=retinanet_cfg.RetinaNetHead(
             attribute_heads=attribute_heads_config))
     l2_regularizer = tf.keras.regularizers.l2(5e-5)
     _ = factory.build_retinanet(input_specs=input_specs,
                                 model_config=model_config,
                                 l2_regularizer=l2_regularizer)
     if has_att_heads:
         self.assertEqual(model_config.head.attribute_heads[0].as_dict(),
                          dict(name='att1', type='regression', size=1))
         self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
                          dict(name='att2', type='classification', size=2))
Exemplo n.º 2
0
 def test_builder(self, backbone_type, input_size):
     num_classes = 2
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     model_config = retinanet_cfg.RetinaNet(
         num_classes=num_classes,
         backbone=backbones.Backbone(type=backbone_type))
     l2_regularizer = tf.keras.regularizers.l2(5e-5)
     _ = factory.build_retinanet(input_specs=input_specs,
                                 model_config=model_config,
                                 l2_regularizer=l2_regularizer)
Exemplo n.º 3
0
    def build_model(self):
        """Build RetinaNet model."""

        input_specs = tf.keras.layers.InputSpec(
            shape=[None] + self.task_config.model.input_size)

        l2_weight_decay = self.task_config.losses.l2_weight_decay
        # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
        # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
        # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
        l2_regularizer = (tf.keras.regularizers.l2(l2_weight_decay / 2.0)
                          if l2_weight_decay else None)

        model = factory.build_retinanet(input_specs=input_specs,
                                        model_config=self.task_config.model,
                                        l2_regularizer=l2_regularizer)
        return model
Exemplo n.º 4
0
  def _build_model(self):

    if self._batch_size is None:
      raise ValueError('batch_size cannot be None for detection models.')
    input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                            self._input_image_size + [3])

    if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
      model = factory.build_maskrcnn(
          input_specs=input_specs, model_config=self.params.task.model)
    elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
      model = factory.build_retinanet(
          input_specs=input_specs, model_config=self.params.task.model)
    else:
      raise ValueError('Detection module not implemented for {} model.'.format(
          type(self.params.task.model)))

    return model
    def build_model(self):

        if self._batch_size is None:
            ValueError("batch_size can't be None for detection models")
        if not self._params.task.model.detection_generator.use_batched_nms:
            ValueError('Only batched_nms is supported.')
        input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                                self._input_image_size + [3])

        if isinstance(self._params.task.model, configs.maskrcnn.MaskRCNN):
            self._model = factory.build_maskrcnn(
                input_specs=input_specs, model_config=self._params.task.model)
        elif isinstance(self._params.task.model, configs.retinanet.RetinaNet):
            self._model = factory.build_retinanet(
                input_specs=input_specs, model_config=self._params.task.model)
        else:
            raise ValueError(
                'Detection module not implemented for {} model.'.format(
                    type(self._params.task.model)))

        return self._model
Exemplo n.º 6
0
    def test_builder(self, backbone_type, input_size, has_attribute_heads):
        num_classes = 2
        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size[0], input_size[1], 3])
        if has_attribute_heads:
            attribute_heads_config = [
                retinanet_cfg.AttributeHead(name='att1'),
                retinanet_cfg.AttributeHead(name='att2',
                                            type='classification',
                                            size=2),
            ]
        else:
            attribute_heads_config = None
        model_config = retinanet_cfg.RetinaNet(
            num_classes=num_classes,
            backbone=backbones.Backbone(
                type=backbone_type,
                spinenet_mobile=backbones.SpineNetMobile(
                    model_id='49',
                    stochastic_depth_drop_rate=0.2,
                    min_level=3,
                    max_level=7,
                    use_keras_upsampling_2d=True)),
            head=retinanet_cfg.RetinaNetHead(
                attribute_heads=attribute_heads_config))
        l2_regularizer = tf.keras.regularizers.l2(5e-5)
        quantization_config = common.Quantization()
        model = factory.build_retinanet(input_specs=input_specs,
                                        model_config=model_config,
                                        l2_regularizer=l2_regularizer)

        _ = qat_factory.build_qat_retinanet(model=model,
                                            quantization=quantization_config,
                                            model_config=model_config)
        if has_attribute_heads:
            self.assertEqual(model_config.head.attribute_heads[0].as_dict(),
                             dict(name='att1', type='regression', size=1))
            self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
                             dict(name='att2', type='classification', size=2))