예제 #1
0
 def test_deeplabv3_builder(self, backbone_type, input_size, weight_decay):
     num_classes = 21
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     model_config = semantic_segmentation_cfg.SemanticSegmentationModel(
         num_classes=num_classes,
         backbone=backbones.Backbone(type=backbone_type,
                                     mobilenet=backbones.MobileNet(
                                         model_id='MobileNetV2',
                                         output_stride=16)),
         decoder=decoders.Decoder(type='aspp',
                                  aspp=decoders.ASPP(level=4,
                                                     num_filters=256,
                                                     dilation_rates=[],
                                                     spp_layer_version='v1',
                                                     output_tensor=True)),
         head=semantic_segmentation_cfg.SegmentationHead(
             level=4,
             low_level=2,
             num_convs=1,
             upsample_factor=2,
             use_depthwise_convolution=True))
     l2_regularizer = (tf.keras.regularizers.l2(weight_decay)
                       if weight_decay else None)
     model = factory.build_segmentation_model(input_specs=input_specs,
                                              model_config=model_config,
                                              l2_regularizer=l2_regularizer)
     quantization_config = common.Quantization()
     _ = qat_factory.build_qat_segmentation_model(
         model=model,
         quantization=quantization_config,
         input_specs=input_specs)
예제 #2
0
 def build_model(self) -> tf.keras.Model:
   """Builds semantic segmentation model with QAT."""
   model = super().build_model()
   input_specs = tf.keras.layers.InputSpec(shape=[None] +
                                           self.task_config.model.input_size)
   if self.task_config.quantization:
     model = factory.build_qat_segmentation_model(
         model, self.task_config.quantization, input_specs)
   return model
예제 #3
0
 def _build_model(self):
   model = super()._build_model()
   input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
                                           self._input_image_size + [3])
   return qat_factory.build_qat_segmentation_model(
       model, self.params.task.quantization, input_specs)