def decode(self):
        config = self.bert_config
        hidden_size = self.encoder_output.shape[2].value
        with tf.variable_scope('bert', reuse=True):
            with tf.variable_scope('embeddings'), tf.device('/cpu:0'):
                # Perform embedding lookup on the target word ids.
                (self.out_embed, self.bert_embeddings) = embedding_lookup(
                    input_ids=self.output_ids,
                    vocab_size=config.vocab_size,
                    embedding_size=config.hidden_size,
                    initializer_range=config.initializer_range,
                    word_embedding_name='word_embeddings',
                    use_one_hot_embeddings=False)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.out_embed = embedding_postprocessor(
                    input_tensor=self.out_embed,
                    use_token_type=True,
                    token_type_ids=self.out_segment_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob)

        with tf.variable_scope('decoder_1'):
            self.decoder_weights = self.bert_embeddings
            self.masked_out_embed = self.out_embed * tf.expand_dims(self.output_mask, -1)
            self.decoder_input = tf.pad(self.masked_out_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left
            # ################################################### decoding train - 1
            self.dec_attn_bias = attention_bias(tf.shape(self.masked_out_embed)[1], 'causal')
            self.all_att_weights, self.decoder_output_1 = transformer_decoder(self.decoder_input, self.encoder_output,
                                                                              self.dec_attn_bias, self.enc_attn_bias,
                                                                              self.hps, scope='decoder_1')
            # [b, l_t, e] => [b*l_t, v]
            self.decoder_output_1 = tf.reshape(self.decoder_output_1, [-1, hidden_size])
            self.vocab_logits = tf.matmul(self.decoder_output_1, self.decoder_weights, False, True)  # (b*l_t, v)
            self.vocab_probs = tf.nn.softmax(self.vocab_logits)  # [b * l_t, v]
            self.logits = self.vocab_probs
            self.pred_ids = tf.reshape(tf.argmax(self.logits, axis=-1), [self.batch_size, -1])
Esempio n. 2
0
    def decode_infer_2(self):
        # stage 2, word level inference using decoded sequence
        # l_t = decode sequence length
        # during infer, following graph are constructed using beam search
        hidden_size = self.bert_config.hidden_size
        with self.graph.as_default():
            target_sequence = tf.squeeze(self.decode_seq, axis=1)
            target_sequence = self.trunct(target_sequence)
            target_length = self.decode_length
            target_seg_ids = tf.zeros_like(target_sequence,
                                           dtype=tf.int32,
                                           name='target_seg_ids_infer_2')
            tgt_mask = tf.sequence_mask(target_length,
                                        maxlen=tf.shape(target_sequence)[1],
                                        dtype=tf.float32)  # [b, q']

            is_training = self.is_training
            dec_model = modeling.BertModel(
                config=self.bert_config,
                is_training=is_training,
                input_ids=target_sequence,
                input_mask=tgt_mask,
                token_type_ids=target_seg_ids,
                scope='bert',
                reuse=tf.AUTO_REUSE,
                use_one_hot_embeddings=self.hps.use_tpu
            )  # use_one_hot_embeddings=Flags.tpu ?

            dec_output = dec_model.get_sequence_output()  # [b, l_t, h]
            tgt_embed = dec_output

            # with tf.variable_scope('bert', reuse=True):
            #     with tf.variable_scope('embeddings'), tf.device('/cpu:0'):
            #         # Perform embedding lookup on the target word ids.
            #         (tgt_embed, _) = embedding_lookup(
            #             input_ids=target_sequence,
            #             vocab_size=config.vocab_size,
            #             embedding_size=config.hidden_size,
            #             initializer_range=config.initializer_range,
            #             word_embedding_name='word_embeddings',
            #             use_one_hot_embeddings=False)
            #
            #         # Add positional embeddings and token type embeddings, then layer
            #         # normalize and perform dropout.
            #         tgt_embed = embedding_postprocessor(
            #             input_tensor=tgt_embed,
            #             use_token_type=True,
            #             token_type_ids=target_seg_ids,
            #             token_type_vocab_size=config.type_vocab_size,
            #             token_type_embedding_name='token_type_embeddings',
            #             use_position_embeddings=True,
            #             position_embedding_name='position_embeddings',
            #             initializer_range=config.initializer_range,
            #             max_position_embeddings=config.max_position_embeddings,
            #             dropout_prob=config.hidden_dropout_prob)

            with tf.variable_scope('decoder_2', reuse=True):
                masked_tgt_embed = tgt_embed * tf.expand_dims(tgt_mask, -1)
                second_dec_attn_bias = attention_bias(
                    tf.shape(masked_tgt_embed)[1], 'cloze_bias')
                infer_decoder_input = tf.pad(
                    masked_tgt_embed,
                    [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left
                all_att_weights, decoder_output = transformer_decoder(
                    infer_decoder_input,
                    self.enc_output,
                    second_dec_attn_bias,
                    self.enc_attn_bias,
                    self.hps,
                    scope='decoder_2')
                # [b, l_t, e] => [b*l_t, v]
                decoder_output = tf.reshape(decoder_output, [-1, hidden_size])
                second_logits = tf.matmul(decoder_output, self.decoder_weights,
                                          False, True)  # (b*l_t, v)
                vocab_probs = tf.nn.softmax(second_logits)  # [b * l_t, v]
                vocab_size = len(self.hps.vocab)
                with tf.variable_scope('copy', reuse=tf.AUTO_REUSE):
                    logits = calculate_final_logits(
                        decoder_output, all_att_weights, vocab_probs,
                        self.input_ids_oo, self.max_out_oovs, self.input_mask,
                        vocab_size, self.infer_tiled_len)  # [b * l_t, v + v']
                second_log_prob = tf.log(logits)
                # (b, l_t, v)
                extend_vocab_size = tf.add(tf.constant(vocab_size),
                                           self.max_out_oovs)
                second_log_prob = tf.reshape(
                    second_log_prob,
                    [-1, tf.shape(target_sequence)[1], extend_vocab_size])
                second_log_id = tf.argmax(second_log_prob, axis=-1)  # (b, l_t)
        return second_log_id
    def decode_2(self):
        config = self.bert_config
        hidden_size = self.encoder_output.shape[2].value
        draft = self.trunct(self.pred_ids)  # as the draft may have copy words, we transform them to UNK first
        draft = tf.cast(draft, tf.int32)
        changed_ids = tf.concat([self.output_ids, draft], axis=-1)  # [b, 2 * l_t]
        change_segment_ids = tf.zeros_like(changed_ids, dtype=tf.int32, name='change_segment_ids')

        def calcu_id_len(input_tensor):
            step_size = tf.constant(0.001)
            a = input_tensor
            res = tf.argmin(tf.cast(a, tf.float32) + tf.cast(tf.range(0, tf.shape(a)[-1]), tf.float32) * step_size,
                            -1) + 1
            return res

        pred_ids_len = calcu_id_len(draft)  # [b,]
        pred_ids_mask_w_draft = tf.sequence_mask(pred_ids_len,
                                                 maxlen=tf.shape(draft)[1],
                                                 dtype=tf.float32)  # [b, l_t]
        pred_ids_mask_wo_draft = tf.zeros_like(draft, dtype=tf.float32)
        pred_ids_mask = tf.cond(self.feed_draft, lambda: pred_ids_mask_w_draft,
                                lambda: pred_ids_mask_wo_draft)
        change_ids_mask = tf.concat([self.output_mask, pred_ids_mask], axis=-1)  # [b, 2 * l_t]

        transferred_mask = create_attention_mask_from_input_mask(changed_ids, change_ids_mask)  # [b, 2 * l_t, 2 * l_t]

        self.second_dec_attn_bias_w_draft = attention_bias(tf.shape(changed_ids)[1], 'mask_draft')
        self.second_dec_attn_bias_wo_draft = attention_bias(tf.shape(changed_ids)[1], 'mask_draft_warmup')
        self.second_dec_attn_bias = tf.cond(self.feed_draft, lambda: self.second_dec_attn_bias_w_draft,
                                            lambda: self.second_dec_attn_bias_wo_draft)  # [1, 1, 2 * l_t, 2 *l_t]
        self.second_dec_attn_bias = tf.tile(self.second_dec_attn_bias,
                                            [tf.shape(self.output_ids)[0], 1, 1, 1])  # [b, 1, 2 * l_t, 2 * l_t]

        self.second_dec_attn_bias = self.second_dec_attn_bias * tf.expand_dims(transferred_mask,
                                                                               1)  # [b, 1, 2 * l_t, 2 * l_t]

        with tf.variable_scope('bert', reuse=True):
            with tf.variable_scope('embeddings'), tf.device('/cpu:0'):
                # Perform embedding lookup on the target word ids.
                (out_embed, bert_embeddings) = embedding_lookup(
                    input_ids=changed_ids,
                    vocab_size=config.vocab_size,
                    embedding_size=config.hidden_size,
                    initializer_range=config.initializer_range,
                    word_embedding_name='word_embeddings',
                    use_one_hot_embeddings=False)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                out_embed = embedding_postprocessor(
                    input_tensor=out_embed,
                    use_token_type=True,
                    token_type_ids=change_segment_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob)

        masked_out_embed = out_embed * tf.expand_dims(change_ids_mask, -1)
        self.decoder_input = tf.pad(masked_out_embed, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]  # Shift left
        # ################################################### decoding train - 2
        with tf.variable_scope('decoder_2'):
            self.all_att_weights, self.decoder_output_2 = transformer_decoder(self.decoder_input, self.encoder_output,
                                                                              (1.0 - self.second_dec_attn_bias) * -1e9,
                                                                              self.enc_attn_bias,
                                                                              self.hps, scope='decoder_2')
            # [b, 2 * l_t, e] => [b, l_t, e] => [b * l_t, v]
            target_len = tf.shape(self.output_ids)[1]
            # keep only ground-truth part for attention weight & decoder output
            self.all_att_weights[-1] = self.all_att_weights[-1][:, :target_len, :]  # [b, l_t, l_s]
            self.decoder_output_2 = self.decoder_output_2[:, :target_len, :]  # [b, l_t, v]
            self.decoder_output_2 = tf.reshape(self.decoder_output_2, [-1, hidden_size])
            self.second_logits = tf.matmul(self.decoder_output_2, self.decoder_weights, False, True)  # (b*l_t, v)
            self.vocab_probs_2 = tf.nn.softmax(self.second_logits)  # [b * l_t, v]
            self.second_logits = self.vocab_probs_2