Esempio n. 1
0
def mnv2_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
    """Generates a config for MobileNet v2 + deeplab v3+ with QAT."""
    config = semantic_segmentation.mnv2_deeplabv3plus_cityscapes()
    task = SemanticSegmentationTask.from_args(
        quantization=common.Quantization(), **config.task.as_dict())
    config.task = task
    return config
Esempio n. 2
0
 def test_deeplabv3_builder(self, backbone_type, input_size, weight_decay):
     num_classes = 21
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     model_config = semantic_segmentation_cfg.SemanticSegmentationModel(
         num_classes=num_classes,
         backbone=backbones.Backbone(type=backbone_type,
                                     mobilenet=backbones.MobileNet(
                                         model_id='MobileNetV2',
                                         output_stride=16)),
         decoder=decoders.Decoder(type='aspp',
                                  aspp=decoders.ASPP(level=4,
                                                     num_filters=256,
                                                     dilation_rates=[],
                                                     spp_layer_version='v1',
                                                     output_tensor=True)),
         head=semantic_segmentation_cfg.SegmentationHead(
             level=4,
             low_level=2,
             num_convs=1,
             upsample_factor=2,
             use_depthwise_convolution=True))
     l2_regularizer = (tf.keras.regularizers.l2(weight_decay)
                       if weight_decay else None)
     model = factory.build_segmentation_model(input_specs=input_specs,
                                              model_config=model_config,
                                              l2_regularizer=l2_regularizer)
     quantization_config = common.Quantization()
     _ = qat_factory.build_qat_segmentation_model(
         model=model,
         quantization=quantization_config,
         input_specs=input_specs)
def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
    """Builds an image classification config for the mobilenetV2 with QAT."""
    config = image_classification.image_classification_imagenet_mobilenet()
    task = ImageClassificationTask.from_args(
        quantization=common.Quantization(), **config.task.as_dict())
    config.task = task

    return config
Esempio n. 4
0
def image_classification_imagenet() -> cfg.ExperimentConfig:
    """Builds an image classification config for the resnet with QAT."""
    config = image_classification.image_classification_imagenet()
    task = ImageClassificationTask.from_args(
        quantization=common.Quantization(), **config.task.as_dict())
    config.task = task
    runtime = cfg.RuntimeConfig(enable_xla=False)
    config.runtime = runtime

    return config
Esempio n. 5
0
def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
    """Generates a config for COCO OD RetinaNet for mobile with QAT."""
    config = retinanet.retinanet_spinenet_mobile_coco()
    task = RetinaNetTask.from_args(quantization=common.Quantization(),
                                   **config.task.as_dict())
    task.model.backbone = backbones.Backbone(
        type='spinenet_mobile',
        spinenet_mobile=backbones.SpineNetMobile(
            model_id='49',
            stochastic_depth_drop_rate=0.2,
            min_level=3,
            max_level=7,
            use_keras_upsampling_2d=True))
    config.task = task

    return config
Esempio n. 6
0
    def test_builder(self, backbone_type, input_size, weight_decay):
        num_classes = 2
        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size[0], input_size[1], 3])
        model_config = classification_cfg.ImageClassificationModel(
            num_classes=num_classes,
            backbone=backbones.Backbone(type=backbone_type))
        l2_regularizer = (tf.keras.regularizers.l2(weight_decay)
                          if weight_decay else None)
        model = factory.build_classification_model(
            input_specs=input_specs,
            model_config=model_config,
            l2_regularizer=l2_regularizer)

        quantization_config = common.Quantization()
        _ = qat_factory.build_qat_classification_model(
            model=model,
            input_specs=input_specs,
            quantization=quantization_config,
            model_config=model_config,
            l2_regularizer=l2_regularizer)
Esempio n. 7
0
    def test_builder(self, backbone_type, input_size, has_attribute_heads):
        num_classes = 2
        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size[0], input_size[1], 3])
        if has_attribute_heads:
            attribute_heads_config = [
                retinanet_cfg.AttributeHead(name='att1'),
                retinanet_cfg.AttributeHead(name='att2',
                                            type='classification',
                                            size=2),
            ]
        else:
            attribute_heads_config = None
        model_config = retinanet_cfg.RetinaNet(
            num_classes=num_classes,
            backbone=backbones.Backbone(
                type=backbone_type,
                spinenet_mobile=backbones.SpineNetMobile(
                    model_id='49',
                    stochastic_depth_drop_rate=0.2,
                    min_level=3,
                    max_level=7,
                    use_keras_upsampling_2d=True)),
            head=retinanet_cfg.RetinaNetHead(
                attribute_heads=attribute_heads_config))
        l2_regularizer = tf.keras.regularizers.l2(5e-5)
        quantization_config = common.Quantization()
        model = factory.build_retinanet(input_specs=input_specs,
                                        model_config=model_config,
                                        l2_regularizer=l2_regularizer)

        _ = qat_factory.build_qat_retinanet(model=model,
                                            quantization=quantization_config,
                                            model_config=model_config)
        if has_attribute_heads:
            self.assertEqual(model_config.head.attribute_heads[0].as_dict(),
                             dict(name='att1', type='regression', size=1))
            self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
                             dict(name='att2', type='classification', size=2))
Esempio n. 8
0
    def test_builder(self, backbone_type, decoder_type, input_size,
                     quantize_detection_head, quantize_detection_decoder):
        num_classes = 2
        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size[0], input_size[1], 3])

        if backbone_type == 'spinenet_mobile':
            backbone_config = backbones.Backbone(
                type=backbone_type,
                spinenet_mobile=backbones.SpineNetMobile(
                    model_id='49',
                    stochastic_depth_drop_rate=0.2,
                    min_level=3,
                    max_level=7,
                    use_keras_upsampling_2d=True))
        elif backbone_type == 'mobilenet':
            backbone_config = backbones.Backbone(type=backbone_type,
                                                 mobilenet=backbones.MobileNet(
                                                     model_id='MobileNetV2',
                                                     filter_size_scale=1.0))
        else:
            raise ValueError(
                'backbone_type {} is not supported'.format(backbone_type))

        if decoder_type == 'identity':
            decoder_config = decoders.Decoder(type=decoder_type)
        elif decoder_type == 'fpn':
            decoder_config = decoders.Decoder(type=decoder_type,
                                              fpn=decoders.FPN(
                                                  num_filters=128,
                                                  use_separable_conv=True,
                                                  use_keras_layer=True))
        else:
            raise ValueError(
                'decoder_type {} is not supported'.format(decoder_type))

        model_config = retinanet_cfg.RetinaNet(
            num_classes=num_classes,
            input_size=[input_size[0], input_size[1], 3],
            backbone=backbone_config,
            decoder=decoder_config,
            head=retinanet_cfg.RetinaNetHead(attribute_heads=None,
                                             use_separable_conv=True))

        l2_regularizer = tf.keras.regularizers.l2(5e-5)
        # Build the original float32 retinanet model.
        model = factory.build_retinanet(input_specs=input_specs,
                                        model_config=model_config,
                                        l2_regularizer=l2_regularizer)

        # Call the model with dummy input to build the head part.
        dummpy_input = tf.zeros([1] + model_config.input_size)
        model(dummpy_input, training=True)

        # Build the QAT model from the original model with quantization config.
        qat_model = qat_factory.build_qat_retinanet(
            model=model,
            quantization=common.Quantization(
                quantize_detection_decoder=quantize_detection_decoder,
                quantize_detection_head=quantize_detection_head),
            model_config=model_config)

        if quantize_detection_head:
            # head become a RetinaNetHeadQuantized when we apply quantization.
            self.assertIsInstance(
                qat_model.head,
                qat_dense_prediction_heads.RetinaNetHeadQuantized)
        else:
            # head is a RetinaNetHead if we don't apply quantization on head part.
            self.assertIsInstance(qat_model.head,
                                  dense_prediction_heads.RetinaNetHead)
            self.assertNotIsInstance(
                qat_model.head,
                qat_dense_prediction_heads.RetinaNetHeadQuantized)

        if decoder_type == 'FPN':
            if quantize_detection_decoder:
                # FPN decoder become a general keras functional model after applying
                # quantization.
                self.assertNotIsInstance(qat_model.decoder, fpn.FPN)
            else:
                self.assertIsInstance(qat_model.decoder, fpn.FPN)