def decode_greedy_pointergenerator_transformer(transducer,
                                               src_sentence,
                                               max_len=100,
                                               trg_bos=BOS_IDX,
                                               trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert isinstance(transducer, PointerGeneratorTransformer)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    src_mask = (src_mask == 0).transpose(0, 1)
    enc_hs = transducer.encode(src_sentence, src_mask)

    output, attns = [trg_bos], []

    for _ in range(max_len):
        output_tensor = torch.tensor(output,
                                     device=DEVICE).view(len(output), 1)
        trg_mask = dummy_mask(output_tensor)
        trg_mask = (trg_mask == 0).transpose(0, 1)

        word_logprob = transducer.decode(enc_hs, src_mask,
                                         output_tensor, trg_mask,
                                         src_sentence.transpose(0, 1))
        word_logprob = word_logprob[-1]

        word = torch.max(word_logprob, dim=1)[1]
        if word == trg_eos:
            break
        output.append(word.item())
    return output[1:], attns
Esempio n. 2
0
    def decode(self, mode, write_fp, decode_fn):
        self.model.eval()
        cnt = 0
        sampler, nb_instance = self.iterate_instance(mode)
        decode_fn.reset()
        # with open(f'{write_fp}.{mode}.tsv', 'w') as fp:
        #  fix alexander kahanek
        with open('{0}.{1}.tsv'.format(write_fp, mode), 'w') as fp:
            # fp.write(f'prediction\ttarget\tloss\tdist\n')
            # fix alexander kahanek
            fp.write('prediction\ttarget\tloss\tdist\n')
            for src, trg in tqdm(sampler(), total=nb_instance):
                pred, _ = decode_fn(self.model, src)
                dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1])

                src_mask = dummy_mask(src)
                trg_mask = dummy_mask(trg)
                data = (src, src_mask, trg, trg_mask)
                loss = self.model.get_loss(data).item()

                trg = self.data.decode_target(trg)[1:-1]
                pred = self.data.decode_target(pred)
                fp.write(
                    # f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n')
                    # fix alexander kahanek
                    '{0}\t{1}\t{2}\t{3}\n'.format(" ".join(pred),
                                                  " ".join(trg), loss, dist))
                cnt += 1
        decode_fn.reset()
        # self.logger.info(f'finished decoding {cnt} {mode} instance')
        # fix alexander kahanek
        self.logger.info('finished decoding {0} {1} instance'.format(
            cnt, mode))
def decode_sample(transducer,
                  src_sentence,
                  max_len=100,
                  trg_bos=BOS_IDX,
                  trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert not isinstance(transducer, HardMonoTransducer)
    if isinstance(transducer, HMMTransducer):
        return decode_sample_hmm(transducer,
                                 src_sentence,
                                 max_len=max_len,
                                 trg_bos=BOS_IDX,
                                 trg_eos=EOS_IDX)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden)
        word = Categorical(word_logprob.exp()).sample_n(1)[0]
        attns.append(attn)
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
def main():
    opt = get_args()

    decode_fn = Decoder(opt.decode, max_len=opt.max_len, beam_size=opt.beam_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load(open(opt.model, mode="rb"), map_location=device)
    model = model.to(device)

    trg_i2c = {i: c for c, i in model.trg_c2i.items()}

    def decode_trg(seq):
        return [trg_i2c[i] for i in seq]

    maybe_mkdir(opt.out_file)
    with open(opt.in_file, "r", encoding="utf-8") as in_fp, open(
        opt.out_file, "w", encoding="utf-8"
    ) as out_fp:
        for line in in_fp.readlines():
            toks = line.strip().split("\t")
            if len(toks) < 2 or line[0] == "#":  # pass through
                out_fp.write(line)
                continue
            # word, lemma, tags = toks[1], toks[2], toks[5]
            word, tags = toks[1], toks[5]
            word, tags = list(word), tags.split(";")
            src = encode(model, word, tags, device)
            src_mask = dummy_mask(src)
            pred, _ = decode_fn(model, src, src_mask)
            pred = unpack_batch(pred)[0]
            pred_out = "".join(decode_trg(pred))
            # write lemma
            toks[2] = pred_out
            out_fp.write("\t".join(toks) + "\n")
Esempio n. 5
0
    def decode(self, mode, write_fp, decode_fn):
        self.model.eval()
        cnt = 0
        sampler, nb_instance = self.iterate_instance(mode)
        decode_fn.reset()
        with open(f"{write_fp}.{mode}.tsv", "w") as fp:
            fp.write("prediction\ttarget\tloss\tdist\n")
            for src, trg in tqdm(sampler(), total=nb_instance):
                pred, _ = decode_fn(self.model, src)
                dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1])

                src_mask = dummy_mask(src)
                trg_mask = dummy_mask(trg)
                data = (src, src_mask, trg, trg_mask)
                loss = self.model.get_loss(data).item()

                trg = self.data.decode_target(trg)[1:-1]
                pred = self.data.decode_target(pred)
                fp.write(f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n')
                cnt += 1
        decode_fn.reset()
        self.logger.info(f"finished decoding {cnt} {mode} instance")
def decode_greedy(transducer,
                  src_sentence,
                  max_len=100,
                  trg_bos=BOS_IDX,
                  trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    if isinstance(transducer, HardMonoTransducer):
        return decode_greedy_mono(transducer,
                                  src_sentence,
                                  max_len=max_len,
                                  trg_bos=BOS_IDX,
                                  trg_eos=EOS_IDX)
    if isinstance(transducer, HMMTransducer):
        return decode_greedy_hmm(transducer,
                                 src_sentence,
                                 max_len=max_len,
                                 trg_bos=BOS_IDX,
                                 trg_eos=EOS_IDX)
    if isinstance(transducer, PointerGeneratorTransformer):
        return decode_greedy_pointergenerator_transformer(transducer,
                                                          src_sentence,
                                                          max_len=max_len,
                                                          trg_bos=BOS_IDX,
                                                          trg_eos=EOS_IDX)
    if isinstance(transducer, Transformer):
        return decode_greedy_transformer(transducer,
                                         src_sentence,
                                         max_len=max_len,
                                         trg_bos=BOS_IDX,
                                         trg_eos=EOS_IDX)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden)
        word = torch.max(word_logprob, dim=1)[1]
        attns.append(attn)
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
def decode_sample_hmm(transducer,
                      src_sentence,
                      max_len=100,
                      trg_bos=BOS_IDX,
                      trg_eos=EOS_IDX):
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)
    T = src_mask.shape[0]

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for idx in range(max_len):
        trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_,
                                                      hidden)
        if idx == 0:
            initial = trans[:, 0].unsqueeze(1)
            attns.append(initial)
            forward = initial
        else:
            attns.append(trans)
            # forward = torch.bmm(forward, trans)
            forward = forward + trans.transpose(1, 2)
            forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2)

        # wordprob = torch.bmm(forward, emiss)
        log_wordprob = forward + emiss.transpose(1, 2)
        log_wordprob = log_wordprob.logsumexp(dim=-1)
        # word = torch.max(log_wordprob, dim=-1)[1]
        word = Categorical(log_wordprob.exp()).sample_n(1)[0]
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
        word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1)
        word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T)
        forward = forward + word_emiss
    return output, attns
def main():
    opt = get_args()

    decode_fn = Decoder(opt.decode, max_len=opt.max_len, beam_size=opt.beam_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load(open(opt.model, mode="rb"), map_location=device)
    model = model.to(device)

    trg_i2c = {i: c for c, i in model.trg_c2i.items()}

    def decode_trg(seq):
        return [trg_i2c[i] for i in seq]

    maybe_mkdir(opt.out_file)
    with open(opt.out_file, "w", encoding="utf-8") as fp:
        for lemma, tags in read_file(opt.in_file, opt.lang):
            src = encode(model, lemma, tags, device)
            src_mask = dummy_mask(src)
            pred, _ = decode_fn(model, src, src_mask)
            pred = unpack_batch(pred)[0]
            pred_out = "".join(decode_trg(pred))
            fp.write(f'{"".join(lemma)}\t{pred_out}\t{";".join(tags[1:])}\n')
Esempio n. 9
0
def decode_greedy_transformer(transducer,
                              src_sentence,
                              src_mask,
                              max_len=100,
                              trg_bos=BOS_IDX,
                              trg_eos=EOS_IDX):
    """
    src_sentence: [seq_len]
    """
    assert isinstance(transducer, Transformer)
    transducer.eval()
    src_mask = (src_mask == 0).transpose(0, 1)
    enc_hs = transducer.encode(src_sentence, src_mask)

    _, bs = src_sentence.shape
    output = torch.tensor([trg_bos] * bs, device=DEVICE)
    output = output.view(1, bs)

    finished = None
    for _ in range(max_len):
        trg_mask = dummy_mask(output)
        trg_mask = (trg_mask == 0).transpose(0, 1)

        word_logprob = transducer.decode(enc_hs, src_mask, output, trg_mask)
        word_logprob = word_logprob[-1]

        word = torch.max(word_logprob, dim=1)[1]
        output = torch.cat((output, word.view(1, bs)))

        if finished is None:
            finished = word == trg_eos
        else:
            finished = finished | (word == trg_eos)

        if finished.all().item():
            break
    return output, None
def decode_greedy_mono(transducer,
                       src_sentence,
                       max_len=100,
                       trg_bos=BOS_IDX,
                       trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert isinstance(transducer, HardMonoTransducer)
    attn_pos = 0
    transducer.eval()
    if isinstance(src_sentence, tuple):
        seq_len = src_sentence[0].shape[0]
    else:
        seq_len = src_sentence.shape[0]
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden, attn_pos)
        word = torch.max(word_logprob, dim=1)[1]
        attns.append(attn)
        if word == STEP_IDX:
            attn_pos += 1
            if attn_pos == seq_len:
                attn_pos = seq_len - 1
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
def decode_beam_hmm(transducer,
                    src_sentence,
                    max_len=50,
                    nb_beam=5,
                    norm=True,
                    trg_bos=BOS_IDX,
                    trg_eos=EOS_IDX,
                    return_top_beams=False):
    def score(beam):
        '''
        compute score based on logprob
        '''
        assert isinstance(beam, BeamHMM)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)
    T = src_mask.shape[0]

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))

    seq_len, log_prob, partial_sent, attn, forward = 1, 0, '', [], None
    beam = BeamHMM(seq_len, log_prob, hidden, input_, partial_sent, attn,
                   forward)
    beams = [beam]
    finish_beams = []

    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            trans, emiss, hidden = transducer.decode_step(
                enc_hs, src_mask, beam.input, beam.hidden)

            if beam.seq_len == 1:
                assert beam.forward is None
                initial = trans[:, 0].unsqueeze(1)
                attn = initial
                forward = initial
            else:
                assert beam.forward is not None
                attn = trans
                # forward = torch.bmm(forward, trans)
                forward = beam.forward + trans.transpose(1, 2)
                forward = forward.logsumexp(dim=-1,
                                            keepdim=True).transpose(1, 2)

            seq_len = beam.seq_len + 1
            next_attn = beam.attn + [attn]

            # wordprob = torch.bmm(forward, emiss)
            log_wordprob = forward + emiss.transpose(1, 2)
            log_wordprob = log_wordprob.logsumexp(dim=-1)
            topk_word = torch.topk(log_wordprob, nb_beam, dim=-1)[1]
            for word in topk_word.view(nb_beam, 1):
                next_input = transducer.dropout(transducer.trg_embed(word))
                next_output = str(word.item())
                word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1)
                word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T)
                next_forward = forward + word_emiss

                log_prob = torch.logsumexp(next_forward, dim=-1).item()

                if word == trg_eos:
                    sent = beam.partial_sent
                    beam = BeamHMM(seq_len, log_prob, None, None, sent,
                                   next_attn, next_forward)
                    finish_beams.append(beam)
                else:
                    sent = f'{beam.partial_sent} {next_output}'
                    beam = BeamHMM(seq_len, log_prob, hidden, next_input, sent,
                                   next_attn, next_forward)
                    next_beams.append(beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    sorted_beams = sorted(finish_beams, key=score)
    if return_top_beams:
        return sorted_beams[:nb_beam]
    else:
        max_output = sorted_beams[0]
        return list(map(int, max_output.partial_sent.split())), max_output.attn
def decode_beam_mono(transducer,
                     src_sentence,
                     max_len=50,
                     nb_beam=5,
                     norm=True,
                     trg_bos=BOS_IDX,
                     trg_eos=EOS_IDX):
    assert isinstance(transducer, HardMonoTransducer)

    def score(beam):
        '''
        compute score based on logprob
        '''
        assert isinstance(beam, BeamHard)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    if isinstance(src_sentence, tuple):
        seq_len = src_sentence[0].shape[0]
    else:
        seq_len = src_sentence.shape[0]
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    start = BeamHard(1, 0, hidden, input_, '', [], 0)
    beams = [start]
    finish_beams = []
    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            word_logprob, hidden, attn = transducer.decode_step(
                enc_hs, src_mask, beam.input, beam.hidden, beam.attn_pos)
            topk_log_prob, topk_word = word_logprob.topk(nb_beam)
            topk_log_prob = topk_log_prob.view(nb_beam, 1)
            topk_word = topk_word.view(nb_beam, 1)
            for log_prob, word in zip(topk_log_prob, topk_word):
                if word == trg_eos:
                    beam = BeamHard(beam.seq_len + 1,
                                    beam.log_prob + log_prob.item(), None,
                                    None, beam.partial_sent,
                                    beam.attn + [attn], beam.attn_pos)
                    finish_beams.append(beam)
                    # if len(finish_beams) == 10*K:
                    # max_output = sorted(finish_beams, key=score)[0]
                    # return list(map(int, max_output.partial_sent.split())), max_output.attn
                else:
                    shift = 1 if word == STEP_IDX and beam.attn_pos + 1 < seq_len else 0
                    beam = BeamHard(
                        beam.seq_len + 1, beam.log_prob + log_prob.item(),
                        hidden, transducer.dropout(transducer.trg_embed(word)),
                        ' '.join([beam.partial_sent,
                                  str(word.item())]), beam.attn + [attn],
                        beam.attn_pos + shift)
                    next_beams.append(beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    max_output = sorted(finish_beams, key=score)[0]
    return list(map(int, max_output.partial_sent.split())), max_output.attn
Esempio n. 13
0
def decode_beam_transformer(
    transducer,
    src_sentence,
    max_len=50,
    nb_beam=5,
    norm=True,
    trg_bos=BOS_IDX,
    trg_eos=EOS_IDX,
):
    """
    src_sentence: [seq_len]
    """
    assert isinstance(transducer, Transformer)

    def score(beam):
        """
        compute score based on logprob
        """
        assert isinstance(beam, Beam)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    src_mask = (src_mask == 0).transpose(0, 1)
    enc_hs = transducer.encode(src_sentence, src_mask)

    input_ = torch.tensor([trg_bos], device=DEVICE).view(1, 1)
    start = Beam(1, 0, None, input_, "", None)
    beams = [start]
    finish_beams = []
    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            trg_mask = dummy_mask(beam.input)
            trg_mask = (trg_mask == 0).transpose(0, 1)

            word_logprob = transducer.decode(enc_hs, src_mask, beam.input,
                                             trg_mask)
            word_logprob = word_logprob[-1]

            topk_log_prob, topk_word = word_logprob.topk(nb_beam)
            topk_log_prob = topk_log_prob.view(nb_beam, 1)
            topk_word = topk_word.view(nb_beam, 1)
            for log_prob, word in zip(topk_log_prob, topk_word):
                if word == trg_eos:
                    new_beam = Beam(
                        beam.seq_len + 1,
                        beam.log_prob + log_prob.item(),
                        None,
                        None,
                        beam.partial_sent,
                        None,
                    )
                    finish_beams.append(new_beam)
                else:
                    new_beam = Beam(
                        beam.seq_len + 1,
                        beam.log_prob + log_prob.item(),
                        None,
                        torch.cat((beam.input, word.view(1, 1))),
                        " ".join([beam.partial_sent,
                                  str(word.item())]),
                        None,
                    )
                    next_beams.append(new_beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    max_output = sorted(finish_beams, key=score)[0]
    return list(map(int, max_output.partial_sent.split())), []
Esempio n. 14
0
def decode_beam_search(
    transducer,
    src_sentence,
    max_len=50,
    nb_beam=5,
    norm=True,
    trg_bos=BOS_IDX,
    trg_eos=EOS_IDX,
):
    """
    src_sentence: [seq_len]
    """

    if isinstance(transducer, HardMonoTransducer):
        return decode_beam_mono(
            transducer,
            src_sentence,
            max_len=max_len,
            nb_beam=nb_beam,
            norm=norm,
            trg_bos=BOS_IDX,
            trg_eos=EOS_IDX,
        )

    if isinstance(transducer, HMMTransducer):
        return decode_beam_hmm(
            transducer,
            src_sentence,
            max_len=max_len,
            nb_beam=nb_beam,
            norm=norm,
            trg_bos=BOS_IDX,
            trg_eos=EOS_IDX,
        )

    if isinstance(transducer, Transformer):
        return decode_beam_transformer(
            transducer,
            src_sentence,
            max_len=max_len,
            nb_beam=nb_beam,
            norm=norm,
            trg_bos=BOS_IDX,
            trg_eos=EOS_IDX,
        )

    def score(beam):
        """
        compute score based on logprob
        """
        assert isinstance(beam, Beam)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    start = Beam(1, 0, hidden, input_, "", [])
    beams = [start]
    finish_beams = []
    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            word_logprob, hidden, attn = transducer.decode_step(
                enc_hs, src_mask, beam.input, beam.hidden)
            topk_log_prob, topk_word = word_logprob.topk(nb_beam)
            topk_log_prob = topk_log_prob.view(nb_beam, 1)
            topk_word = topk_word.view(nb_beam, 1)
            for log_prob, word in zip(topk_log_prob, topk_word):
                if word == trg_eos:
                    new_beam = Beam(
                        beam.seq_len + 1,
                        beam.log_prob + log_prob.item(),
                        None,
                        None,
                        beam.partial_sent,
                        beam.attn + [attn],
                    )
                    finish_beams.append(new_beam)
                    # if len(finish_beams) == 10*K:
                    # max_output = sorted(finish_beams, key=score)[0]
                    # return list(map(int, max_output.partial_sent.split())), max_output.attn
                else:
                    new_beam = Beam(
                        beam.seq_len + 1,
                        beam.log_prob + log_prob.item(),
                        hidden,
                        transducer.dropout(transducer.trg_embed(word)),
                        " ".join([beam.partial_sent,
                                  str(word.item())]),
                        beam.attn + [attn],
                    )
                    next_beams.append(new_beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    max_output = sorted(finish_beams, key=score)[0]
    return list(map(int, max_output.partial_sent.split())), max_output.attn
Esempio n. 15
0
def decode_beam_transformer(
    transducer,
    src_sentence,
    src_mask,
    max_len=50,
    nb_beam=5,
    trg_bos=BOS_IDX,
    trg_eos=EOS_IDX,
):
    """
    src_sentence: [seq_len]
    """
    assert isinstance(transducer, Transformer)

    transducer.eval()
    src_mask = (src_mask == 0).transpose(0, 1)
    enc_hs = transducer.encode(src_sentence, src_mask)

    _, bs = src_sentence.shape
    input_ = torch.tensor([trg_bos] * bs, device=DEVICE)
    input_ = input_.view(1, bs)
    output = input_
    start = Beam(0, None, input_, output)
    beams = [start]
    finish_beams = [list() for _ in range(bs)]
    for i in range(max_len):
        cur_len = i + 2  # bos & the current prediction
        next_beams = []
        for beam in beams:
            trg_mask = dummy_mask(beam.input)
            trg_mask = (trg_mask == 0).transpose(0, 1)

            word_logprob = transducer.decode(enc_hs, src_mask, beam.input,
                                             trg_mask)
            word_logprob = word_logprob[-1]

            topk_log_prob, topk_word = word_logprob.topk(nb_beam)
            topk_log_prob = topk_log_prob.split(1, dim=1)
            topk_word = topk_word.split(1, dim=1)
            for log_prob, word in zip(topk_log_prob, topk_word):
                log_prob = log_prob.squeeze(1)
                word = word.squeeze(1)

                log_prob = beam.log_prob + log_prob
                input_ = torch.cat((beam.input, word.view(1, bs)))
                output = torch.cat((beam.partial_sent, word.view(1, bs)))

                if (word == trg_eos).any():
                    batch_idx = (word == trg_eos).nonzero().view(-1).tolist()
                    for j in batch_idx:
                        score = log_prob[j] / cur_len
                        seq = output[:, j].tolist()
                        log_prob[j] = -1e6
                        finish_beams[j].append((score, seq))

                new_beam = Beam(
                    log_prob,
                    None,
                    input_,
                    output,
                )
                next_beams.append(new_beam)

        beam_idx = get_topk_beam_idx(next_beams, nb_beam)
        log_prob = gather_logprob(next_beams, bs, beam_idx)
        input_ = torch.stack([b.input for b in next_beams],
                             dim=-1)[torch.arange(cur_len).view(-1, 1, 1),
                                     torch.arange(bs).view(1, -1, 1),
                                     beam_idx, ]
        output = gather_output(next_beams, cur_len, bs, beam_idx)
        beams = [
            Beam(lp.squeeze(-1), None, ip.squeeze(-1), op.squeeze(-1))
            for lp, ip, op in zip(
                log_prob.split(1, dim=-1),
                input_.split(1, dim=-1),
                output.split(1, dim=-1),
            )
        ]
    return [max(b)[1] if b else [] for b in finish_beams], None