Example #1
0
    def beam_search_decode(self, encoder_output, dot_attention_mask,
                           beam_size):
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        feature_length = encoder_output.size(1)

        self.beam_steper = BeamSteper(batch_size, beam_size, self.bos_id,
                                      self.eos_id, self.vocab_size, device)
        encoder_output = encoder_output.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, feature_length,
                                     -1)
        dot_attention_mask = dot_attention_mask.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, 1, feature_length)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    length = self.beam_steper.length_container
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(
                        token_mask, token_mask)

                    token_id = self.beam_steper.get_first_step_token(
                    ) if i == 0 else self.beam_steper.token_container.view(
                        batch_size * beam_size, -1)
                    last_prob = self.beam_decode_step(token_id,
                                                      encoder_output,
                                                      token_mask,
                                                      self_attention_mask,
                                                      dot_attention_mask,
                                                      topk=beam_size,
                                                      return_last=True)
                    self.beam_steper.step(last_prob)
                except:
                    break
        return self.beam_steper.batch_best_saver
    def beam_search_decode(self, encoder_output, feature_mask, beam_size, best_k=5, lp_eps=0.0):
        batch_size, feature_length, _ = encoder_output.size()
        device = encoder_output.device
        self.beam_steper = BeamSteper(
            batch_size=batch_size, beam_size=beam_size, bos_id=self.bos_id, eos_id=self.eos_id,
            vocab_size=self.vocab_size, device=device, k_best=best_k, lp_eps=lp_eps
        )

        beam_feature_mask = feature_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_size*beam_size, -1)
        beam_encoder_output = encoder_output.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size*beam_size, feature_length, -1)
        with t.no_grad():
            for i in range(self.max_length):
                if i == 0:
                    token_id = self.beam_steper.get_first_step_token()
                    length = self.beam_steper.get_first_step_length()
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
                    if_continue = self.beam_steper.first_step(last_prob)
                    if not if_continue:
                        break
                else:
                    token_id = self.beam_steper.token_container
                    token_id = token_id.view(batch_size * beam_size, -1)
                    length = self.beam_steper.length_container
                    length = length.view(batch_size * beam_size)
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, beam_feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, beam_encoder_output, token_mask, self_attention_mask, dot_attention_mask
                    )
                    if_continue = self.beam_steper.step(last_prob.view(batch_size, beam_size, -1))
                    if not if_continue:
                        break
        output_token = self.beam_steper.batch_best_saver.batch
        return output_token
class TokenDecoder(t.nn.Module):
    """
    token transformer decoder
    """
    def __init__(self, input_size, feed_forward_size, hidden_size, dropout, num_head, num_layer, vocab_size,
                 padding_idx, max_length=50, share_weight=True, bos_id=3, eos_id=4, use_low_rank=False):
        super(TokenDecoder, self).__init__()
        self.max_length = max_length
        self.bos_id = bos_id
        self.eos_id = eos_id
        self.vocab_size = vocab_size
        self.embedding = Embedding(
            vocab_size, input_size, padding_idx, max_length, scale_word_embedding=share_weight)
        self.transformer_decoder = TransformerDecoder(
            input_size, feed_forward_size, hidden_size, dropout, num_head, num_layer, use_low_rank)
        self.output_layer = t.nn.Linear(input_size, vocab_size, bias=False)
        self.switch_layer = t.nn.Linear(input_size, 4, bias=False)
        t.nn.init.xavier_normal_(self.switch_layer.weight)
        if share_weight:
            self.output_layer.weight = self.embedding.word_embedding.weight
        else:
            t.nn.init.xavier_normal_(self.output_layer.weight)

    def forward(self, token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask):
        net = self.embedding(token_id)
        net.masked_fill_(~token_mask.unsqueeze(-1), 0.0)
        net = self.transformer_decoder(net, token_mask, encoder_output, self_attention_mask, dot_attention_mask)
        swich = self.switch_layer(net)
        net = self.output_layer(net)
        return net, swich

    def forward_one_step(self, token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask):
        net = self.embedding(token_id)
        net.masked_fill_(~token_mask.unsqueeze(-1), 0.0)
        net = self.transformer_decoder(net, token_mask, encoder_output, self_attention_mask, dot_attention_mask)
        net = self.output_layer(net)[:, -1:, :]
        return net

    def beam_search_decode(self, encoder_output, feature_mask, beam_size, best_k=5, lp_eps=0.0):
        batch_size, feature_length, _ = encoder_output.size()
        device = encoder_output.device
        self.beam_steper = BeamSteper(
            batch_size=batch_size, beam_size=beam_size, bos_id=self.bos_id, eos_id=self.eos_id,
            vocab_size=self.vocab_size, device=device, k_best=best_k, lp_eps=lp_eps
        )

        beam_feature_mask = feature_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_size*beam_size, -1)
        beam_encoder_output = encoder_output.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size*beam_size, feature_length, -1)
        with t.no_grad():
            for i in range(self.max_length):
                if i == 0:
                    token_id = self.beam_steper.get_first_step_token()
                    length = self.beam_steper.get_first_step_length()
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
                    if_continue = self.beam_steper.first_step(last_prob)
                    if not if_continue:
                        break
                else:
                    token_id = self.beam_steper.token_container
                    token_id = token_id.view(batch_size * beam_size, -1)
                    length = self.beam_steper.length_container
                    length = length.view(batch_size * beam_size)
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, beam_feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, beam_encoder_output, token_mask, self_attention_mask, dot_attention_mask
                    )
                    if_continue = self.beam_steper.step(last_prob.view(batch_size, beam_size, -1))
                    if not if_continue:
                        break
        output_token = self.beam_steper.batch_best_saver.batch
        return output_token

    def greedy_decode(self, encoder_output, feature_mask):
        """
        batched greedy decode
        """
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        token_id = t.full((batch_size, 1), fill_value=self.bos_id, dtype=t.long, device=device)
        length = t.LongTensor([1] * batch_size).to(device)
        #probs = t.Tensor().to(device)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                    last_prob, last_token_id = self.decode_step(
                        token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask,
                        topk=1, return_last=True)
                    token_id = t.cat([token_id, last_token_id], dim=1)
                    # print('concate, tokenid', token_id)
                    #probs = t.cat([probs, last_prob], dim=1)
                    for index, id in enumerate(last_token_id.squeeze(1)):
                        if id != self.eos_id:
                            length[index] += 1
                except:
                    #TODO: to be more consious
                    break
        return token_id
    #
    # def greedy_decode(self, encoder_output, dot_attention_mask):
    #     """
    #     batched greedy decode
    #     """
    #     batch_size = encoder_output.size(0)
    #     device = encoder_output.device
    #     token_id = t.full((batch_size, 1), fill_value=self.bos_id, dtype=t.long, device=device)
    #     length = t.LongTensor([1] * batch_size).to(device)
    #     #probs = t.Tensor().to(device)
    #     with t.no_grad():
    #         for i in range(self.max_length):
    #             try:
    #                 token_mask = Masker.get_mask(length)
    #                 self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
    #                 last_prob, last_token_id = self.decode_step(
    #                     token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask,
    #                     topk=1, return_last=True)
    #                 token_id = t.cat([token_id, last_token_id], dim=1)
    #                 # print('concate, tokenid', token_id)
    #                 #probs = t.cat([probs, last_prob], dim=1)
    #                 for index, id in enumerate(last_token_id.squeeze(1)):
    #                     if id != self.eos_id:
    #                         length[index] += 1
    #             except:
    #                 #TODO: to be more consious
    #                 break
    #     return token_id

        # B, 1

    def beam_decode_step(self, token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask):
        net, _ = self.forward(token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
        net = t.nn.functional.log_softmax(net, -1)
        return net[:, -1, :]

    def decode_step(self, token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask, topk=1,
                    return_last=True):
        # token_id B, Lt
        # encoder_output B, Lf, H
        # token_mask B, Lt
        # self_attention_mask B, 1, Lt
        # dot_attention_mask B, 1, Lf
        net, _ = self.forward(token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
        net = t.nn.functional.log_softmax(net, -1)
        probs, indexs = t.topk(net, topk)
        if return_last:
            return probs[:, -1, :], indexs[:, -1, :]
        else:
            return probs, indexs
Example #4
0
class TokenDecoder(t.nn.Module):
    """
    token transformer decoder
    """
    def __init__(self,
                 input_size,
                 feed_forward_size,
                 hidden_size,
                 dropout,
                 num_head,
                 num_layer,
                 vocab_size,
                 padding_idx,
                 max_length=50,
                 share_weight=True,
                 bos_id=3,
                 eos_id=4):
        super(TokenDecoder, self).__init__()
        self.max_length = max_length
        self.bos_id = bos_id
        self.eos_id = eos_id
        self.vocab_size = vocab_size
        self.embedding = Embedding(vocab_size,
                                   input_size,
                                   padding_idx,
                                   max_length,
                                   scale_word_embedding=share_weight)
        self.transformer_decoder = TransformerDecoder(input_size,
                                                      feed_forward_size,
                                                      hidden_size, dropout,
                                                      num_head, num_layer)
        self.layer_norm = t.nn.LayerNorm(input_size,
                                         eps=1 / (input_size**-0.5))
        self.output_layer = t.nn.Linear(input_size, vocab_size, bias=False)
        if share_weight:
            self.output_layer.weight = self.embedding.word_embedding.weight
        else:
            t.nn.init.xavier_normal_(self.output_layer.weight)

    def forward(self, token_id, encoder_output, token_mask,
                self_attention_mask, dot_attention_mask):
        net = self.embedding(token_id)
        net.masked_fill_(token_mask.unsqueeze(-1) == 0, 0.0)
        net = self.transformer_decoder(net, token_mask.unsqueeze(-1),
                                       encoder_output, self_attention_mask,
                                       dot_attention_mask)
        net = self.layer_norm(net)
        net = self.output_layer(net)
        return net

    def beam_search_decode(self, encoder_output, dot_attention_mask,
                           beam_size):
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        feature_length = encoder_output.size(1)

        self.beam_steper = BeamSteper(batch_size, beam_size, self.bos_id,
                                      self.eos_id, self.vocab_size, device)
        encoder_output = encoder_output.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, feature_length,
                                     -1)
        dot_attention_mask = dot_attention_mask.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, 1, feature_length)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    length = self.beam_steper.length_container
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(
                        token_mask, token_mask)

                    token_id = self.beam_steper.get_first_step_token(
                    ) if i == 0 else self.beam_steper.token_container.view(
                        batch_size * beam_size, -1)
                    last_prob = self.beam_decode_step(token_id,
                                                      encoder_output,
                                                      token_mask,
                                                      self_attention_mask,
                                                      dot_attention_mask,
                                                      topk=beam_size,
                                                      return_last=True)
                    self.beam_steper.step(last_prob)
                except:
                    break
        return self.beam_steper.batch_best_saver

    def greedy_decode(self, encoder_output, dot_attention_mask):
        """
        batched greedy decode
        """
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        token_id = t.full((batch_size, 1),
                          fill_value=self.bos_id,
                          dtype=t.long,
                          device=device)
        length = t.LongTensor([1] * batch_size).to(device)
        #probs = t.Tensor().to(device)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(
                        token_mask, token_mask)
                    last_prob, last_token_id = self.decode_step(
                        token_id,
                        encoder_output,
                        token_mask,
                        self_attention_mask,
                        dot_attention_mask,
                        topk=1,
                        return_last=True)
                    token_id = t.cat([token_id, last_token_id], dim=1)
                    # print('concate, tokenid', token_id)
                    #probs = t.cat([probs, last_prob], dim=1)
                    for index, id in enumerate(last_token_id.squeeze(1)):
                        if id != self.eos_id:
                            length[index] += 1
                except:
                    #TODO: to be more consious
                    break
        return token_id

        # B, 1

    def decode_step(self,
                    token_id,
                    encoder_output,
                    token_mask,
                    self_attention_mask,
                    dot_attention_mask,
                    topk=1,
                    return_last=True):
        # token_id B, Lt
        # encoder_output B, Lf, H
        # token_mask B, Lt
        # self_attention_mask B, 1, Lt
        # dot_attention_mask B, 1, Lf
        net = self.forward(token_id, encoder_output, token_mask,
                           self_attention_mask, dot_attention_mask)
        net = t.nn.functional.log_softmax(net, -1)
        probs, indexs = t.topk(net, topk)
        if return_last:
            return probs[:, -1, :], indexs[:, -1, :]
        else:
            return probs, indexs