def inference_lm(self): """ this is for pre-trained language model. main inference logic here: invoke transformer model to do inference,input is a sequence, output is also a sequence, get representation of masked token(s) and use a classifier to train the model. # idea of the hidden state of masked position(s): # 1) a batch of position index, 2) one hot it, multiply with total sequence represenation, 3)every where is 0 for the second dimension(sequence_length), 4) only one place is 1, 5) thus we can sum up without loss any information. :return: """ # 1. input representation(input embedding, positional encoding, segment encoding) token_embeddings = tf.nn.embedding_lookup(self.embedding,self.x_mask_lm) # [batch_size,sequence_length,embed_size] self.input_representation_lm=tf.add(tf.add(token_embeddings,self.segment_embeddings_lm),self.position_embeddings_lm) # [batch_size,sequence_length,embed_size] # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm) encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length_lm,self.h,self.batch_size,self.num_layer,self.input_representation_lm, self.input_representation_lm,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn) h_lm = encoder_class.encoder_fn() # [batch_size,sequence_length,d_model] # 3. get last hidden state of the masked position(s), and project it to make a predict. p_mask_lm_onehot=tf.one_hot(self.p_mask_lm,self.sequence_length_lm) # [batch_size, sequence_length_lm] p_mask_lm_expand=tf.expand_dims(p_mask_lm_onehot,axis=-1) # # [batch_size, sequence_length_lm,1] h_lm_multiply=tf.multiply(h_lm,p_mask_lm_expand) # [batch_size,sequence_length,d_model] h_lm_representation=tf.reduce_sum(h_lm_multiply,axis=1) # batch_size,d_model]. # 4. project representation of masked token(s) to vocab size with tf.variable_scope("pre_training"): logits_lm = tf.layers.dense(h_lm_representation, self.vocab_size) # shape:[None,self.vocab_size] logits_lm = tf.nn.dropout(logits_lm,keep_prob=self.dropout_keep_prob) # shape:[None,self.num_classes] return logits_lm # shape:[None,self.num_classes]
def inference(self): """ this is for fine-tuning. main inference logic here: invoke transformer model to do inference,input is a sequence, output is also a sequence, get representation of masked token(s) and use a classifier to train the model. # idea of the hidden state of masked position(s): # 1) a batch of position index, 2) one hot it, multiply with total sequence represenation, 3)every where is 0 for the second dimension(sequence_length), 4) only one place is 1, 5) thus we can sum up without loss any information. :return: """ # 1. input representation(input embedding, positional encoding, segment encoding) token_embeddings = tf.nn.embedding_lookup(self.embedding,self.input_x) # [batch_size,sequence_length,embed_size] self.input_representation=tf.add(tf.add(token_embeddings,self.segment_embeddings_lm),self.position_embeddings) # [batch_size,sequence_length,embed_size] # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm) encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length,self.h,self.batch_size,self.num_layer,self.input_representation, self.input_representation,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn) h= encoder_class.encoder_fn() # [batch_size,sequence_length,d_model] # 3. get hidden state of token of [cls], and project it to make a predict. h_cls=h[:,0,:] # [batch_size,d_model] # 4. project representation of masked token(s) to vocab size with tf.variable_scope("fine_tuning"): logits = tf.layers.dense(h_cls, self.num_classes) # shape:[None,self.vocab_size] logits = tf.nn.dropout(logits,keep_prob=self.dropout_keep_prob) # shape:[None,self.num_classes] return logits # shape:[None,self.num_classes]
def inference(self): """ main inference logic here: invoke transformer model to do inference. input is a sequence, output is also a sequence. input representation--> :return: """ # 1. input representation(input embedding, positional encoding, segment encoding) token_embeddings = tf.nn.embedding_lookup(self.embedding,self.input_x) # [batch_size,sequence_length,embed_size] self.input_representation=tf.add(tf.add(token_embeddings,self.segment_embeddings),self.position_embeddings) # [batch_size,sequence_length,embed_size] # 2. repeat Nx times of building block( multi-head attention followed by Add & Norm; feed forward followed by Add & Norm) encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length,self.h,self.batch_size,self.num_layer,self.input_representation, self.input_representation,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn) h = encoder_class.encoder_fn() # [batch_size,sequence_length,d_model] # 3. get logits for different tasks by applying projection layer logits=self.project_tasks(h) # shape:[None,self.num_classes] return logits # shape:[None,self.num_classes]