def build_logits(self, features, mode=None):
        """ Building graph of KD Student

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool): tell the model whether it is under training
        Returns:
            logits (`list`): logits for all the layers, list of shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        preprocessor = preprocessors.get_preprocessor(
            self.config.pretrain_model_name_or_path,
            user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)

        if mode != tf.estimator.ModeKeys.PREDICT:
            teacher_logits, input_ids, input_mask, segment_ids, label_ids = preprocessor(
                features)
        else:
            teacher_logits, input_ids, input_mask, segment_ids = preprocessor(
                features)
            label_ids = None

        teacher_n_layers = int(
            teacher_logits.shape[1]) / self.config.num_labels - 1
        self.teacher_logits = [
            teacher_logits[:, i * self.config.num_labels:(i + 1) *
                           self.config.num_labels]
            for i in range(teacher_n_layers + 1)
        ]

        if self.config.train_probes:
            bert_model = bert_backbone.bert
            embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                     training=is_training)
            attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
            all_hidden_outputs, all_att_outputs = bert_model.encoder(
                [embedding_output, attention_mask], training=is_training)

            # Get teacher Probes
            logits = layers.HiddenLayerProbes(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name="probes")([embedding_output, all_hidden_outputs])
        else:
            _, pooled_output = bert_backbone(
                [input_ids, input_mask, segment_ids], mode=mode)
            pooled_output = tf.layers.dropout(pooled_output,
                                              rate=self.config.dropout_rate,
                                              training=is_training)
            logits = layers.Dense(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name='app/ez_dense')(pooled_output)
            logits = [logits]

        return logits, label_ids
    def build_logits(self, features, mode=None):
        """ Building graph of KD Teacher

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool): tell the model whether it is under training
        Returns:
            logits (`list`): logits for all the layers, list of shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        preprocessor = preprocessors.get_preprocessor(
            self.config.pretrain_model_name_or_path,
            user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)

        # Serialize raw text to get input tensors
        input_ids, input_mask, segment_ids, label_id = preprocessor(features)

        if self.config.train_probes:
            # Get BERT all hidden states
            bert_model = bert_backbone.bert
            embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                     training=is_training)
            attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
            all_hidden_outputs, all_att_outputs = bert_model.encoder(
                [embedding_output, attention_mask], training=is_training)

            # Get teacher Probes
            logits = layers.HiddenLayerProbes(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name="probes")([embedding_output, all_hidden_outputs])
            self.tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           "probes/")
        else:
            _, pooled_output = bert_backbone(
                [input_ids, input_mask, segment_ids], mode=mode)
            pooled_output = tf.layers.dropout(pooled_output,
                                              rate=self.config.dropout_rate,
                                              training=is_training)
            logits = layers.Dense(
                self.config.num_labels,
                kernel_initializer=layers.get_initializer(0.02),
                name='app/ez_dense')(pooled_output)
            logits = [logits]

        if mode == tf.estimator.ModeKeys.PREDICT:
            return {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
                "label_id": label_id,
                "logits": tf.concat(logits, axis=-1)
            }
        else:
            return logits, label_id
Пример #3
0
    def build_logits(self, features, mode=None):
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path)

        model = model_zoo.get_pretrained_model(
            self.pretrain_model_name_or_path)

        global_step = tf.train.get_or_create_global_step()

        tnews_dense = layers.Dense(
            15,
            kernel_initializer=layers.get_initializer(0.02),
            name='tnews_dense')

        ocemotion_dense = layers.Dense(
            7,
            kernel_initializer=layers.get_initializer(0.02),
            name='ocemotion_dense')

        ocnli_dense = layers.Dense(
            3,
            kernel_initializer=layers.get_initializer(0.02),
            name='ocnli_dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        outputs = model([input_ids, input_mask, segment_ids], mode=mode)
        pooled_output = outputs[1]

        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output = tf.nn.dropout(pooled_output, keep_prob=0.9)

        logits = tf.case([
            (tf.equal(tf.mod(global_step, 3),
                      0), lambda: tnews_dense(pooled_output)),
            (tf.equal(tf.mod(global_step, 3),
                      1), lambda: ocemotion_dense(pooled_output)),
            (tf.equal(tf.mod(global_step, 3),
                      2), lambda: ocnli_dense(pooled_output)),
        ],
                         exclusive=True)

        if mode == tf.estimator.ModeKeys.PREDICT:
            ret = {
                "tnews_logits": tnews_dense(pooled_output),
                "ocemotion_logits": ocemotion_dense(pooled_output),
                "ocnli_logits": ocnli_dense(pooled_output),
                "label_ids": label_ids
            }
            return ret

        return logits, label_ids
    def build_logits(self, features, mode=None):
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path)

        model = model_zoo.get_pretrained_model(
            self.pretrain_model_name_or_path)

        global_step = tf.train.get_or_create_global_step()

        tnews_dense = layers.Dense(
            15,
            kernel_initializer=layers.get_initializer(0.02),
            name='tnews_dense')

        ocemotion_dense = layers.Dense(
            7,
            kernel_initializer=layers.get_initializer(0.02),
            name='ocemotion_dense')

        ocnli_dense = layers.Dense(
            3,
            kernel_initializer=layers.get_initializer(0.02),
            name='ocnli_dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        outputs_tnews = model([input_ids[0], input_mask[0], segment_ids[0]],
                              mode=mode)
        pooled_output_tnews = outputs_tnews[1]
        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output_tnews = tf.nn.dropout(pooled_output_tnews,
                                                keep_prob=0.2)
        logits_tnews = tnews_dense(pooled_output_tnews)

        outputs_ocemotion = model(
            [input_ids[1], input_mask[1], segment_ids[1]], mode=mode)
        pooled_output_ocemotion = outputs_ocemotion[1]
        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output_ocemotion = tf.nn.dropout(pooled_output_ocemotion,
                                                    keep_prob=0.2)
        logits_ocemotion = ocemotion_dense(pooled_output_ocemotion)

        outputs_ocnli = model([input_ids[2], input_mask[2], segment_ids[2]],
                              mode=mode)
        pooled_output_ocnli = outputs_ocnli[1]
        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output_ocnli = tf.nn.dropout(pooled_output_ocnli,
                                                keep_prob=0.5)
        logits_ocnli = ocnli_dense(pooled_output_ocnli)

        return [logits_tnews, logits_ocemotion,
                logits_ocnli], [label_ids[0], label_ids[1], label_ids[2]]
Пример #5
0
    def build_logits(self, features, mode=None):
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path,
            user_defined_config=self.user_defined_config)

        model = model_zoo.get_pretrained_model(
            self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        outputs = model([input_ids, input_mask, segment_ids], mode=mode)
        pooled_output = outputs[1]

        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output = tf.nn.dropout(pooled_output, keep_prob=0.9)

        logits = dense(pooled_output)

        if mode == tf.estimator.ModeKeys.PREDICT:
            return logits

        return logits, label_ids
    def __init__(self, config, **kwargs):
        super(AlbertBackbone, self).__init__(**kwargs)

        self.embeddings = layers.AlbertEmbeddings(config, name="embeddings")
        self.embedding_hidden_mapping_in = layers.Dense(
            config.hidden_size,
            kernel_initializer=layers.get_initializer(
                config.initializer_range),
            name="encoder/embedding_hidden_mapping_in",
        )
        self.encoder = AlbertEncoder(config, name="encoder/transformer")
        self.pooler = layers.Dense(units=config.hidden_size,
                                   activation='tanh',
                                   kernel_initializer=layers.get_initializer(
                                       config.initializer_range),
                                   name="pooler/dense")
 def __init__(self, config, **kwargs):
     super(AlbertSelfOutput, self).__init__(**kwargs)
     self.dense = layers.Dense(config.hidden_size,
                               kernel_initializer=layers.get_initializer(
                                   config.initializer_range),
                               name="dense")
     self.dropout = layers.Dropout(config.hidden_dropout_prob)
Пример #8
0
    def build_logits(self, features, mode=None):
        """ Building BERT text match graph

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool`): tell the model whether it is under training
        Returns:
            logits (`Tensor`): The output after the last dense layer. Shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        bert_preprocessor = preprocessors.get_preprocessor(
            self.config.pretrain_model_name_or_path,
            user_defined_config=self.config)
        input_ids, input_mask, segment_ids, label_ids = bert_preprocessor(
            features)

        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)
        _, pool_output = bert_backbone([input_ids, input_mask, segment_ids],
                                       mode=mode)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        pool_output = tf.layers.dropout(pool_output,
                                        rate=self.config.dropout_rate,
                                        training=is_training)
        logits = layers.Dense(self.config.num_labels,
                              kernel_initializer=layers.get_initializer(0.02),
                              name='app/ez_dense')(pool_output)

        self.check_and_init_from_checkpoint(mode)
        return logits, label_ids
    def build(self, input_shape):
        self.output_weights = self.add_weight(
            shape=[self.hidden_size, self.patch_feature_size],
            initializer=layers.get_initializer(self.initializer_range),
            trainable=True,
            name="output_weights")

        super(ImageBertMPMHead, self).build(input_shape)
Пример #10
0
    def build_logits(self, features, mode=None):
        """ Building DAM text match graph

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool`): tell the model whether it is under training
        Returns:
            logits (`Tensor`): The output after the last dense layer. Shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        text_preprocessor = DeepTextPreprocessor(self.config, mode=mode)
        text_a_indices, text_a_masks, text_b_indices, text_b_masks, label_ids = text_preprocessor(
            features)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        word_embeddings = self._add_word_embeddings(
            vocab_size=text_preprocessor.vocab.size,
            embed_size=self.config.embedding_size,
            pretrained_word_embeddings=text_preprocessor.
            pretrained_word_embeddings,
            trainable=not self.config.fix_embedding)
        a_embeds = tf.nn.embedding_lookup(word_embeddings, text_a_indices)
        b_embeds = tf.nn.embedding_lookup(word_embeddings, text_b_indices)

        dam_output_features = layers.DAMEncoder(self.config.hidden_size)(
            [a_embeds, b_embeds, text_a_masks, text_b_masks],
            training=is_training)

        dam_output_features = tf.layers.dropout(
            dam_output_features,
            rate=0.2,
            training=is_training,
            name='dam_out_features_dropout')
        dam_output_features = layers.Dense(
            self.config.hidden_size,
            activation=tf.nn.relu,
            kernel_initializer=layers.get_initializer(0.02),
            name='dam_out_features_projection')(dam_output_features)

        logits = layers.Dense(self.config.num_labels,
                              kernel_initializer=layers.get_initializer(0.02),
                              name='output_layer')(dam_output_features)

        self.check_and_init_from_checkpoint(mode)
        return logits, label_ids
Пример #11
0
 def __init__(self, config, **kwargs):
     super(BertHAEBackbone, self).__init__(config, **kwargs)
     self.config = config
     self.encoder = layers.Encoder(config, name="encoder")
     self.pooler = layers.Dense(
         units=config.hidden_size,
         activation='tanh',
         kernel_initializer=layers.get_initializer(config.initializer_range),
         name="pooler/dense")
    def __init__(self, config, **kwargs):
        super(AlbertIntermediate, self).__init__(**kwargs)
        self.dense = layers.Dense(config.intermediate_size,
                                  activation=layers.gelu_new,
                                  kernel_initializer=layers.get_initializer(
                                      config.initializer_range),
                                  name="dense")

        self.dense_output = AlbertOutput(config, name="output")
        self.dropout = layers.Dropout(config.hidden_dropout_prob)
 def _add_word_embeddings(self, vocab_size, embed_size, pretrained_word_embeddings=None, trainable=False):
     with tf.name_scope("input_representations"):
         if pretrained_word_embeddings is not None:
             tf.logging.info("Initialize word embedding from pretrained")
             word_embedding_initializer = tf.constant_initializer(pretrained_word_embeddings)
         else:
             word_embedding_initializer = layers.get_initializer(0.02)
         word_embeddings = tf.get_variable("word_embeddings",
                                           [vocab_size, embed_size],
                                           dtype=tf.float32, initializer=word_embedding_initializer,
                                           trainable=trainable)
     return word_embeddings
Пример #14
0
    def __init__(self, config, **kwargs):
        super(VideoBertEmbeddings, self).__init__(**kwargs)

        self.clip_feature_size = config.clip_feature_size
        self.clip_size = config.clip_size
        self.hidden_size = config.hidden_size
        self.initializer_range = config.initializer_range
        self.max_clip_position_embeddings = config.max_clip_position_embeddings

        self.LayerNorm = layers.LayerNormalization
        self.dropout = layers.Dropout(config.hidden_dropout_prob)
        self.initializer = layers.get_initializer(self.initializer_range)
Пример #15
0
 def __init__(self, config, **kwargs):
     super(ImageBertBackbone, self).__init__(**kwargs)
     self.num_hidden_layers = config.num_hidden_layers
     self.embeddings = layers.BertEmbeddings(config, name="embeddings")
     self.image_embeddings = ImageEmbeddings(config,
                                             name="image_embeddings")
     self.encoder = layers.Encoder(config, name="encoder")
     self.pooler = layers.Dense(units=config.hidden_size,
                                activation='tanh',
                                kernel_initializer=layers.get_initializer(
                                    config.initializer_range),
                                name="pooler/dense")
Пример #16
0
    def build_logits(self, features, mode=None):
        bert_preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path,
            app_model_name="pretrain_language_model",
            user_defined_config=self.user_defined_config)

        if _APP_FLAGS.distribution_strategy == "WhaleStrategy" or \
                self.config.distribution_strategy == "WhaleStrategy":
            tf.logging.info("*********Calling Whale Encoder***********")
            model = model_zoo.get_pretrained_model(
                self.pretrain_model_name_or_path,
                enable_whale=True,
                input_sequence_length=_APP_FLAGS.input_sequence_length)
        else:
            model = model_zoo.get_pretrained_model(
                self.pretrain_model_name_or_path,
                input_sequence_length=_APP_FLAGS.input_sequence_length)

        if _APP_FLAGS.loss == "mlm+nsp" or _APP_FLAGS.loss == "mlm+sop":
            input_ids, input_mask, segment_ids, masked_lm_positions, \
            masked_lm_ids, masked_lm_weights, next_sentence_labels = bert_preprocessor(features)

            lm_logits, nsp_logits, _ = model(
                [input_ids, input_mask, segment_ids],
                masked_lm_positions=masked_lm_positions,
                output_features=False,
                mode=mode)

            return (lm_logits, nsp_logits), (masked_lm_ids, masked_lm_weights,
                                             next_sentence_labels)

        elif _APP_FLAGS.loss == "mlm":

            task_1_dense = layers.Dense(
                2,
                kernel_initializer=layers.get_initializer(0.02),
                name='task_1_dense')

            input_ids, input_mask, segment_ids, masked_lm_positions, \
            masked_lm_ids, masked_lm_weights, task_1_label = bert_preprocessor(features)

            lm_logits, _, pooled_output = model(
                [input_ids, input_mask, segment_ids],
                masked_lm_positions=masked_lm_positions,
                output_features=False,
                mode=mode)

            task_1_logits = task_1_dense(pooled_output)

            return (lm_logits,
                    task_1_logits), (masked_lm_ids, masked_lm_weights,
                                     task_1_label)
Пример #17
0
    def __init__(self, config, **kwargs):
        super(ImageEmbeddings, self).__init__(**kwargs)

        self.patch_feature_size = config.patch_feature_size
        self.hidden_size = config.hidden_size
        self.initializer_range = config.initializer_range
        self.patch_type_vocab_size = config.patch_type_vocab_size
        self.max_patch_position_embeddings = config.max_patch_position_embeddings

        self.LayerNorm = layers.LayerNormalization
        self.dropout_input = layers.Dropout(config.hidden_dropout_prob)
        self.dropout_output = layers.Dropout(config.hidden_dropout_prob)
        self.initializer = layers.get_initializer(self.initializer_range)
Пример #18
0
    def build_logits(self, features, mode=None):
        text_preprocessor = DeepTextPreprocessor(self.config, mode=mode)
        text_a_indices, text_a_masks, text_b_indices, text_b_masks, label_ids = text_preprocessor(
            features)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        word_embeddings = self._add_word_embeddings(
            vocab_size=text_preprocessor.vocab.size,
            embed_size=self.config.embedding_size,
            pretrained_word_embeddings=text_preprocessor.
            pretrained_word_embeddings,
            trainable=not self.config.fix_embedding)
        a_embeds = tf.nn.embedding_lookup(word_embeddings, text_a_indices)
        b_embeds = tf.nn.embedding_lookup(word_embeddings, text_b_indices)

        hcnn_output_features = layers.HybridCNNEncoder(
            num_filters=self.config.hidden_size,
            l2_reg=self.config.l2_reg,
            filter_size=self.config.filter_size)(
                [a_embeds, b_embeds, text_a_masks, text_b_masks])

        hcnn_output_features = tf.layers.dropout(
            hcnn_output_features,
            rate=0.2,
            training=is_training,
            name='dam_out_features_dropout')
        hcnn_output_features = layers.Dense(
            self.config.hidden_size,
            activation=tf.nn.relu,
            kernel_initializer=layers.get_initializer(0.02),
            name='dam_out_features_projection')(hcnn_output_features)

        logits = layers.Dense(self.config.num_labels,
                              kernel_initializer=layers.get_initializer(0.02),
                              name='output_layer')(hcnn_output_features)

        self.check_and_init_from_checkpoint(mode)
        return logits, label_ids
Пример #19
0
    def build_logits(self, features, mode=None):

        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path, user_defined_config=self.config)
        model = model_zoo.get_pretrained_model(
            self.pretrain_model_name_or_path)
        dense = layers.Dense(self.num_labels,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)
        outputs = model([input_ids, input_mask, segment_ids], mode=mode)
        pooled_output = outputs[1]
        logits = dense(pooled_output)
        return logits, label_ids
Пример #20
0
    def __init__(self, config, **kwargs):

        self.embeddings = layers.BertEmbeddings(config, name="embeddings")
        if not kwargs.pop('enable_whale', False):
            self.encoder = layers.Encoder(config, name="encoder")
        else:
            self.encoder = layers.Encoder_whale(config, name="encoder")

        self.pooler = layers.Dense(units=config.hidden_size,
                                   activation='tanh',
                                   kernel_initializer=layers.get_initializer(
                                       config.initializer_range),
                                   name="pooler/dense")

        super(BertBackbone, self).__init__(config, **kwargs)
    def build_logits(self, features, mode=None):
        # 负责对原始数据进行预处理,生成模型需要的特征,比如:input_ids, input_mask, segment_ids等
        preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path,
                                                      user_defined_config=self.user_defined_config)

        # 负责构建网络的backbone
        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        _, pooled_output = model([input_ids, input_mask, segment_ids], mode=mode)

        logits = dense(pooled_output)

        return logits, label_ids
    def build_logits(self, features, mode=None):
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path)
        model = model_zoo.get_pretrained_model(
            self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)
        outputs = model([input_ids, input_mask, segment_ids], mode=mode)
        pooled_output = outputs[1]
        logits = dense(pooled_output)

        if mode == tf.estimator.ModeKeys.PREDICT:
            ret = {"logits": logits}
            return ret

        return logits, label_ids
Пример #23
0
    def build_logits(self, features, mode=None):
        """ Building DAM text match graph

        Args:
            features (`OrderedDict`): A dict mapping raw input to tensors
            mode (`bool`): tell the model whether it is under training
        Returns:
            logits (`Tensor`): The output after the last dense layer. Shape of [None, num_labels]
            label_ids (`Tensor`): label_ids, shape of [None]
        """
        text_preprocessor = DeepTextPreprocessor(self.config, mode=mode)
        text_indices, text_masks, _, _, label_ids = text_preprocessor(features)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        word_embeddings = self._add_word_embeddings(
            vocab_size=text_preprocessor.vocab.size,
            embed_size=self.config.embedding_size,
            pretrained_word_embeddings=text_preprocessor.
            pretrained_word_embeddings,
            trainable=not self.config.fix_embedding)
        text_embeds = tf.nn.embedding_lookup(word_embeddings, text_indices)

        output_features = layers.TextCNNEncoder(
            num_filters=self.config.num_filters,
            filter_sizes=self.config.filter_sizes,
            embed_size=self.config.embedding_size,
            max_seq_len=self.config.sequence_length,
        )([text_embeds, text_masks], training=is_training)

        output_features = tf.layers.dropout(output_features,
                                            rate=self.config.dropout_rate,
                                            training=is_training,
                                            name='output_features')

        logits = layers.Dense(self.config.num_labels,
                              kernel_initializer=layers.get_initializer(0.02),
                              name='output_layer')(output_features)

        self.check_and_init_from_checkpoint(mode)
        return logits, label_ids
    def build_logits(self, features, mode=None):

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path, user_defined_config=self.config)
        bert_backbone = model_zoo.get_pretrained_model(
            self.config.pretrain_model_name_or_path)
        dense = layers.Dense(self.num_labels,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='dense')

        input_ids, input_mask, segment_ids, label_ids, domains, weights = preprocessor(
            features)

        self.domains = domains
        self.weights = weights
        hidden_size = bert_backbone.config.hidden_size
        self.domain_logits = dict()

        bert_model = bert_backbone.bert
        embedding_output = bert_model.embeddings([input_ids, segment_ids],
                                                 training=is_training)
        attention_mask = layers.get_attn_mask_bert(input_ids, input_mask)
        encoder_outputs = bert_model.encoder(
            [embedding_output, attention_mask], training=is_training)
        encoder_outputs = encoder_outputs[0]
        pooled_output = bert_model.pooler(encoder_outputs[-1][:, 0])

        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output = tf.nn.dropout(pooled_output, keep_prob=0.9)

        with tf.variable_scope("mft", reuse=tf.AUTO_REUSE):
            # add domain network
            logits = dense(pooled_output)
            domains = tf.squeeze(domains)

            domain_embedded_matrix = tf.get_variable(
                "domain_projection", [num_domains, hidden_size],
                initializer=tf.truncated_normal_initializer(stddev=0.02))
            domain_embedded = tf.nn.embedding_lookup(domain_embedded_matrix,
                                                     domains)

            for layer_index in layer_indexes:
                content_tensor = tf.reduce_mean(encoder_outputs[layer_index],
                                                axis=1)
                content_tensor_with_domains = domain_embedded + content_tensor

                domain_weights = tf.get_variable(
                    "domain_weights", [num_domains, hidden_size],
                    initializer=tf.truncated_normal_initializer(stddev=0.02))
                domain_bias = tf.get_variable(
                    "domain_bias", [num_domains],
                    initializer=tf.zeros_initializer())

                current_domain_logits = tf.matmul(content_tensor_with_domains,
                                                  domain_weights,
                                                  transpose_b=True)
                current_domain_logits = tf.nn.bias_add(current_domain_logits,
                                                       domain_bias)

                self.domain_logits["domain_logits_" +
                                   str(layer_index)] = current_domain_logits
        return logits, label_ids