def image_text_matching(self, num_detections, detection_boxes, detection_classes, detection_scores, detection_features, caption_ids, caption_tag_ids, caption_tag_features, caption_length): """Predicts the matching score of the given image-text pair. Args: num_detections: A [batch] int tensor. detection_classes: A [batch, max_detections] string tensor. detection_features: A [batch, max_detections, dims] float tensor. caption_ids: A [batch, max_caption_len] int tensor. caption_tag_ids: A [batch, max_caption_len] int tensor. caption_tag_features: A [batch, max_caption_len, dims] float tensor. caption_length: A [batch] int tensor. Returns: bert_feature: A [batch, 1 + max_detections + 1 + max_caption_len + 1, dims] float tensor. embedding_table: A [vocab_size, dims] float tensor. """ (input_ids, input_masks, input_tag_masks, input_tag_features) = self.create_bert_input_tensors( num_detections, detection_boxes, detection_classes, detection_scores, detection_features, caption_ids, caption_tag_ids, caption_tag_features, caption_length) bert_model = bert_modeling.BertModel( self._bert_config, self._is_training, input_ids=input_ids, input_mask=input_masks, input_tag_mask=input_tag_masks, input_tag_feature=input_tag_features, scope='bert') return bert_model.get_pooled_output(), bert_model.get_embedding_table()
def generate_adversarial_masks(self, choice_ids, choice_lengths, question_lengths, labels, hard=True): """Masked language modeling.""" options = self._model_proto is_training = self._is_training batch_size = choice_ids.shape[0] max_choice_len = tf.shape(choice_ids)[-1] with tf.variable_scope('adversarial'): input_ids = tf.reshape(choice_ids, [batch_size * NUM_CHOICES, -1]) input_masks = tf.reshape( tf.sequence_mask(choice_lengths, maxlen=max_choice_len), [batch_size * NUM_CHOICES, -1]) label_embeddings = None if options.use_label_embedding: full_label_embeddings = tf.get_variable( name='label_embedding', shape=[2, self._bert2_config.hidden_size], initializer=bert_modeling.create_initializer( self._bert_config.initializer_range)) one_hot_labels = tf.one_hot(labels, NUM_CHOICES, on_value=1, off_value=0) label_embeddings = tf.nn.embedding_lookup( full_label_embeddings, one_hot_labels) label_embeddings = tf.reshape(label_embeddings, [ batch_size * NUM_CHOICES, 1, self._bert2_config.hidden_size ]) bert_model = bert_modeling.BertModel( self._bert2_config, self._is_training, input_ids=input_ids, input_mask=input_masks, input_tag_feature=label_embeddings, scope='bert') sequence_output = bert_model.get_sequence_output() sequence_output = tf.reshape( sequence_output, [batch_size, NUM_CHOICES, -1, self._bert2_config.hidden_size]) choice_shortcut_logits = slim.fully_connected(sequence_output, num_outputs=1, activation_fn=None, scope='logits') choice_shortcut_logits = tf.multiply( options.adversarial_logits_scale, tf.squeeze(choice_shortcut_logits, -1)) # Gumbel-Softmax to get the probable shortcut. choice_masks = tf.logical_and( tf.sequence_mask(choice_lengths, maxlen=max_choice_len), tf.logical_not( tf.sequence_mask(question_lengths, maxlen=max_choice_len))) choice_masks = tf.cast(choice_masks, tf.float32) temperature = tf.Variable(options.temperature_init_value, name='adversarial/temperature_var', trainable=options.temperature_trainable, dtype=tf.float32) temperature = tf.maximum(temperature, EPSILON) tf.summary.histogram('shortcut/logtis', choice_shortcut_logits) tf.summary.scalar('metrics/temperature', temperature) choice_shortcut_logits = choice_shortcut_logits - \ INF * (1.0 - choice_masks) tf.summary.histogram('shortcut/probas', tf.nn.softmax(choice_shortcut_logits)) a_sample = RelaxedOneHotCategorical(temperature, logits=choice_shortcut_logits, allow_nan_stats=False).sample() if hard: k = tf.shape(choice_shortcut_logits)[-1] a_hard_sample = tf.cast(tf.one_hot(tf.argmax(a_sample, -1), k), a_sample.dtype) a_sample = tf.stop_gradient(a_hard_sample - a_sample) + a_sample # Returns the mask sampled from the distribution. return a_sample, choice_shortcut_logits, sequence_output, temperature