コード例 #1
0
    def image_text_matching(self, num_detections, detection_boxes,
                            detection_classes, detection_scores,
                            detection_features, caption_ids, caption_tag_ids,
                            caption_tag_features, caption_length):
        """Predicts the matching score of the given image-text pair.

    Args:
      num_detections: A [batch] int tensor.
      detection_classes: A [batch, max_detections] string tensor.
      detection_features: A [batch, max_detections, dims] float tensor.
      caption_ids: A [batch, max_caption_len] int tensor.
      caption_tag_ids: A [batch, max_caption_len] int tensor.
      caption_tag_features: A [batch, max_caption_len, dims] float tensor.
      caption_length: A [batch] int tensor.

    Returns:
      bert_feature: A [batch, 1 + max_detections + 1 + max_caption_len + 1, dims] float tensor.
      embedding_table: A [vocab_size, dims] float tensor.
    """
        (input_ids, input_masks, input_tag_masks,
         input_tag_features) = self.create_bert_input_tensors(
             num_detections, detection_boxes, detection_classes,
             detection_scores, detection_features, caption_ids,
             caption_tag_ids, caption_tag_features, caption_length)
        bert_model = bert_modeling.BertModel(
            self._bert_config,
            self._is_training,
            input_ids=input_ids,
            input_mask=input_masks,
            input_tag_mask=input_tag_masks,
            input_tag_feature=input_tag_features,
            scope='bert')
        return bert_model.get_pooled_output(), bert_model.get_embedding_table()
コード例 #2
0
    def generate_adversarial_masks(self,
                                   choice_ids,
                                   choice_lengths,
                                   question_lengths,
                                   labels,
                                   hard=True):
        """Masked language modeling."""
        options = self._model_proto
        is_training = self._is_training

        batch_size = choice_ids.shape[0]
        max_choice_len = tf.shape(choice_ids)[-1]

        with tf.variable_scope('adversarial'):
            input_ids = tf.reshape(choice_ids, [batch_size * NUM_CHOICES, -1])
            input_masks = tf.reshape(
                tf.sequence_mask(choice_lengths, maxlen=max_choice_len),
                [batch_size * NUM_CHOICES, -1])

            label_embeddings = None
            if options.use_label_embedding:
                full_label_embeddings = tf.get_variable(
                    name='label_embedding',
                    shape=[2, self._bert2_config.hidden_size],
                    initializer=bert_modeling.create_initializer(
                        self._bert_config.initializer_range))
                one_hot_labels = tf.one_hot(labels,
                                            NUM_CHOICES,
                                            on_value=1,
                                            off_value=0)
                label_embeddings = tf.nn.embedding_lookup(
                    full_label_embeddings, one_hot_labels)
                label_embeddings = tf.reshape(label_embeddings, [
                    batch_size * NUM_CHOICES, 1, self._bert2_config.hidden_size
                ])

            bert_model = bert_modeling.BertModel(
                self._bert2_config,
                self._is_training,
                input_ids=input_ids,
                input_mask=input_masks,
                input_tag_feature=label_embeddings,
                scope='bert')
            sequence_output = bert_model.get_sequence_output()
            sequence_output = tf.reshape(
                sequence_output,
                [batch_size, NUM_CHOICES, -1, self._bert2_config.hidden_size])

            choice_shortcut_logits = slim.fully_connected(sequence_output,
                                                          num_outputs=1,
                                                          activation_fn=None,
                                                          scope='logits')
            choice_shortcut_logits = tf.multiply(
                options.adversarial_logits_scale,
                tf.squeeze(choice_shortcut_logits, -1))

        # Gumbel-Softmax to get the probable shortcut.
        choice_masks = tf.logical_and(
            tf.sequence_mask(choice_lengths, maxlen=max_choice_len),
            tf.logical_not(
                tf.sequence_mask(question_lengths, maxlen=max_choice_len)))
        choice_masks = tf.cast(choice_masks, tf.float32)

        temperature = tf.Variable(options.temperature_init_value,
                                  name='adversarial/temperature_var',
                                  trainable=options.temperature_trainable,
                                  dtype=tf.float32)
        temperature = tf.maximum(temperature, EPSILON)

        tf.summary.histogram('shortcut/logtis', choice_shortcut_logits)
        tf.summary.scalar('metrics/temperature', temperature)

        choice_shortcut_logits = choice_shortcut_logits - \
            INF * (1.0 - choice_masks)
        tf.summary.histogram('shortcut/probas',
                             tf.nn.softmax(choice_shortcut_logits))

        a_sample = RelaxedOneHotCategorical(temperature,
                                            logits=choice_shortcut_logits,
                                            allow_nan_stats=False).sample()

        if hard:
            k = tf.shape(choice_shortcut_logits)[-1]
            a_hard_sample = tf.cast(tf.one_hot(tf.argmax(a_sample, -1), k),
                                    a_sample.dtype)
            a_sample = tf.stop_gradient(a_hard_sample - a_sample) + a_sample

        # Returns the mask sampled from the distribution.
        return a_sample, choice_shortcut_logits, sequence_output, temperature