Exemplo n.º 1
0
 def trunct(self, seq):
     vocab_size = len(self.hps.vocab)
     # trunct word idx, change those greater than vocab_size to zero
     shape = seq.shape
     new_seq = tf_trunct(seq, vocab_size, self.hps.unkId)
     new_seq.set_shape(shape)
     return new_seq
Exemplo n.º 2
0
    def decode_infer(self, inputs, state):
        # state['enc']: [b * beam, l_s, e]  ,   state['dec']: [b * beam, q', e]
        # q' = previous decode output length
        # during infer, following graph are constructed using beam search
        with self.graph.as_default():
            config = self.bert_config

            target_sequence = inputs['target']  # [b * beam, q']
            vocab_size = len(self.hps.vocab_out)
            # trunct word idx, change those greater than vocab_size to unkId
            shape = target_sequence.shape
            unkid = self.hps.vocab_out[self.hps.unk]
            # target_sequence = tf_trunct(target_sequence, vocab_size, self.hps.unkId)
            target_sequence = tf_trunct(target_sequence, vocab_size, unkid)
            target_sequence.set_shape(shape)

            target_length = inputs['target_length']
            target_seg_ids = tf.zeros_like(target_sequence,
                                           dtype=tf.int32,
                                           name='target_seg_ids_infer')
            tgt_mask = tf.sequence_mask(target_length,
                                        maxlen=tf.shape(target_sequence)[1],
                                        dtype=tf.float32)  # [b, q']

            # with tf.variable_scope('bert', reuse=True):
            out_dict_size = len(self.hps.vocab_out)
            with tf.variable_scope('bert', reuse=True):
                with tf.variable_scope('embeddings'), tf.device('/cpu:0'):
                    # Perform embedding lookup on the target word ids.
                    (tgt_embed, _) = embedding_lookup(
                        input_ids=target_sequence,
                        vocab_size=out_dict_size,  # out vocab size
                        embedding_size=config.hidden_size,
                        initializer_range=config.initializer_range,
                        word_embedding_name='word_embeddings',
                        use_one_hot_embeddings=False)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    tgt_embed = embedding_postprocessor(
                        input_tensor=tgt_embed,
                        use_token_type=True,
                        token_type_ids=target_seg_ids,
                        token_type_vocab_size=config.type_vocab_size,
                        token_type_embedding_name='token_type_embeddings',
                        use_position_embeddings=True,
                        position_embedding_name='position_embeddings',
                        initializer_range=config.initializer_range,
                        max_position_embeddings=config.max_position_embeddings,
                        dropout_prob=config.hidden_dropout_prob)

            with tf.variable_scope('decode', reuse=True):
                # [b, q', e]
                masked_tgt_embed = tgt_embed * tf.expand_dims(tgt_mask, -1)
                dec_attn_bias = attention_bias(
                    tf.shape(masked_tgt_embed)[1], "causal")
                decoder_input = tf.pad(
                    masked_tgt_embed,
                    [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left

                infer_decoder_input = decoder_input[:, -1:, :]
                infer_dec_attn_bias = dec_attn_bias[:, :, -1:, :]

                ret = transformer_decoder_three(infer_decoder_input,
                                                self.enc_output,
                                                self.topic_memory,
                                                infer_dec_attn_bias,
                                                self.enc_attn_bias,
                                                self.topic_attn_bias,
                                                self.hps,
                                                state=state['decoder'])

                all_att_weights1, all_att_weights2, decoder_output, decoder_state = ret
                decoder_output = decoder_output[:, -1, :]  # [b * beam, e]
                vocab_logits = tf.matmul(decoder_output, self.decoder_weights,
                                         False, True)  # [b * beam, v]
                vocab_probs = tf.nn.softmax(vocab_logits)
                vocab_size = out_dict_size  # out vocabsize
                # we have tiled source_id_oo before feed, so last argument is set to 1
                with tf.variable_scope('copy'):
                    logits = calculate_two_copy_logits(
                        decoder_output, all_att_weights1, vocab_probs,
                        self.input_ids_oo, self.max_out_oovs, self.input_mask,
                        vocab_size, 1, all_att_weights2, self.topic_words_ids,
                        self.topic_words_mask)
                log_prob = tf.log(logits)  # [b * beam, v + v']
        return log_prob, {
            'encoder': state['encoder'],
            'decoder': decoder_state
        }