Esempio n. 1
0
def build_classification_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: classification_cfg.ImageClassificationModel,
        l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
        skip_logits_layer: bool = False,
        backbone: Optional[tf.keras.Model] = None) -> tf.keras.Model:
    """Builds the classification model."""
    norm_activation_config = model_config.norm_activation
    if not backbone:
        backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=model_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)

    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=model_config.num_classes,
        input_specs=input_specs,
        dropout_rate=model_config.dropout_rate,
        kernel_initializer=model_config.kernel_initializer,
        kernel_regularizer=l2_regularizer,
        add_head_batch_norm=model_config.add_head_batch_norm,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        skip_logits_layer=skip_logits_layer)
    return model
Esempio n. 2
0
  def test_sync_bn_multiple_devices(self, strategy, use_sync_bn):
    """Test for sync bn on TPU and GPU devices."""
    inputs = np.random.rand(64, 128, 128, 3)

    tf.keras.backend.set_image_data_format('channels_last')

    with strategy.scope():
      backbone = backbones.ResNet(model_id=50, use_sync_bn=use_sync_bn)

      model = classification_model.ClassificationModel(
          backbone=backbone,
          num_classes=1000,
          dropout_rate=0.2,
      )
      _ = model(inputs)
Esempio n. 3
0
  def test_serialize_deserialize(self):
    """Validate the classification net can be serialized and deserialized."""

    tf.keras.backend.set_image_data_format('channels_last')
    backbone = backbones.ResNet(model_id=50)

    model = classification_model.ClassificationModel(
        backbone=backbone, num_classes=1000)

    config = model.get_config()
    new_model = classification_model.ClassificationModel.from_config(config)

    # Validate that the config can be forced to JSON.
    _ = new_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(model.get_config(), new_model.get_config())
Esempio n. 4
0
  def test_mobilenet_network_creation(self, mobilenet_model_id,
                                      filter_size_scale):
    """Test for creation of a MobileNet classifier."""
    inputs = np.random.rand(2, 224, 224, 3)

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.MobileNet(
        model_id=mobilenet_model_id, filter_size_scale=filter_size_scale)

    num_classes = 1001
    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=num_classes,
        dropout_rate=0.2,
    )

    logits = model(inputs)
    self.assertAllEqual([2, num_classes], logits.numpy().shape)
Esempio n. 5
0
  def test_data_format_gpu(self, strategy, data_format, input_dim):
    """Test for different data formats on GPU devices."""
    if data_format == 'channels_last':
      inputs = np.random.rand(2, 128, 128, input_dim)
    else:
      inputs = np.random.rand(2, input_dim, 128, 128)
    input_specs = tf.keras.layers.InputSpec(shape=inputs.shape)

    tf.keras.backend.set_image_data_format(data_format)

    with strategy.scope():
      backbone = backbones.ResNet(model_id=50, input_specs=input_specs)

      model = classification_model.ClassificationModel(
          backbone=backbone,
          num_classes=1000,
          input_specs=input_specs,
      )
      _ = model(inputs)
Esempio n. 6
0
  def test_revnet_network_creation(self):
    """Test for creation of a RevNet-56 classifier."""
    revnet_model_id = 56
    inputs = np.random.rand(2, 224, 224, 3)

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.RevNet(model_id=revnet_model_id)
    self.assertEqual(backbone.count_params(), 19473792)

    num_classes = 1000
    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=num_classes,
        dropout_rate=0.2,
        add_head_batch_norm=True,
    )
    self.assertEqual(model.count_params(), 22816104)

    logits = model(inputs)
    self.assertAllEqual([2, num_classes], logits.numpy().shape)
Esempio n. 7
0
  def test_resnet_network_creation(
      self, input_size, resnet_model_id, activation):
    """Test for creation of a ResNet-50 classifier."""
    inputs = np.random.rand(2, input_size, input_size, 3)

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.ResNet(
        model_id=resnet_model_id, activation=activation)
    self.assertEqual(backbone.count_params(), 23561152)

    num_classes = 1000
    model = classification_model.ClassificationModel(
        backbone=backbone,
        num_classes=num_classes,
        dropout_rate=0.2,
    )
    self.assertEqual(model.count_params(), 25610152)

    logits = model(inputs)
    self.assertAllEqual([2, num_classes], logits.numpy().shape)
Esempio n. 8
0
def build_qat_classification_model(
    model: tf.keras.Model,
    quantization: common.Quantization,
    input_specs: tf.keras.layers.InputSpec,
    model_config: configs.image_classification.ImageClassificationModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Apply model optimization techniques.

  Args:
    model: The model applying model optimization techniques.
    quantization: The Quantization config.
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: The model config.
    l2_regularizer: tf.keras.regularizers.Regularizer object. Default to None.

  Returns:
    model: The model that applied optimization techniques.
  """
  original_checkpoint = quantization.pretrained_original_checkpoint
  if original_checkpoint:
    ckpt = tf.train.Checkpoint(
        model=model,
        **model.checkpoint_items)
    status = ckpt.read(original_checkpoint)
    status.expect_partial().assert_existing_objects_matched()

  scope_dict = {
      'L2': tf.keras.regularizers.l2,
  }
  with tfmot.quantization.keras.quantize_scope(scope_dict):
    annotated_backbone = tfmot.quantization.keras.quantize_annotate_model(
        model.backbone)
    if quantization.change_num_bits:
      backbone = tfmot.quantization.keras.quantize_apply(
          annotated_backbone,
          scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
              num_bits_weight=quantization.num_bits_weight,
              num_bits_activation=quantization.num_bits_activation))
    else:
      backbone = tfmot.quantization.keras.quantize_apply(
          annotated_backbone,
          scheme=schemes.Default8BitQuantizeScheme())

  norm_activation_config = model_config.norm_activation
  backbone_optimized_model = classification_model.ClassificationModel(
      backbone=backbone,
      num_classes=model_config.num_classes,
      input_specs=input_specs,
      dropout_rate=model_config.dropout_rate,
      kernel_regularizer=l2_regularizer,
      add_head_batch_norm=model_config.add_head_batch_norm,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon)
  for from_layer, to_layer in zip(
      model.layers, backbone_optimized_model.layers):
    if from_layer != model.backbone:
      to_layer.set_weights(from_layer.get_weights())

  with tfmot.quantization.keras.quantize_scope(scope_dict):
    def apply_quantization_to_dense(layer):
      if isinstance(layer, (tf.keras.layers.Dense,
                            tf.keras.layers.Dropout,
                            tf.keras.layers.GlobalAveragePooling2D)):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
      return layer

    annotated_model = tf.keras.models.clone_model(
        backbone_optimized_model,
        clone_function=apply_quantization_to_dense,
    )

    if quantization.change_num_bits:
      optimized_model = tfmot.quantization.keras.quantize_apply(
          annotated_model,
          scheme=n_bit_schemes.DefaultNBitQuantizeScheme(
              num_bits_weight=quantization.num_bits_weight,
              num_bits_activation=quantization.num_bits_activation))

    else:
      optimized_model = tfmot.quantization.keras.quantize_apply(
          annotated_model)

  return optimized_model