Beispiel #1
0
    def beam_decode(self, batch, max_len, oov_nums):

        bos_token = self.data_utils.bos
        beam_size = self.args.beam_size
        vocab_size = self.data_utils.vocab_size

        src = batch['src'].long()
        src_mask = batch['src_mask']
        src_extended = batch['src_extended'].long()
        memory = self.model.encode(src, src_mask)
        batch_size = src.size(0)

        beam = Beam(self.data_utils.pad, bos_token, self.data_utils.eos,
                    beam_size, batch_size, self.args.n_best, True, max_len)

        ys = torch.full((batch_size, 1), bos_token).type_as(src.data).cuda()
        log_prob = self.model.decode(
            memory, src_mask, Variable(ys),
            Variable(
                subsequent_mask(ys.size(1)).type_as(src.data).expand(
                    (ys.size(0), ys.size(1), ys.size(1)))), src_extended,
            oov_nums)

        # log_prob = [batch_size, 1, voc_size]
        top_prob, top_indices = torch.topk(input=log_prob, k=beam_size, dim=-1)
        # print(top_indices)
        top_prob = top_prob.view(-1, 1)
        top_indices = top_indices.view(-1, 1)
        beam.update_prob(top_prob.detach().cpu(), top_indices.detach().cpu())
        # [batch_size, 1, beam_size]
        ys = top_indices
        top_indices = None
        # print(ys.size())
        ####### repeat var #######
        src = torch.repeat_interleave(src, beam_size, dim=0)
        src_mask = torch.repeat_interleave(src_mask, beam_size, dim=0)
        #[batch_size, src_len, d_model] -> [batch_size*beam_size, src_len, d_model]
        memory = torch.repeat_interleave(memory, beam_size, dim=0)
        # print('max_len', max_len)
        for t in range(1, max_len):
            log_prob = self.model.decode(
                memory, src_mask, Variable(ys),
                Variable(
                    subsequent_mask(ys.size(1)).type_as(src.data).expand(
                        (ys.size(0), ys.size(1), ys.size(1)))), src)
            # print('log_prob', log_prob.size())
            log_prob = log_prob[:, -1].unsqueeze(1)
            # print(beam.seq)
            real_top = beam.advance(log_prob.detach().cpu())
            # print(real_top.size())
            # print(ys.size())
            # print(real_top.size())
            ys = torch.cat((ys, real_top.view(-1, 1).cuda()), dim=-1)
            # print(ys.size())

        # print(ys.size())
        # print(beam.top_prob)
        # print(len(beam.seq))

        return [beam.seq[0]]