def decode(self): config = self.bert_config hidden_size = self.encoder_output.shape[2].value with tf.variable_scope('bert', reuse=True): with tf.variable_scope('embeddings'), tf.device('/cpu:0'): # Perform embedding lookup on the target word ids. (self.out_embed, self.bert_embeddings) = embedding_lookup( input_ids=self.output_ids, vocab_size=config.vocab_size, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', use_one_hot_embeddings=False) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.out_embed = embedding_postprocessor( input_tensor=self.out_embed, use_token_type=True, token_type_ids=self.out_segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) with tf.variable_scope('decoder_1'): self.decoder_weights = self.bert_embeddings self.masked_out_embed = self.out_embed * tf.expand_dims(self.output_mask, -1) self.decoder_input = tf.pad(self.masked_out_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left # ################################################### decoding train - 1 self.dec_attn_bias = attention_bias(tf.shape(self.masked_out_embed)[1], 'causal') self.all_att_weights, self.decoder_output_1 = transformer_decoder(self.decoder_input, self.encoder_output, self.dec_attn_bias, self.enc_attn_bias, self.hps, scope='decoder_1') # [b, l_t, e] => [b*l_t, v] self.decoder_output_1 = tf.reshape(self.decoder_output_1, [-1, hidden_size]) self.vocab_logits = tf.matmul(self.decoder_output_1, self.decoder_weights, False, True) # (b*l_t, v) self.vocab_probs = tf.nn.softmax(self.vocab_logits) # [b * l_t, v] self.logits = self.vocab_probs self.pred_ids = tf.reshape(tf.argmax(self.logits, axis=-1), [self.batch_size, -1])
def decode_infer_2(self): # stage 2, word level inference using decoded sequence # l_t = decode sequence length # during infer, following graph are constructed using beam search hidden_size = self.bert_config.hidden_size with self.graph.as_default(): target_sequence = tf.squeeze(self.decode_seq, axis=1) target_sequence = self.trunct(target_sequence) target_length = self.decode_length target_seg_ids = tf.zeros_like(target_sequence, dtype=tf.int32, name='target_seg_ids_infer_2') tgt_mask = tf.sequence_mask(target_length, maxlen=tf.shape(target_sequence)[1], dtype=tf.float32) # [b, q'] is_training = self.is_training dec_model = modeling.BertModel( config=self.bert_config, is_training=is_training, input_ids=target_sequence, input_mask=tgt_mask, token_type_ids=target_seg_ids, scope='bert', reuse=tf.AUTO_REUSE, use_one_hot_embeddings=self.hps.use_tpu ) # use_one_hot_embeddings=Flags.tpu ? dec_output = dec_model.get_sequence_output() # [b, l_t, h] tgt_embed = dec_output # with tf.variable_scope('bert', reuse=True): # with tf.variable_scope('embeddings'), tf.device('/cpu:0'): # # Perform embedding lookup on the target word ids. # (tgt_embed, _) = embedding_lookup( # input_ids=target_sequence, # vocab_size=config.vocab_size, # embedding_size=config.hidden_size, # initializer_range=config.initializer_range, # word_embedding_name='word_embeddings', # use_one_hot_embeddings=False) # # # Add positional embeddings and token type embeddings, then layer # # normalize and perform dropout. # tgt_embed = embedding_postprocessor( # input_tensor=tgt_embed, # use_token_type=True, # token_type_ids=target_seg_ids, # token_type_vocab_size=config.type_vocab_size, # token_type_embedding_name='token_type_embeddings', # use_position_embeddings=True, # position_embedding_name='position_embeddings', # initializer_range=config.initializer_range, # max_position_embeddings=config.max_position_embeddings, # dropout_prob=config.hidden_dropout_prob) with tf.variable_scope('decoder_2', reuse=True): masked_tgt_embed = tgt_embed * tf.expand_dims(tgt_mask, -1) second_dec_attn_bias = attention_bias( tf.shape(masked_tgt_embed)[1], 'cloze_bias') infer_decoder_input = tf.pad( masked_tgt_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left all_att_weights, decoder_output = transformer_decoder( infer_decoder_input, self.enc_output, second_dec_attn_bias, self.enc_attn_bias, self.hps, scope='decoder_2') # [b, l_t, e] => [b*l_t, v] decoder_output = tf.reshape(decoder_output, [-1, hidden_size]) second_logits = tf.matmul(decoder_output, self.decoder_weights, False, True) # (b*l_t, v) vocab_probs = tf.nn.softmax(second_logits) # [b * l_t, v] vocab_size = len(self.hps.vocab) with tf.variable_scope('copy', reuse=tf.AUTO_REUSE): logits = calculate_final_logits( decoder_output, all_att_weights, vocab_probs, self.input_ids_oo, self.max_out_oovs, self.input_mask, vocab_size, self.infer_tiled_len) # [b * l_t, v + v'] second_log_prob = tf.log(logits) # (b, l_t, v) extend_vocab_size = tf.add(tf.constant(vocab_size), self.max_out_oovs) second_log_prob = tf.reshape( second_log_prob, [-1, tf.shape(target_sequence)[1], extend_vocab_size]) second_log_id = tf.argmax(second_log_prob, axis=-1) # (b, l_t) return second_log_id
def decode_2(self): config = self.bert_config hidden_size = self.encoder_output.shape[2].value draft = self.trunct(self.pred_ids) # as the draft may have copy words, we transform them to UNK first draft = tf.cast(draft, tf.int32) changed_ids = tf.concat([self.output_ids, draft], axis=-1) # [b, 2 * l_t] change_segment_ids = tf.zeros_like(changed_ids, dtype=tf.int32, name='change_segment_ids') def calcu_id_len(input_tensor): step_size = tf.constant(0.001) a = input_tensor res = tf.argmin(tf.cast(a, tf.float32) + tf.cast(tf.range(0, tf.shape(a)[-1]), tf.float32) * step_size, -1) + 1 return res pred_ids_len = calcu_id_len(draft) # [b,] pred_ids_mask_w_draft = tf.sequence_mask(pred_ids_len, maxlen=tf.shape(draft)[1], dtype=tf.float32) # [b, l_t] pred_ids_mask_wo_draft = tf.zeros_like(draft, dtype=tf.float32) pred_ids_mask = tf.cond(self.feed_draft, lambda: pred_ids_mask_w_draft, lambda: pred_ids_mask_wo_draft) change_ids_mask = tf.concat([self.output_mask, pred_ids_mask], axis=-1) # [b, 2 * l_t] transferred_mask = create_attention_mask_from_input_mask(changed_ids, change_ids_mask) # [b, 2 * l_t, 2 * l_t] self.second_dec_attn_bias_w_draft = attention_bias(tf.shape(changed_ids)[1], 'mask_draft') self.second_dec_attn_bias_wo_draft = attention_bias(tf.shape(changed_ids)[1], 'mask_draft_warmup') self.second_dec_attn_bias = tf.cond(self.feed_draft, lambda: self.second_dec_attn_bias_w_draft, lambda: self.second_dec_attn_bias_wo_draft) # [1, 1, 2 * l_t, 2 *l_t] self.second_dec_attn_bias = tf.tile(self.second_dec_attn_bias, [tf.shape(self.output_ids)[0], 1, 1, 1]) # [b, 1, 2 * l_t, 2 * l_t] self.second_dec_attn_bias = self.second_dec_attn_bias * tf.expand_dims(transferred_mask, 1) # [b, 1, 2 * l_t, 2 * l_t] with tf.variable_scope('bert', reuse=True): with tf.variable_scope('embeddings'), tf.device('/cpu:0'): # Perform embedding lookup on the target word ids. (out_embed, bert_embeddings) = embedding_lookup( input_ids=changed_ids, vocab_size=config.vocab_size, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', use_one_hot_embeddings=False) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. out_embed = embedding_postprocessor( input_tensor=out_embed, use_token_type=True, token_type_ids=change_segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) masked_out_embed = out_embed * tf.expand_dims(change_ids_mask, -1) self.decoder_input = tf.pad(masked_out_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left # ################################################### decoding train - 2 with tf.variable_scope('decoder_2'): self.all_att_weights, self.decoder_output_2 = transformer_decoder(self.decoder_input, self.encoder_output, (1.0 - self.second_dec_attn_bias) * -1e9, self.enc_attn_bias, self.hps, scope='decoder_2') # [b, 2 * l_t, e] => [b, l_t, e] => [b * l_t, v] target_len = tf.shape(self.output_ids)[1] # keep only ground-truth part for attention weight & decoder output self.all_att_weights[-1] = self.all_att_weights[-1][:, :target_len, :] # [b, l_t, l_s] self.decoder_output_2 = self.decoder_output_2[:, :target_len, :] # [b, l_t, v] self.decoder_output_2 = tf.reshape(self.decoder_output_2, [-1, hidden_size]) self.second_logits = tf.matmul(self.decoder_output_2, self.decoder_weights, False, True) # (b*l_t, v) self.vocab_probs_2 = tf.nn.softmax(self.second_logits) # [b * l_t, v] self.second_logits = self.vocab_probs_2