def __init__(self, delg_global_features=True, delg_gem_power=3.0, delg_embedding_layer_dim=2048, block3_strides=True, iou=1.0): """Initialization of DELG model. Args: delg_global_features: Whether the model uses a DELG-like global feature head. delg_gem_power: Power for Generalized Mean pooling in the DELG model. Used only if 'delg_global_features' is True. delg_embedding_layer_dim: Size of the FC whitening layer (embedding layer). Used only if 'delg_global_features' is True. block3_strides: bool, whether to add strides to the output of block3. iou: IOU for non-max suppression. """ self._stride_factor = 2.0 if block3_strides else 1.0 self._iou = iou # Setup the DELG model for extraction. if delg_global_features: self._model = delg_model.Delg( block3_strides=block3_strides, name='DELG', gem_power=delg_gem_power, embedding_layer_dim=delg_embedding_layer_dim) else: self._model = delf_model.Delf(block3_strides=block3_strides, name='DELF')
def test_build_model(self, block3_strides): image_size = 321 num_classes = 1000 batch_size = 2 input_shape = (batch_size, image_size, image_size, 3) model = delg_model.Delg( block3_strides=block3_strides, use_dim_reduction=True) model.init_classifiers(num_classes) images = tf.random.uniform(input_shape, minval=-1.0, maxval=1.0, seed=0) labels = tf.random.uniform((batch_size,), minval=0, maxval=model.num_classes - 1, dtype=tf.int64) blocks = {} desc_prelogits = model.backbone( images, intermediates_dict=blocks, training=False) desc_logits = model.desc_classification(desc_prelogits, labels) self.assertAllEqual(desc_prelogits.shape, (batch_size, 2048)) self.assertAllEqual(desc_logits.shape, (batch_size, num_classes)) features = blocks['block3'] attn_prelogits, _, _ = model.attention(features) attn_logits = model.attn_classification(attn_prelogits) self.assertAllEqual(attn_prelogits.shape, (batch_size, 1024)) self.assertAllEqual(attn_logits.shape, (batch_size, num_classes))
def test_forward_pass(self, block3_strides): image_size = 321 num_classes = 1000 batch_size = 2 input_shape = (batch_size, image_size, image_size, 3) local_feature_dim = 64 feature_map_size = image_size // 16 # reduction factor for resnet50. if block3_strides: feature_map_size //= 2 model = delg_model.Delg(block3_strides=block3_strides, use_dim_reduction=True, reduced_dimension=local_feature_dim) model.init_classifiers(num_classes) images = tf.random.uniform(input_shape, minval=-1.0, maxval=1.0, seed=0) # Run a complete forward pass of the model. global_feature, attn_scores, local_features = model.build_call(images) self.assertAllEqual(global_feature.shape, (batch_size, 2048)) self.assertAllEqual( attn_scores.shape, (batch_size, feature_map_size, feature_map_size, 1)) self.assertAllEqual( local_features.shape, (batch_size, feature_map_size, feature_map_size, local_feature_dim))
def create_model(num_classes): """Define DELF model, and initialize classifiers.""" if FLAGS.delg_global_features: model = delg_model.Delg( block3_strides=FLAGS.block3_strides, name='DELG', gem_power=FLAGS.delg_gem_power, embedding_layer_dim=FLAGS.delg_embedding_layer_dim, scale_factor_init=FLAGS.delg_scale_factor_init, arcface_margin=FLAGS.delg_arcface_margin) else: model = delf_model.Delf(block3_strides=FLAGS.block3_strides, name='DELF') model.init_classifiers(num_classes) return model
def test_train_step(self, block3_strides): image_size = 321 num_classes = 1000 batch_size = 2 clip_val = 10.0 input_shape = (batch_size, image_size, image_size, 3) model = delg_model.Delg( block3_strides=block3_strides, use_dim_reduction=True) model.init_classifiers(num_classes) optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9) images = tf.random.uniform(input_shape, minval=0.0, maxval=1.0, seed=0) labels = tf.random.uniform((batch_size,), minval=0, maxval=model.num_classes - 1, dtype=tf.int64) loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) def compute_loss(labels, predictions): per_example_loss = loss_object(labels, predictions) return tf.nn.compute_average_loss( per_example_loss, global_batch_size=batch_size) with tf.GradientTape() as gradient_tape: (desc_prelogits, attn_prelogits, _, backbone_blocks, dim_expanded_features, _) = model.global_and_local_forward_pass(images) # Calculate global loss by applying the descriptor classifier. desc_logits = model.desc_classification(desc_prelogits, labels) desc_loss = compute_loss(labels, desc_logits) # Calculate attention loss by applying the attention block classifier. attn_logits = model.attn_classification(attn_prelogits) attn_loss = compute_loss(labels, attn_logits) # Calculate reconstruction loss between the attention prelogits and the # backbone. block3 = tf.stop_gradient(backbone_blocks['block3']) reconstruction_loss = tf.math.reduce_mean( tf.keras.losses.MSE(block3, dim_expanded_features)) # Cumulate global loss and attention loss and backpropagate through the # descriptor layer and attention layer together. total_loss = desc_loss + attn_loss + reconstruction_loss gradients = gradient_tape.gradient(total_loss, model.trainable_weights) clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val) optimizer.apply_gradients(zip(clipped, model.trainable_weights))
def __init__(self, multi_scale_pool_type='None', normalize_global_descriptor=False, input_scales_tensor=None, delg_global_features=False, delg_gem_power=3.0, delg_embedding_layer_dim=2048): """Initialization of global feature model. Args: multi_scale_pool_type: Type of multi-scale pooling to perform. normalize_global_descriptor: Whether to L2-normalize global descriptor. input_scales_tensor: If None, the exported function to be used should be ExtractFeatures, where an input end-point "input_scales" is added for the exported model. If not None, the specified 1D tensor of floats will be hard-coded as the desired input scales, in conjunction with ExtractFeaturesFixedScales. delg_global_features: Whether the model uses a DELG-like global feature head. delg_gem_power: Power for Generalized Mean pooling in the DELG model. Used only if 'delg_global_features' is True. delg_embedding_layer_dim: Size of the FC whitening layer (embedding layer). Used only if 'delg_global_features' is True. """ self._multi_scale_pool_type = multi_scale_pool_type self._normalize_global_descriptor = normalize_global_descriptor if input_scales_tensor is None: self._input_scales_tensor = [] else: self._input_scales_tensor = input_scales_tensor # Setup the DELF model for extraction. if delg_global_features: self._model = delg_model.Delg( block3_strides=False, name='DELG', gem_power=delg_gem_power, embedding_layer_dim=delg_embedding_layer_dim) else: self._model = delf_model.Delf(block3_strides=False, name='DELF')
def create_model(num_classes): """Define DELF model, and initialize classifiers.""" if FLAGS.delg_global_features: model = delg_model.Delg( block3_strides=FLAGS.block3_strides, name='DELG', gem_power=FLAGS.delg_gem_power, embedding_layer_dim=FLAGS.delg_embedding_layer_dim, scale_factor_init=FLAGS.delg_scale_factor_init, arcface_margin=FLAGS.delg_arcface_margin, use_dim_reduction=FLAGS.use_autoencoder, reduced_dimension=FLAGS.autoencoder_dimensions, dim_expand_channels=FLAGS.local_feature_map_channels) else: model = delf_model.Delf( block3_strides=FLAGS.block3_strides, name='DELF', use_dim_reduction=FLAGS.use_autoencoder, reduced_dimension=FLAGS.autoencoder_dimensions, dim_expand_channels=FLAGS.local_feature_map_channels) model.init_classifiers(num_classes) return model