Example #1
0
    def beam_search_decode(self, encoder_output, dot_attention_mask,
                           beam_size):
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        feature_length = encoder_output.size(1)

        self.beam_steper = BeamSteper(batch_size, beam_size, self.bos_id,
                                      self.eos_id, self.vocab_size, device)
        encoder_output = encoder_output.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, feature_length,
                                     -1)
        dot_attention_mask = dot_attention_mask.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, 1, feature_length)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    length = self.beam_steper.length_container
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(
                        token_mask, token_mask)

                    token_id = self.beam_steper.get_first_step_token(
                    ) if i == 0 else self.beam_steper.token_container.view(
                        batch_size * beam_size, -1)
                    last_prob = self.beam_decode_step(token_id,
                                                      encoder_output,
                                                      token_mask,
                                                      self_attention_mask,
                                                      dot_attention_mask,
                                                      topk=beam_size,
                                                      return_last=True)
                    self.beam_steper.step(last_prob)
                except:
                    break
        return self.beam_steper.batch_best_saver
    def beam_search_decode(self, encoder_output, feature_mask, beam_size, best_k=5, lp_eps=0.0):
        batch_size, feature_length, _ = encoder_output.size()
        device = encoder_output.device
        self.beam_steper = BeamSteper(
            batch_size=batch_size, beam_size=beam_size, bos_id=self.bos_id, eos_id=self.eos_id,
            vocab_size=self.vocab_size, device=device, k_best=best_k, lp_eps=lp_eps
        )

        beam_feature_mask = feature_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_size*beam_size, -1)
        beam_encoder_output = encoder_output.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size*beam_size, feature_length, -1)
        with t.no_grad():
            for i in range(self.max_length):
                if i == 0:
                    token_id = self.beam_steper.get_first_step_token()
                    length = self.beam_steper.get_first_step_length()
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
                    if_continue = self.beam_steper.first_step(last_prob)
                    if not if_continue:
                        break
                else:
                    token_id = self.beam_steper.token_container
                    token_id = token_id.view(batch_size * beam_size, -1)
                    length = self.beam_steper.length_container
                    length = length.view(batch_size * beam_size)
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, beam_feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, beam_encoder_output, token_mask, self_attention_mask, dot_attention_mask
                    )
                    if_continue = self.beam_steper.step(last_prob.view(batch_size, beam_size, -1))
                    if not if_continue:
                        break
        output_token = self.beam_steper.batch_best_saver.batch
        return output_token