Example #1
0
    def call(self, inputs, input_mask=None, segment_ids=None, training=False):

        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            input_mask = inputs[1] if len(inputs) > 1 else input_mask
            segment_ids = inputs[2] if len(inputs) > 2 else segment_ids
        else:
            input_ids = inputs

        input_shape = layers.get_shape_list(input_ids)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        if segment_ids is None:
            segment_ids = tf.zeros(shape=[batch_size, seq_length],
                                   dtype=tf.int32)

        embedding_output = self.embeddings([input_ids, segment_ids],
                                           training=training)
        attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
        encoder_outputs = self.encoder([embedding_output, attention_mask],
                                       training=training)
        pooled_output = self.pooler(encoder_outputs[0][-1][:, 0])
        outputs = (encoder_outputs[0][-1], pooled_output)
        return outputs
Example #2
0
    def call(self, inputs,
             input_mask=None,
             segment_ids=None,
             history_answer_marker=None,
             training=False):

        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            input_mask = inputs[1] if len(inputs) > 1 else input_mask
            segment_ids = inputs[2] if len(inputs) > 2 else segment_ids
            history_answer_marker = inputs[3] if len(inputs) > 3 else history_answer_marker
        else:
            input_ids = inputs

        input_shape = layers.get_shape_list(input_ids)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)

        if segment_ids is None:
            segment_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

        if history_answer_marker is None:
            history_answer_marker = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

        with tf.variable_scope("embeddings"):
            # Perform embedding lookup on the word ids.
            (embedding_output, embedding_table) = embedding_lookup(
                input_ids=input_ids,
                vocab_size=self.config.vocab_size,
                embedding_size=self.config.hidden_size,
                initializer_range=self.config.initializer_range,
                word_embedding_name="word_embeddings",
                use_one_hot_embeddings=False)

            # Add positional embeddings and token type embeddings, then layer
            # normalize and perform dropout.
            embedding_output = embedding_postprocessor(
                input_tensor=embedding_output,
                use_token_type=True,
                token_type_ids=segment_ids,
                token_type_vocab_size=self.config.type_vocab_size,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=self.config.initializer_range,
                max_position_embeddings=self.config.max_position_embeddings,
                use_history_answer_embedding=True,
                history_answer_marker=history_answer_marker,
                history_answer_embedding_vocab_size=2,
                history_answer_embedding_name='history_answer_embedding',
                dropout_prob=self.config.hidden_dropout_prob)

        attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
        encoder_outputs = self.encoder([embedding_output, attention_mask], training=training)
        pooled_output = self.pooler(encoder_outputs[0][-1][:, 0])
        outputs = (encoder_outputs[0][-1], pooled_output)
        return outputs
    def build_logits(self, features, mode=None):
        """ Building graph of KD Student

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool): tell the model whether it is under training
        Returns:
            logits (`list`): logits for all the layers, list of shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        preprocessor = preprocessors.get_preprocessor(
            self.config.pretrain_model_name_or_path,
            user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)

        if mode != tf.estimator.ModeKeys.PREDICT:
            teacher_logits, input_ids, input_mask, segment_ids, label_ids = preprocessor(
                features)
        else:
            teacher_logits, input_ids, input_mask, segment_ids = preprocessor(
                features)
            label_ids = None

        teacher_n_layers = int(
            teacher_logits.shape[1]) / self.config.num_labels - 1
        self.teacher_logits = [
            teacher_logits[:, i * self.config.num_labels:(i + 1) *
                           self.config.num_labels]
            for i in range(teacher_n_layers + 1)
        ]

        if self.config.train_probes:
            bert_model = bert_backbone.bert
            embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                     training=is_training)
            attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
            all_hidden_outputs, all_att_outputs = bert_model.encoder(
                [embedding_output, attention_mask], training=is_training)

            # Get teacher Probes
            logits = layers.HiddenLayerProbes(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name="probes")([embedding_output, all_hidden_outputs])
        else:
            _, pooled_output = bert_backbone(
                [input_ids, input_mask, segment_ids], mode=mode)
            pooled_output = tf.layers.dropout(pooled_output,
                                              rate=self.config.dropout_rate,
                                              training=is_training)
            logits = layers.Dense(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name='app/ez_dense')(pooled_output)
            logits = [logits]

        return logits, label_ids
    def build_logits(self, features, mode=None):
        """ Building graph of KD Teacher

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool): tell the model whether it is under training
        Returns:
            logits (`list`): logits for all the layers, list of shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        preprocessor = preprocessors.get_preprocessor(
            self.config.pretrain_model_name_or_path,
            user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)

        # Serialize raw text to get input tensors
        input_ids, input_mask, segment_ids, label_id = preprocessor(features)

        if self.config.train_probes:
            # Get BERT all hidden states
            bert_model = bert_backbone.bert
            embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                     training=is_training)
            attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
            all_hidden_outputs, all_att_outputs = bert_model.encoder(
                [embedding_output, attention_mask], training=is_training)

            # Get teacher Probes
            logits = layers.HiddenLayerProbes(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name="probes")([embedding_output, all_hidden_outputs])
            self.tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           "probes/")
        else:
            _, pooled_output = bert_backbone(
                [input_ids, input_mask, segment_ids], mode=mode)
            pooled_output = tf.layers.dropout(pooled_output,
                                              rate=self.config.dropout_rate,
                                              training=is_training)
            logits = layers.Dense(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name='app/ez_dense')(pooled_output)
            logits = [logits]

        if mode == tf.estimator.ModeKeys.PREDICT:
            return {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
                "label_id": label_id,
                "logits": tf.concat(logits, axis=-1)
            }
        else:
            return logits, label_id
    def build_logits(self, features, mode=None):

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path, user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)
        dense = layers.Dense(self.num_labels,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='dense')

        input_ids, input_mask, segment_ids, label_ids, domains, weights = preprocessor(
            features)

        self.domains = domains
        self.weights = weights
        hidden_size = bert_backbone.config.hidden_size
        self.domain_logits = dict()

        bert_model = bert_backbone.bert
        embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                 training=is_training)
        attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
        encoder_outputs = bert_model.encoder(
            [embedding_output, attention_mask], training=is_training)
        encoder_outputs = encoder_outputs[0]
        pooled_output = bert_model.pooler(encoder_outputs[-1][:, 0])

        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output = tf.nn.dropout(pooled_output, keep_prob=0.9)

        with tf.variable_scope("mft", reuse=tf.AUTO_REUSE):
            # add domain network
            logits = dense(pooled_output)
            domains = tf.squeeze(domains)

            domain_embedded_matrix = tf.get_variable(
                "domain_projection", [num_domains, hidden_size],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            domain_embedded = tf.nn.embedding_lookup(domain_embedded_matrix,
                                                     domains)

            for layer_index in layer_indexes:
                content_tensor = tf.reduce_mean(encoder_outputs[layer_index],
                                                axis=1)
                content_tensor_with_domains = domain_embedded + content_tensor

                domain_weights = tf.get_variable(
                    "domain_weights", [num_domains, hidden_size],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))
                domain_bias = tf.get_variable(
                    "domain_bias", [num_domains],
                    initializer=tf.zeros_initializer())

                current_domain_logits = tf.matmul(content_tensor_with_domains,
                                                  domain_weights,
                                                  transpose_b=True)
                current_domain_logits = tf.nn.bias_add(current_domain_logits,
                                                       domain_bias)

                self.domain_logits["domain_logits_" +
                                   str(layer_index)] = current_domain_logits
        return logits, label_ids