def build_classification_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: classification_cfg.ImageClassificationModel,
        l2_regularizer: tf.keras.regularizers.Regularizer = None,
        skip_logits_layer: bool = False) -> tf.keras.Model:
    """Builds the classification model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=model_config.num_classes,
        input_specs=input_specs,
        dropout_rate=model_config.dropout_rate,
        kernel_regularizer=l2_regularizer,
        add_head_batch_norm=model_config.add_head_batch_norm,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        skip_logits_layer=skip_logits_layer)
    return model
  def test_mobilenet_network_creation(self, mobilenet_model_id,
                                      filter_size_scale):
    """Test for creation of a MobileNet classifier."""
    mobilenet_params = {
        ('MobileNetV1', 1.0): 4254889,
        ('MobileNetV1', 0.75): 2602745,
        ('MobileNetV2', 1.0): 3540265,
        ('MobileNetV2', 0.75): 2664345,
        ('MobileNetV3Large', 1.0): 5508713,
        ('MobileNetV3Large', 0.75): 4013897,
        ('MobileNetV3Small', 1.0): 2555993,
        ('MobileNetV3Small', 0.75): 2052577,
        ('MobileNetV3EdgeTPU', 1.0): 4131593,
        ('MobileNetV3EdgeTPU', 0.75): 3019569,
    }

    inputs = np.random.rand(2, 224, 224, 3)

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.MobileNet(
        model_id=mobilenet_model_id, filter_size_scale=filter_size_scale)

    num_classes = 1001
    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=num_classes,
        dropout_rate=0.2,
    )
    self.assertEqual(model.count_params(),
                     mobilenet_params[(mobilenet_model_id, filter_size_scale)])

    logits = model(inputs)
    self.assertAllEqual([2, num_classes], logits.numpy().shape)
    def test_sync_bn_multiple_devices(self, strategy, use_sync_bn):
        """Test for sync bn on TPU and GPU devices."""
        inputs = np.random.rand(64, 128, 128, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        with strategy.scope():
            backbone = backbones.ResNet(model_id=50, use_sync_bn=use_sync_bn)

            model = classification_model.ClassificationModel(
                backbone=backbone,
                num_classes=1000,
                dropout_rate=0.2,
            )
            _ = model(inputs)
  def test_serialize_deserialize(self):
    """Validate the classification net can be serialized and deserialized."""

    tf.keras.backend.set_image_data_format('channels_last')
    backbone = backbones.ResNet(model_id=50)

    model = classification_model.ClassificationModel(
        backbone=backbone, num_classes=1000)

    config = model.get_config()
    new_model = classification_model.ClassificationModel.from_config(config)

    # Validate that the config can be forced to JSON.
    _ = new_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(model.get_config(), new_model.get_config())
    def test_data_format_gpu(self, strategy, data_format, input_dim):
        """Test for different data formats on GPU devices."""
        if data_format == 'channels_last':
            inputs = np.random.rand(2, 128, 128, input_dim)
        else:
            inputs = np.random.rand(2, input_dim, 128, 128)
        input_specs = tf.keras.layers.InputSpec(shape=inputs.shape)

        tf.keras.backend.set_image_data_format(data_format)

        with strategy.scope():
            backbone = backbones.ResNet(model_id=50, input_specs=input_specs)

            model = classification_model.ClassificationModel(
                backbone=backbone,
                num_classes=1000,
                input_specs=input_specs,
            )
            _ = model(inputs)
Esempio n. 6
0
    def test_mobilenet_network_creation(self, mobilenet_model_id,
                                        filter_size_scale):
        """Test for creation of a MobileNet classifier."""
        inputs = np.random.rand(2, 224, 224, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.MobileNet(model_id=mobilenet_model_id,
                                       filter_size_scale=filter_size_scale)

        num_classes = 1001
        model = classification_model.ClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            dropout_rate=0.2,
        )

        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
    def test_revnet_network_creation(self):
        """Test for creation of a RevNet-56 classifier."""
        revnet_model_id = 56
        inputs = np.random.rand(2, 224, 224, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.RevNet(model_id=revnet_model_id)
        self.assertEqual(backbone.count_params(), 19473792)

        num_classes = 1000
        model = classification_model.ClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            dropout_rate=0.2,
            add_head_batch_norm=True,
        )
        self.assertEqual(model.count_params(), 22816104)

        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
    def test_resnet_network_creation(self, input_size, resnet_model_id,
                                     activation):
        """Test for creation of a ResNet-50 classifier."""
        inputs = np.random.rand(2, input_size, input_size, 3)

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.ResNet(model_id=resnet_model_id,
                                    activation=activation)
        self.assertEqual(backbone.count_params(), 23561152)

        num_classes = 1000
        model = classification_model.ClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            dropout_rate=0.2,
        )
        self.assertEqual(model.count_params(), 25610152)

        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
Esempio n. 9
0
def build_qat_classification_model(
    model: tf.keras.Model,
    quantization: common.Quantization,
    input_specs: tf.keras.layers.InputSpec,
    model_config: configs.image_classification.ImageClassificationModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Apply model optimization techniques.

  Args:
    model: The model applying model optimization techniques.
    quantization: The Quantization config.
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: The model config.
    l2_regularizer: tf.keras.regularizers.Regularizer object. Default to None.

  Returns:
    model: The model that applied optimization techniques.
  """
    original_checkpoint = quantization.pretrained_original_checkpoint
    if original_checkpoint:
        ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
        status = ckpt.read(original_checkpoint)
        status.expect_partial().assert_existing_objects_matched()

    scope_dict = {
        'L2': tf.keras.regularizers.l2,
    }
    with tfmot.quantization.keras.quantize_scope(scope_dict):
        annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
            model.backbone)
        if quantization.change_num_bits:
            backbone = tfmot.quantization.keras.quantize_apply(
                annotated_backbone,
                scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
                    num_bits_weight=quantization.num_bits_weight,
                    num_bits_activation=quantization.num_bits_activation))
        else:
            backbone = tfmot.quantization.keras.quantize_apply(
                annotated_backbone, scheme=schemes.Default8BitQuantizeScheme())

    norm_activation_config = model_config.norm_activation
    backbone_optimized_model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=model_config.num_classes,
        input_specs=input_specs,
        dropout_rate=model_config.dropout_rate,
        kernel_regularizer=l2_regularizer,
        add_head_batch_norm=model_config.add_head_batch_norm,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon)
    for from_layer, to_layer in zip(model.layers,
                                    backbone_optimized_model.layers):
        if from_layer != model.backbone:
            to_layer.set_weights(from_layer.get_weights())

    with tfmot.quantization.keras.quantize_scope(scope_dict):

        def apply_quantization_to_dense(layer):
            if isinstance(layer,
                          (tf.keras.layers.Dense, tf.keras.layers.Dropout,
                           tf.keras.layers.GlobalAveragePooling2D)):
                return tfmot.quantization.keras.quantize_annotate_layer(layer)
            return layer

        annotated_model = tf.keras.models.clone_model(
            backbone_optimized_model,
            clone_function=apply_quantization_to_dense,
        )

        if quantization.change_num_bits:
            optimized_model = tfmot.quantization.keras.quantize_apply(
                annotated_model,
                scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
                    num_bits_weight=quantization.num_bits_weight,
                    num_bits_activation=quantization.num_bits_activation))

        else:
            optimized_model = tfmot.quantization.keras.quantize_apply(
                annotated_model)

    return optimized_model