def inference_lm(self):
        """
        this is for pre-trained language model.
        main inference logic here: invoke transformer model to do inference,input is a sequence, output is also a sequence, get representation of masked token(s) and use a classifier
        to train the model.
        # idea of the hidden state of masked position(s):
        #   1) a batch of position index,
            2) one hot it, multiply with total sequence represenation,
            3)every where is 0 for the second dimension(sequence_length),
            4) only one place is 1,
            5) thus we can sum up without loss any information.
        :return:
        """
        # 1. input representation(input embedding, positional encoding, segment encoding)
        token_embeddings = tf.nn.embedding_lookup(self.embedding,self.x_mask_lm)  # [batch_size,sequence_length,embed_size]
        self.input_representation_lm=tf.add(tf.add(token_embeddings,self.segment_embeddings_lm),self.position_embeddings_lm)  # [batch_size,sequence_length,embed_size]

        # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm)
        encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length_lm,self.h,self.batch_size,self.num_layer,self.input_representation_lm,
                              self.input_representation_lm,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn)
        h_lm = encoder_class.encoder_fn() # [batch_size,sequence_length,d_model]

        # 3. get last hidden state of the masked position(s), and project it to make a predict.
        p_mask_lm_onehot=tf.one_hot(self.p_mask_lm,self.sequence_length_lm) # [batch_size, sequence_length_lm]
        p_mask_lm_expand=tf.expand_dims(p_mask_lm_onehot,axis=-1) #  # [batch_size, sequence_length_lm,1]
        h_lm_multiply=tf.multiply(h_lm,p_mask_lm_expand)     # [batch_size,sequence_length,d_model]
        h_lm_representation=tf.reduce_sum(h_lm_multiply,axis=1) # batch_size,d_model].

        # 4. project representation of masked token(s) to vocab size
        with tf.variable_scope("pre_training"):
            logits_lm = tf.layers.dense(h_lm_representation, self.vocab_size)   # shape:[None,self.vocab_size]
            logits_lm = tf.nn.dropout(logits_lm,keep_prob=self.dropout_keep_prob)  # shape:[None,self.num_classes]
        return logits_lm # shape:[None,self.num_classes]
    def inference(self):
        """
        this is for fine-tuning.
        main inference logic here: invoke transformer model to do inference,input is a sequence, output is also a sequence, get representation of masked token(s) and use a classifier
        to train the model.
        # idea of the hidden state of masked position(s):
        #   1) a batch of position index,
            2) one hot it, multiply with total sequence represenation,
            3)every where is 0 for the second dimension(sequence_length),
            4) only one place is 1,
            5) thus we can sum up without loss any information.
        :return:
        """
        # 1. input representation(input embedding, positional encoding, segment encoding)
        token_embeddings = tf.nn.embedding_lookup(self.embedding,self.input_x)  # [batch_size,sequence_length,embed_size]
        self.input_representation=tf.add(tf.add(token_embeddings,self.segment_embeddings_lm),self.position_embeddings)  # [batch_size,sequence_length,embed_size]

        # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm)
        encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length,self.h,self.batch_size,self.num_layer,self.input_representation,
                              self.input_representation,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn)
        h= encoder_class.encoder_fn() # [batch_size,sequence_length,d_model]

        # 3. get hidden state of token of [cls], and project it to make a predict.
        h_cls=h[:,0,:] # [batch_size,d_model]

        # 4. project representation of masked token(s) to vocab size
        with tf.variable_scope("fine_tuning"):
            logits = tf.layers.dense(h_cls, self.num_classes)   # shape:[None,self.vocab_size]
            logits = tf.nn.dropout(logits,keep_prob=self.dropout_keep_prob)  # shape:[None,self.num_classes]
        return logits # shape:[None,self.num_classes]
Exemple #3
0
    def inference(self):
        """
        main inference logic here: invoke transformer model to do inference. input is a sequence, output is also a sequence.
        input representation-->
        :return:
        """
        # 1. input representation(input embedding, positional encoding, segment encoding)
        token_embeddings = tf.nn.embedding_lookup(self.embedding,self.input_x)  # [batch_size,sequence_length,embed_size]
        self.input_representation=tf.add(tf.add(token_embeddings,self.segment_embeddings),self.position_embeddings)  # [batch_size,sequence_length,embed_size]

        # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm)
        encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length,self.h,self.batch_size,self.num_layer,self.input_representation,
                              self.input_representation,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn)
        h = encoder_class.encoder_fn() # [batch_size,sequence_length,d_model]

        # 3. get logits for different tasks by applying projection layer
        logits=self.project_tasks(h) # shape:[None,self.num_classes]
        return logits # shape:[None,self.num_classes]