def build_classification_model( input_specs: tf.keras.layers.InputSpec, model_config: classification_cfg.ImageClassificationModel, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, skip_logits_layer: bool = False, backbone: Optional[tf.keras.Model] = None) -> tf.keras.Model: """Builds the classification model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) model = classification_model.ClassificationModel( backbone=backbone, num_classes=model_config.num_classes, input_specs=input_specs, dropout_rate=model_config.dropout_rate, kernel_initializer=model_config.kernel_initializer, kernel_regularizer=l2_regularizer, add_head_batch_norm=model_config.add_head_batch_norm, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, skip_logits_layer=skip_logits_layer) return model
def test_sync_bn_multiple_devices(self, strategy, use_sync_bn): """Test for sync bn on TPU and GPU devices.""" inputs = np.random.rand(64, 128, 128, 3) tf.keras.backend.set_image_data_format('channels_last') with strategy.scope(): backbone = backbones.ResNet(model_id=50, use_sync_bn=use_sync_bn) model = classification_model.ClassificationModel( backbone=backbone, num_classes=1000, dropout_rate=0.2, ) _ = model(inputs)
def test_serialize_deserialize(self): """Validate the classification net can be serialized and deserialized.""" tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.ResNet(model_id=50) model = classification_model.ClassificationModel( backbone=backbone, num_classes=1000) config = model.get_config() new_model = classification_model.ClassificationModel.from_config(config) # Validate that the config can be forced to JSON. _ = new_model.to_json() # If the serialization was successful, the new config should match the old. self.assertAllEqual(model.get_config(), new_model.get_config())
def test_mobilenet_network_creation(self, mobilenet_model_id, filter_size_scale): """Test for creation of a MobileNet classifier.""" inputs = np.random.rand(2, 224, 224, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.MobileNet( model_id=mobilenet_model_id, filter_size_scale=filter_size_scale) num_classes = 1001 model = classification_model.ClassificationModel( backbone=backbone, num_classes=num_classes, dropout_rate=0.2, ) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)
def test_data_format_gpu(self, strategy, data_format, input_dim): """Test for different data formats on GPU devices.""" if data_format == 'channels_last': inputs = np.random.rand(2, 128, 128, input_dim) else: inputs = np.random.rand(2, input_dim, 128, 128) input_specs = tf.keras.layers.InputSpec(shape=inputs.shape) tf.keras.backend.set_image_data_format(data_format) with strategy.scope(): backbone = backbones.ResNet(model_id=50, input_specs=input_specs) model = classification_model.ClassificationModel( backbone=backbone, num_classes=1000, input_specs=input_specs, ) _ = model(inputs)
def test_revnet_network_creation(self): """Test for creation of a RevNet-56 classifier.""" revnet_model_id = 56 inputs = np.random.rand(2, 224, 224, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.RevNet(model_id=revnet_model_id) self.assertEqual(backbone.count_params(), 19473792) num_classes = 1000 model = classification_model.ClassificationModel( backbone=backbone, num_classes=num_classes, dropout_rate=0.2, add_head_batch_norm=True, ) self.assertEqual(model.count_params(), 22816104) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)
def test_resnet_network_creation( self, input_size, resnet_model_id, activation): """Test for creation of a ResNet-50 classifier.""" inputs = np.random.rand(2, input_size, input_size, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.ResNet( model_id=resnet_model_id, activation=activation) self.assertEqual(backbone.count_params(), 23561152) num_classes = 1000 model = classification_model.ClassificationModel( backbone=backbone, num_classes=num_classes, dropout_rate=0.2, ) self.assertEqual(model.count_params(), 25610152) logits = model(inputs) self.assertAllEqual([2, num_classes], logits.numpy().shape)
def build_qat_classification_model( model: tf.keras.Model, quantization: common.Quantization, input_specs: tf.keras.layers.InputSpec, model_config: configs.image_classification.ImageClassificationModel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Apply model optimization techniques. Args: model: The model applying model optimization techniques. quantization: The Quantization config. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. model_config: The model config. l2_regularizer: tf.keras.regularizers.Regularizer object. Default to None. Returns: model: The model that applied optimization techniques. """ original_checkpoint = quantization.pretrained_original_checkpoint if original_checkpoint: ckpt = tf.train.Checkpoint( model=model, **model.checkpoint_items) status = ckpt.read(original_checkpoint) status.expect_partial().assert_existing_objects_matched() scope_dict = { 'L2': tf.keras.regularizers.l2, } with tfmot.quantization.keras.quantize_scope(scope_dict): annotated_backbone = tfmot.quantization.keras.quantize_annotate_model( model.backbone) if quantization.change_num_bits: backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=n_bit_schemes.DefaultNBitQuantizeScheme( num_bits_weight=quantization.num_bits_weight, num_bits_activation=quantization.num_bits_activation)) else: backbone = tfmot.quantization.keras.quantize_apply( annotated_backbone, scheme=schemes.Default8BitQuantizeScheme()) norm_activation_config = model_config.norm_activation backbone_optimized_model = classification_model.ClassificationModel( backbone=backbone, num_classes=model_config.num_classes, input_specs=input_specs, dropout_rate=model_config.dropout_rate, kernel_regularizer=l2_regularizer, add_head_batch_norm=model_config.add_head_batch_norm, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon) for from_layer, to_layer in zip( model.layers, backbone_optimized_model.layers): if from_layer != model.backbone: to_layer.set_weights(from_layer.get_weights()) with tfmot.quantization.keras.quantize_scope(scope_dict): def apply_quantization_to_dense(layer): if isinstance(layer, (tf.keras.layers.Dense, tf.keras.layers.Dropout, tf.keras.layers.GlobalAveragePooling2D)): return tfmot.quantization.keras.quantize_annotate_layer(layer) return layer annotated_model = tf.keras.models.clone_model( backbone_optimized_model, clone_function=apply_quantization_to_dense, ) if quantization.change_num_bits: optimized_model = tfmot.quantization.keras.quantize_apply( annotated_model, scheme=n_bit_schemes.DefaultNBitQuantizeScheme( num_bits_weight=quantization.num_bits_weight, num_bits_activation=quantization.num_bits_activation)) else: optimized_model = tfmot.quantization.keras.quantize_apply( annotated_model) return optimized_model