def build_qat_segmentation_model( model: tf.keras.Model, quantization: common.Quantization, input_specs: tf.keras.layers.InputSpec) -> tf.keras.Model: """Applies quantization aware training for segmentation model. Args: model: The model applying quantization aware training. quantization: The Quantization config. input_specs: The shape specifications of input tensor. Returns: The model that applied optimization techniques. """ original_checkpoint = quantization.pretrained_original_checkpoint if original_checkpoint is not None: ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items) status = ckpt.read(original_checkpoint) status.expect_partial().assert_existing_objects_matched() # Build quantization compatible model. model = qat_segmentation_model.SegmentationModelQuantized( model.backbone, model.decoder, model.head, input_specs) scope_dict = { 'L2': tf.keras.regularizers.l2, } # Apply QAT to backbone (a tf.keras.Model) first. with tfmot.quantization.keras.quantize_scope(scope_dict): annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( model.backbone) optimized_backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=schemes.Default8BitQuantizeScheme()) backbone_optimized_model = qat_segmentation_model.SegmentationModelQuantized( optimized_backbone, model.decoder, model.head, input_specs) # Copy over all remaining layers. for from_layer, to_layer in zip(model.layers, backbone_optimized_model.layers): if from_layer != model.backbone: to_layer.set_weights(from_layer.get_weights()) with tfmot.quantization.keras.quantize_scope(scope_dict): def apply_quantization_to_layers(layer): if isinstance(layer, (segmentation_heads.SegmentationHead, nn_layers.SpatialPyramidPooling, aspp.ASPP)): return tfmot.quantization.keras.quantize_annotate_layer(layer) return layer annotated_model = tf.keras.models.clone_model( backbone_optimized_model, clone_function=apply_quantization_to_layers, ) optimized_model = tfmot.quantization.keras.quantize_apply( annotated_model, scheme=schemes.Default8BitQuantizeScheme()) return optimized_model
def build_qat_retinanet( model: tf.keras.Model, quantization: common.Quantization, model_config: configs.retinanet.RetinaNet) -> tf.keras.Model: """Applies quantization aware training for RetinaNet model. Args: model: The model applying quantization aware training. quantization: The Quantization config. model_config: The model config. Returns: The model that applied optimization techniques. """ original_checkpoint = quantization.pretrained_original_checkpoint if original_checkpoint is not None: ckpt = tf.train.Checkpoint( model=model, **model.checkpoint_items) status = ckpt.read(original_checkpoint) status.expect_partial().assert_existing_objects_matched() scope_dict = { 'L2': tf.keras.regularizers.l2, } with tfmot.quantization.keras.quantize_scope(scope_dict): annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( model.backbone) optimized_backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=schemes.Default8BitQuantizeScheme()) head = model.head if quantization.quantize_detection_head: if not isinstance(head, dense_prediction_heads.RetinaNetHead): raise ValueError('Currently only supports RetinaNetHead.') head = ( dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config( head.get_config())) optimized_model = retinanet_model.RetinaNetModel( optimized_backbone, model.decoder, head, model.detection_generator, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return optimized_model
def build_qat_classification_model( model: tf.keras.Model, quantization: common.Quantization, input_specs: tf.keras.layers.InputSpec, model_config: configs.image_classification.ImageClassificationModel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Apply model optimization techniques. Args: model: The model applying model optimization techniques. quantization: The Quantization config. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. model_config: The model config. l2_regularizer: tf.keras.regularizers.Regularizer object. Default to None. Returns: model: The model that applied optimization techniques. """ original_checkpoint = quantization.pretrained_original_checkpoint if original_checkpoint: ckpt = tf.train.Checkpoint( model=model, **model.checkpoint_items) status = ckpt.read(original_checkpoint) status.expect_partial().assert_existing_objects_matched() scope_dict = { 'L2': tf.keras.regularizers.l2, } with tfmot.quantization.keras.quantize_scope(scope_dict): annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( model.backbone) if quantization.change_num_bits: backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=n_bit_schemes.DefaultNBitQuantizeScheme( num_bits_weight=quantization.num_bits_weight, num_bits_activation=quantization.num_bits_activation)) else: backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=schemes.Default8BitQuantizeScheme()) norm_activation_config = model_config.norm_activation backbone_optimized_model = classification_model.ClassificationModel( backbone=backbone, num_classes=model_config.num_classes, input_specs=input_specs, dropout_rate=model_config.dropout_rate, kernel_regularizer=l2_regularizer, add_head_batch_norm=model_config.add_head_batch_norm, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon) for from_layer, to_layer in zip( model.layers, backbone_optimized_model.layers): if from_layer != model.backbone: to_layer.set_weights(from_layer.get_weights()) with tfmot.quantization.keras.quantize_scope(scope_dict): def apply_quantization_to_dense(layer): if isinstance(layer, (tf.keras.layers.Dense, tf.keras.layers.Dropout, tf.keras.layers.GlobalAveragePooling2D)): return tfmot.quantization.keras.quantize_annotate_layer(layer) return layer annotated_model = tf.keras.models.clone_model( backbone_optimized_model, clone_function=apply_quantization_to_dense, ) if quantization.change_num_bits: optimized_model = tfmot.quantization.keras.quantize_apply( annotated_model, scheme=n_bit_schemes.DefaultNBitQuantizeScheme( num_bits_weight=quantization.num_bits_weight, num_bits_activation=quantization.num_bits_activation)) else: optimized_model = tfmot.quantization.keras.quantize_apply( annotated_model) return optimized_model
def build_qat_retinanet( model: tf.keras.Model, quantization: common.Quantization, model_config: configs.retinanet.RetinaNet) -> tf.keras.Model: """Applies quantization aware training for RetinaNet model. Args: model: The model applying quantization aware training. quantization: The Quantization config. model_config: The model config. Returns: The model that applied optimization techniques. """ original_checkpoint = quantization.pretrained_original_checkpoint if original_checkpoint is not None: ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items) status = ckpt.read(original_checkpoint) status.expect_partial().assert_existing_objects_matched() scope_dict = { 'L2': tf.keras.regularizers.l2, 'BatchNormalizationWrapper': qat_nn_layers.BatchNormalizationWrapper, } with tfmot.quantization.keras.quantize_scope(scope_dict): annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( model.backbone) optimized_backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=schemes.Default8BitQuantizeScheme()) decoder = model.decoder if quantization.quantize_detection_decoder: if not isinstance(decoder, fpn.FPN): raise ValueError('Currently only supports FPN.') decoder = tf.keras.models.clone_model( decoder, clone_function=_clone_function_for_fpn, ) decoder = tfmot.quantization.keras.quantize_apply(decoder) decoder = tfmot.quantization.keras.remove_input_range(decoder) head = model.head if quantization.quantize_detection_head: if not isinstance(head, dense_prediction_heads.RetinaNetHead): raise ValueError('Currently only supports RetinaNetHead.') head = ( dense_prediction_heads_qat.RetinaNetHeadQuantized.from_config( head.get_config())) optimized_model = retinanet_model.RetinaNetModel( optimized_backbone, decoder, head, model.detection_generator, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) if quantization.quantize_detection_head: # Call the model with dummy input to build the head part. dummpy_input = tf.zeros([1] + model_config.input_size) optimized_model(dummpy_input, training=True) helper.copy_original_weights(model.head, optimized_model.head) return optimized_model