def decode_greedy_transformer(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert isinstance(transducer, Transformer) transducer.eval() src_mask = dummy_mask(src_sentence) src_mask = (src_mask == 0).transpose(0, 1) enc_hs = transducer.encode(src_sentence, src_mask) output, attns = [trg_bos], [] gen_prob = None for _ in range(max_len): output_tensor = torch.tensor(output, device=DEVICE).view(len(output), 1) trg_mask = dummy_mask(output_tensor) trg_mask = (trg_mask == 0).transpose(0, 1) src_mask_i = src_mask[:len(output)] enc_hs_i = enc_hs[:, :len(output)] src_sentence_i = src_sentence[:, :len(output)] dec_hs, attn_weights, embed_tgt = transducer.decode( enc_hs_i, src_mask_i, output_tensor, trg_mask) t_output = transducer.final_out(dec_hs) if not transducer.use_copy: word_logprob = F.log_softmax(t_output, dim=-1) else: # word_logprob, gen_prob = transducer.source_weighted_output(src_sentence, t_output, attn_weights, enc_hs, # dec_hs, embed_tgt) word_logprob, gen_prob = transducer.source_weighted_output( src_sentence_i, t_output, attn_weights, enc_hs_i, dec_hs, embed_tgt) # word_logprob = F.log_softmax(word_logprob, dim=-1) word_logprob = word_logprob[-1] word = torch.max(word_logprob, dim=1)[1] if word == trg_eos: break output.append(word.item()) return output[1:], gen_prob, attns
def decode_sample(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert not isinstance(transducer, HardMonoTransducer) if isinstance(transducer, HMMTransducer): return decode_sample_hmm(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden) word = Categorical(word_logprob.exp()).sample_n(1)[0] attns.append(attn) if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def decode_sample_hmm(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) T = src_mask.shape[0] output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for idx in range(max_len): trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_, hidden) if idx == 0: initial = trans[:, 0].unsqueeze(1) attns.append(initial) forward = initial else: attns.append(trans) # forward = torch.bmm(forward, trans) forward = forward + trans.transpose(1, 2) forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2) # wordprob = torch.bmm(forward, emiss) log_wordprob = forward + emiss.transpose(1, 2) log_wordprob = log_wordprob.logsumexp(dim=-1) # word = torch.max(log_wordprob, dim=-1)[1] word = Categorical(log_wordprob.exp()).sample_n(1)[0] if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1) word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T) forward = forward + word_emiss return output, attns
def decode_greedy_mono(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert isinstance(transducer, HardMonoTransducer) attn_pos = 0 transducer.eval() if isinstance(src_sentence, tuple): seq_len = src_sentence[0].shape[0] else: seq_len = src_sentence.shape[0] src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden, attn_pos) word = torch.max(word_logprob, dim=1)[1] attns.append(attn) if word == STEP_IDX: attn_pos += 1 if attn_pos == seq_len: attn_pos = seq_len - 1 if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def decode(self, mode, write_fp, decode_fn): self.model.eval() cnt = 0 sampler, nb_instance = self.iterate_instance(mode) decode_fn.reset() outputs = [] for src, trg in tqdm(sampler(), total=nb_instance): pred, gen_prob_vals, _ = decode_fn(self.model, src) if gen_prob_vals: p_gen, copy_prob, gen_prob = gen_prob_vals else: p_gen, copy_prob, gen_prob = None, None, None # p_gen = gen_prob_vals dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1]) src_mask = dummy_mask(src) trg_mask = dummy_mask(trg) data = (src, src_mask, trg, trg_mask) loss = self.model.get_loss(data).item() src = self.data.decode_source(src) trg = self.data.decode_target(trg)[1:-1] pred = self.data.decode_target(pred) outputs.append( [pred, trg, loss, dist, src, p_gen, copy_prob, gen_prob]) with open(f'{write_fp}.{mode}.tsv', 'w', encoding='utf-8') as fp: fp.write(f'prediction\ttarget\tloss\tdist\n') for pred, trg, loss, dist, _, _, _, _ in outputs: fp.write( f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n') cnt += 1 with open(f'{write_fp}.{mode}_gh.tsv', 'w', encoding='utf-8') as fp: fp.write(f'target\tprediction\n') for pred, trg, _, _, _, _, _, _ in outputs: fp.write(f'{" ".join(trg)}\t{" ".join(pred)}\n') cnt += 1 with open(f'{write_fp}.{mode}_src_pred.tsv', 'w', encoding='utf-8') as fp: for pred, trg, _, _, src, _, _, _ in outputs: fp.write(f'{"".join(src[1:-1])}\t{" ".join(pred)}\n') cnt += 1 with open(f'{write_fp}.{mode}_copy-probs.tsv', 'w', encoding='utf-8') as fp: fp.write(f'source\ttarget\tprediction\tdist\n') for pred, trg, _, _, src, p_gen, copy_prob, gen_prob in outputs: fp.write( f'{" ".join(src)}\t{" ".join(trg)}\t{" ".join(pred)}\n') fp.write(f'p_gen') fp.write(f'{p_gen}\n') fp.write(f'copy_prob') fp.write(f'{copy_prob}\n') fp.write(f'gen_prob') fp.write(f'{gen_prob}\n') cnt += 1 decode_fn.reset() self.logger.info(f'finished decoding {cnt} {mode} instance')
def decode_beam_hmm(transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX, return_top_beams=False): def score(beam): ''' compute score based on logprob ''' assert isinstance(beam, BeamHMM) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) T = src_mask.shape[0] output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) seq_len, log_prob, partial_sent, attn, forward = 1, 0, '', [], None beam = BeamHMM(seq_len, log_prob, hidden, input_, partial_sent, attn, forward) beams = [beam] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: trans, emiss, hidden = transducer.decode_step( enc_hs, src_mask, beam.input, beam.hidden) if beam.seq_len == 1: assert beam.forward is None initial = trans[:, 0].unsqueeze(1) attn = initial forward = initial else: assert beam.forward is not None attn = trans # forward = torch.bmm(forward, trans) forward = beam.forward + trans.transpose(1, 2) forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2) seq_len = beam.seq_len + 1 next_attn = beam.attn + [attn] # wordprob = torch.bmm(forward, emiss) log_wordprob = forward + emiss.transpose(1, 2) log_wordprob = log_wordprob.logsumexp(dim=-1) topk_word = torch.topk(log_wordprob, nb_beam, dim=-1)[1] for word in topk_word.view(nb_beam, 1): next_input = transducer.dropout(transducer.trg_embed(word)) next_output = str(word.item()) word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1) word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T) next_forward = forward + word_emiss log_prob = torch.logsumexp(next_forward, dim=-1).item() if word == trg_eos: sent = beam.partial_sent beam = BeamHMM(seq_len, log_prob, None, None, sent, next_attn, next_forward) finish_beams.append(beam) else: sent = f'{beam.partial_sent} {next_output}' beam = BeamHMM(seq_len, log_prob, hidden, next_input, sent, next_attn, next_forward) next_beams.append(beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams sorted_beams = sorted(finish_beams, key=score) if return_top_beams: return sorted_beams[:nb_beam] else: max_output = sorted_beams[0] return list(map(int, max_output.partial_sent.split())), max_output.attn
def decode_beam_mono(transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX): assert isinstance(transducer, HardMonoTransducer) def score(beam): ''' compute score based on logprob ''' assert isinstance(beam, BeamHard) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() if isinstance(src_sentence, tuple): seq_len = src_sentence[0].shape[0] else: seq_len = src_sentence.shape[0] src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) start = BeamHard(1, 0, hidden, input_, '', [], 0) beams = [start] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, beam.input, beam.hidden, beam.attn_pos) topk_log_prob, topk_word = word_logprob.topk(nb_beam) topk_log_prob = topk_log_prob.view(nb_beam, 1) topk_word = topk_word.view(nb_beam, 1) for log_prob, word in zip(topk_log_prob, topk_word): if word == trg_eos: beam = BeamHard(beam.seq_len + 1, beam.log_prob + log_prob.item(), None, None, beam.partial_sent, beam.attn + [attn], beam.attn_pos) finish_beams.append(beam) # if len(finish_beams) == 10*K: # max_output = sorted(finish_beams, key=score)[0] # return list(map(int, max_output.partial_sent.split())), max_output.attn else: shift = 1 if word == STEP_IDX and beam.attn_pos + 1 < seq_len else 0 beam = BeamHard( beam.seq_len + 1, beam.log_prob + log_prob.item(), hidden, transducer.dropout(transducer.trg_embed(word)), ' '.join([beam.partial_sent, str(word.item())]), beam.attn + [attn], beam.attn_pos + shift) next_beams.append(beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams max_output = sorted(finish_beams, key=score)[0] return list(map(int, max_output.partial_sent.split())), max_output.attn