Пример #1
0
def decode_greedy_transformer(transducer,
                              src_sentence,
                              max_len=100,
                              trg_bos=BOS_IDX,
                              trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert isinstance(transducer, Transformer)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    src_mask = (src_mask == 0).transpose(0, 1)
    enc_hs = transducer.encode(src_sentence, src_mask)
    output, attns = [trg_bos], []

    gen_prob = None
    for _ in range(max_len):
        output_tensor = torch.tensor(output,
                                     device=DEVICE).view(len(output), 1)
        trg_mask = dummy_mask(output_tensor)
        trg_mask = (trg_mask == 0).transpose(0, 1)

        src_mask_i = src_mask[:len(output)]
        enc_hs_i = enc_hs[:, :len(output)]
        src_sentence_i = src_sentence[:, :len(output)]
        dec_hs, attn_weights, embed_tgt = transducer.decode(
            enc_hs_i, src_mask_i, output_tensor, trg_mask)

        t_output = transducer.final_out(dec_hs)

        if not transducer.use_copy:
            word_logprob = F.log_softmax(t_output, dim=-1)
        else:
            # word_logprob, gen_prob = transducer.source_weighted_output(src_sentence, t_output, attn_weights, enc_hs,
            # dec_hs, embed_tgt)
            word_logprob, gen_prob = transducer.source_weighted_output(
                src_sentence_i, t_output, attn_weights, enc_hs_i, dec_hs,
                embed_tgt)

        # word_logprob = F.log_softmax(word_logprob, dim=-1)

        word_logprob = word_logprob[-1]

        word = torch.max(word_logprob, dim=1)[1]
        if word == trg_eos:
            break
        output.append(word.item())

    return output[1:], gen_prob, attns
Пример #2
0
def decode_sample(transducer,
                  src_sentence,
                  max_len=100,
                  trg_bos=BOS_IDX,
                  trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert not isinstance(transducer, HardMonoTransducer)
    if isinstance(transducer, HMMTransducer):
        return decode_sample_hmm(transducer,
                                 src_sentence,
                                 max_len=max_len,
                                 trg_bos=BOS_IDX,
                                 trg_eos=EOS_IDX)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden)
        word = Categorical(word_logprob.exp()).sample_n(1)[0]
        attns.append(attn)
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
Пример #3
0
def decode_sample_hmm(transducer,
                      src_sentence,
                      max_len=100,
                      trg_bos=BOS_IDX,
                      trg_eos=EOS_IDX):
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)
    T = src_mask.shape[0]

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for idx in range(max_len):
        trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_,
                                                      hidden)
        if idx == 0:
            initial = trans[:, 0].unsqueeze(1)
            attns.append(initial)
            forward = initial
        else:
            attns.append(trans)
            # forward = torch.bmm(forward, trans)
            forward = forward + trans.transpose(1, 2)
            forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2)

        # wordprob = torch.bmm(forward, emiss)
        log_wordprob = forward + emiss.transpose(1, 2)
        log_wordprob = log_wordprob.logsumexp(dim=-1)
        # word = torch.max(log_wordprob, dim=-1)[1]
        word = Categorical(log_wordprob.exp()).sample_n(1)[0]
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
        word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1)
        word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T)
        forward = forward + word_emiss
    return output, attns
Пример #4
0
def decode_greedy_mono(transducer,
                       src_sentence,
                       max_len=100,
                       trg_bos=BOS_IDX,
                       trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert isinstance(transducer, HardMonoTransducer)
    attn_pos = 0
    transducer.eval()
    if isinstance(src_sentence, tuple):
        seq_len = src_sentence[0].shape[0]
    else:
        seq_len = src_sentence.shape[0]
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden, attn_pos)
        word = torch.max(word_logprob, dim=1)[1]
        attns.append(attn)
        if word == STEP_IDX:
            attn_pos += 1
            if attn_pos == seq_len:
                attn_pos = seq_len - 1
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
Пример #5
0
    def decode(self, mode, write_fp, decode_fn):
        self.model.eval()
        cnt = 0
        sampler, nb_instance = self.iterate_instance(mode)
        decode_fn.reset()

        outputs = []
        for src, trg in tqdm(sampler(), total=nb_instance):
            pred, gen_prob_vals, _ = decode_fn(self.model, src)
            if gen_prob_vals:
                p_gen, copy_prob, gen_prob = gen_prob_vals
            else:
                p_gen, copy_prob, gen_prob = None, None, None
            # p_gen = gen_prob_vals
            dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1])

            src_mask = dummy_mask(src)
            trg_mask = dummy_mask(trg)
            data = (src, src_mask, trg, trg_mask)
            loss = self.model.get_loss(data).item()
            src = self.data.decode_source(src)
            trg = self.data.decode_target(trg)[1:-1]
            pred = self.data.decode_target(pred)
            outputs.append(
                [pred, trg, loss, dist, src, p_gen, copy_prob, gen_prob])

        with open(f'{write_fp}.{mode}.tsv', 'w', encoding='utf-8') as fp:
            fp.write(f'prediction\ttarget\tloss\tdist\n')
            for pred, trg, loss, dist, _, _, _, _ in outputs:
                fp.write(
                    f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n')
                cnt += 1

        with open(f'{write_fp}.{mode}_gh.tsv', 'w', encoding='utf-8') as fp:
            fp.write(f'target\tprediction\n')
            for pred, trg, _, _, _, _, _, _ in outputs:
                fp.write(f'{" ".join(trg)}\t{" ".join(pred)}\n')
                cnt += 1

        with open(f'{write_fp}.{mode}_src_pred.tsv', 'w',
                  encoding='utf-8') as fp:
            for pred, trg, _, _, src, _, _, _ in outputs:
                fp.write(f'{"".join(src[1:-1])}\t{" ".join(pred)}\n')
                cnt += 1

        with open(f'{write_fp}.{mode}_copy-probs.tsv', 'w',
                  encoding='utf-8') as fp:
            fp.write(f'source\ttarget\tprediction\tdist\n')
            for pred, trg, _, _, src, p_gen, copy_prob, gen_prob in outputs:
                fp.write(
                    f'{" ".join(src)}\t{" ".join(trg)}\t{" ".join(pred)}\n')
                fp.write(f'p_gen')
                fp.write(f'{p_gen}\n')
                fp.write(f'copy_prob')
                fp.write(f'{copy_prob}\n')
                fp.write(f'gen_prob')
                fp.write(f'{gen_prob}\n')

                cnt += 1

        decode_fn.reset()
        self.logger.info(f'finished decoding {cnt} {mode} instance')
Пример #6
0
def decode_beam_hmm(transducer,
                    src_sentence,
                    max_len=50,
                    nb_beam=5,
                    norm=True,
                    trg_bos=BOS_IDX,
                    trg_eos=EOS_IDX,
                    return_top_beams=False):
    def score(beam):
        '''
        compute score based on logprob
        '''
        assert isinstance(beam, BeamHMM)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)
    T = src_mask.shape[0]

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))

    seq_len, log_prob, partial_sent, attn, forward = 1, 0, '', [], None
    beam = BeamHMM(seq_len, log_prob, hidden, input_, partial_sent, attn,
                   forward)
    beams = [beam]
    finish_beams = []

    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            trans, emiss, hidden = transducer.decode_step(
                enc_hs, src_mask, beam.input, beam.hidden)

            if beam.seq_len == 1:
                assert beam.forward is None
                initial = trans[:, 0].unsqueeze(1)
                attn = initial
                forward = initial
            else:
                assert beam.forward is not None
                attn = trans
                # forward = torch.bmm(forward, trans)
                forward = beam.forward + trans.transpose(1, 2)
                forward = forward.logsumexp(dim=-1,
                                            keepdim=True).transpose(1, 2)

            seq_len = beam.seq_len + 1
            next_attn = beam.attn + [attn]

            # wordprob = torch.bmm(forward, emiss)
            log_wordprob = forward + emiss.transpose(1, 2)
            log_wordprob = log_wordprob.logsumexp(dim=-1)
            topk_word = torch.topk(log_wordprob, nb_beam, dim=-1)[1]
            for word in topk_word.view(nb_beam, 1):
                next_input = transducer.dropout(transducer.trg_embed(word))
                next_output = str(word.item())
                word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1)
                word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T)
                next_forward = forward + word_emiss

                log_prob = torch.logsumexp(next_forward, dim=-1).item()

                if word == trg_eos:
                    sent = beam.partial_sent
                    beam = BeamHMM(seq_len, log_prob, None, None, sent,
                                   next_attn, next_forward)
                    finish_beams.append(beam)
                else:
                    sent = f'{beam.partial_sent} {next_output}'
                    beam = BeamHMM(seq_len, log_prob, hidden, next_input, sent,
                                   next_attn, next_forward)
                    next_beams.append(beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    sorted_beams = sorted(finish_beams, key=score)
    if return_top_beams:
        return sorted_beams[:nb_beam]
    else:
        max_output = sorted_beams[0]
        return list(map(int, max_output.partial_sent.split())), max_output.attn
Пример #7
0
def decode_beam_mono(transducer,
                     src_sentence,
                     max_len=50,
                     nb_beam=5,
                     norm=True,
                     trg_bos=BOS_IDX,
                     trg_eos=EOS_IDX):
    assert isinstance(transducer, HardMonoTransducer)

    def score(beam):
        '''
        compute score based on logprob
        '''
        assert isinstance(beam, BeamHard)
        if norm:
            return -beam.log_prob / beam.seq_len
        return -beam.log_prob

    transducer.eval()
    if isinstance(src_sentence, tuple):
        seq_len = src_sentence[0].shape[0]
    else:
        seq_len = src_sentence.shape[0]
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    start = BeamHard(1, 0, hidden, input_, '', [], 0)
    beams = [start]
    finish_beams = []
    for _ in range(max_len):
        next_beams = []
        for beam in sorted(beams, key=score)[:nb_beam]:
            word_logprob, hidden, attn = transducer.decode_step(
                enc_hs, src_mask, beam.input, beam.hidden, beam.attn_pos)
            topk_log_prob, topk_word = word_logprob.topk(nb_beam)
            topk_log_prob = topk_log_prob.view(nb_beam, 1)
            topk_word = topk_word.view(nb_beam, 1)
            for log_prob, word in zip(topk_log_prob, topk_word):
                if word == trg_eos:
                    beam = BeamHard(beam.seq_len + 1,
                                    beam.log_prob + log_prob.item(), None,
                                    None, beam.partial_sent,
                                    beam.attn + [attn], beam.attn_pos)
                    finish_beams.append(beam)
                    # if len(finish_beams) == 10*K:
                    # max_output = sorted(finish_beams, key=score)[0]
                    # return list(map(int, max_output.partial_sent.split())), max_output.attn
                else:
                    shift = 1 if word == STEP_IDX and beam.attn_pos + 1 < seq_len else 0
                    beam = BeamHard(
                        beam.seq_len + 1, beam.log_prob + log_prob.item(),
                        hidden, transducer.dropout(transducer.trg_embed(word)),
                        ' '.join([beam.partial_sent,
                                  str(word.item())]), beam.attn + [attn],
                        beam.attn_pos + shift)
                    next_beams.append(beam)
        beams = next_beams
    finish_beams = finish_beams if finish_beams else next_beams
    max_output = sorted(finish_beams, key=score)[0]
    return list(map(int, max_output.partial_sent.split())), max_output.attn