def trunct(self, seq): vocab_size = len(self.hps.vocab) # trunct word idx, change those greater than vocab_size to zero shape = seq.shape new_seq = tf_trunct(seq, vocab_size, self.hps.unkId) new_seq.set_shape(shape) return new_seq
def decode_infer(self, inputs, state): # state['enc']: [b * beam, l_s, e] , state['dec']: [b * beam, q', e] # q' = previous decode output length # during infer, following graph are constructed using beam search with self.graph.as_default(): config = self.bert_config target_sequence = inputs['target'] # [b * beam, q'] vocab_size = len(self.hps.vocab_out) # trunct word idx, change those greater than vocab_size to unkId shape = target_sequence.shape unkid = self.hps.vocab_out[self.hps.unk] # target_sequence = tf_trunct(target_sequence, vocab_size, self.hps.unkId) target_sequence = tf_trunct(target_sequence, vocab_size, unkid) target_sequence.set_shape(shape) target_length = inputs['target_length'] target_seg_ids = tf.zeros_like(target_sequence, dtype=tf.int32, name='target_seg_ids_infer') tgt_mask = tf.sequence_mask(target_length, maxlen=tf.shape(target_sequence)[1], dtype=tf.float32) # [b, q'] # with tf.variable_scope('bert', reuse=True): out_dict_size = len(self.hps.vocab_out) with tf.variable_scope('bert', reuse=True): with tf.variable_scope('embeddings'), tf.device('/cpu:0'): # Perform embedding lookup on the target word ids. (tgt_embed, _) = embedding_lookup( input_ids=target_sequence, vocab_size=out_dict_size, # out vocab size embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', use_one_hot_embeddings=False) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. tgt_embed = embedding_postprocessor( input_tensor=tgt_embed, use_token_type=True, token_type_ids=target_seg_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) with tf.variable_scope('decode', reuse=True): # [b, q', e] masked_tgt_embed = tgt_embed * tf.expand_dims(tgt_mask, -1) dec_attn_bias = attention_bias( tf.shape(masked_tgt_embed)[1], "causal") decoder_input = tf.pad( masked_tgt_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] # Shift left infer_decoder_input = decoder_input[:, -1:, :] infer_dec_attn_bias = dec_attn_bias[:, :, -1:, :] ret = transformer_decoder_three(infer_decoder_input, self.enc_output, self.topic_memory, infer_dec_attn_bias, self.enc_attn_bias, self.topic_attn_bias, self.hps, state=state['decoder']) all_att_weights1, all_att_weights2, decoder_output, decoder_state = ret decoder_output = decoder_output[:, -1, :] # [b * beam, e] vocab_logits = tf.matmul(decoder_output, self.decoder_weights, False, True) # [b * beam, v] vocab_probs = tf.nn.softmax(vocab_logits) vocab_size = out_dict_size # out vocabsize # we have tiled source_id_oo before feed, so last argument is set to 1 with tf.variable_scope('copy'): logits = calculate_two_copy_logits( decoder_output, all_att_weights1, vocab_probs, self.input_ids_oo, self.max_out_oovs, self.input_mask, vocab_size, 1, all_att_weights2, self.topic_words_ids, self.topic_words_mask) log_prob = tf.log(logits) # [b * beam, v + v'] return log_prob, { 'encoder': state['encoder'], 'decoder': decoder_state }