def decode_greedy_pointergenerator_transformer(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert isinstance(transducer, PointerGeneratorTransformer) transducer.eval() src_mask = dummy_mask(src_sentence) src_mask = (src_mask == 0).transpose(0, 1) enc_hs = transducer.encode(src_sentence, src_mask) output, attns = [trg_bos], [] for _ in range(max_len): output_tensor = torch.tensor(output, device=DEVICE).view(len(output), 1) trg_mask = dummy_mask(output_tensor) trg_mask = (trg_mask == 0).transpose(0, 1) word_logprob = transducer.decode(enc_hs, src_mask, output_tensor, trg_mask, src_sentence.transpose(0, 1)) word_logprob = word_logprob[-1] word = torch.max(word_logprob, dim=1)[1] if word == trg_eos: break output.append(word.item()) return output[1:], attns
def decode(self, mode, write_fp, decode_fn): self.model.eval() cnt = 0 sampler, nb_instance = self.iterate_instance(mode) decode_fn.reset() # with open(f'{write_fp}.{mode}.tsv', 'w') as fp: # fix alexander kahanek with open('{0}.{1}.tsv'.format(write_fp, mode), 'w') as fp: # fp.write(f'prediction\ttarget\tloss\tdist\n') # fix alexander kahanek fp.write('prediction\ttarget\tloss\tdist\n') for src, trg in tqdm(sampler(), total=nb_instance): pred, _ = decode_fn(self.model, src) dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1]) src_mask = dummy_mask(src) trg_mask = dummy_mask(trg) data = (src, src_mask, trg, trg_mask) loss = self.model.get_loss(data).item() trg = self.data.decode_target(trg)[1:-1] pred = self.data.decode_target(pred) fp.write( # f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n') # fix alexander kahanek '{0}\t{1}\t{2}\t{3}\n'.format(" ".join(pred), " ".join(trg), loss, dist)) cnt += 1 decode_fn.reset() # self.logger.info(f'finished decoding {cnt} {mode} instance') # fix alexander kahanek self.logger.info('finished decoding {0} {1} instance'.format( cnt, mode))
def decode_sample(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert not isinstance(transducer, HardMonoTransducer) if isinstance(transducer, HMMTransducer): return decode_sample_hmm(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden) word = Categorical(word_logprob.exp()).sample_n(1)[0] attns.append(attn) if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def main(): opt = get_args() decode_fn = Decoder(opt.decode, max_len=opt.max_len, beam_size=opt.beam_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(open(opt.model, mode="rb"), map_location=device) model = model.to(device) trg_i2c = {i: c for c, i in model.trg_c2i.items()} def decode_trg(seq): return [trg_i2c[i] for i in seq] maybe_mkdir(opt.out_file) with open(opt.in_file, "r", encoding="utf-8") as in_fp, open( opt.out_file, "w", encoding="utf-8" ) as out_fp: for line in in_fp.readlines(): toks = line.strip().split("\t") if len(toks) < 2 or line[0] == "#": # pass through out_fp.write(line) continue # word, lemma, tags = toks[1], toks[2], toks[5] word, tags = toks[1], toks[5] word, tags = list(word), tags.split(";") src = encode(model, word, tags, device) src_mask = dummy_mask(src) pred, _ = decode_fn(model, src, src_mask) pred = unpack_batch(pred)[0] pred_out = "".join(decode_trg(pred)) # write lemma toks[2] = pred_out out_fp.write("\t".join(toks) + "\n")
def decode(self, mode, write_fp, decode_fn): self.model.eval() cnt = 0 sampler, nb_instance = self.iterate_instance(mode) decode_fn.reset() with open(f"{write_fp}.{mode}.tsv", "w") as fp: fp.write("prediction\ttarget\tloss\tdist\n") for src, trg in tqdm(sampler(), total=nb_instance): pred, _ = decode_fn(self.model, src) dist = util.edit_distance(pred, trg.view(-1).tolist()[1:-1]) src_mask = dummy_mask(src) trg_mask = dummy_mask(trg) data = (src, src_mask, trg, trg_mask) loss = self.model.get_loss(data).item() trg = self.data.decode_target(trg)[1:-1] pred = self.data.decode_target(pred) fp.write(f'{" ".join(pred)}\t{" ".join(trg)}\t{loss}\t{dist}\n') cnt += 1 decode_fn.reset() self.logger.info(f"finished decoding {cnt} {mode} instance")
def decode_greedy(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' if isinstance(transducer, HardMonoTransducer): return decode_greedy_mono(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) if isinstance(transducer, HMMTransducer): return decode_greedy_hmm(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) if isinstance(transducer, PointerGeneratorTransformer): return decode_greedy_pointergenerator_transformer(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) if isinstance(transducer, Transformer): return decode_greedy_transformer(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden) word = torch.max(word_logprob, dim=1)[1] attns.append(attn) if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def decode_sample_hmm(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) T = src_mask.shape[0] output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for idx in range(max_len): trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_, hidden) if idx == 0: initial = trans[:, 0].unsqueeze(1) attns.append(initial) forward = initial else: attns.append(trans) # forward = torch.bmm(forward, trans) forward = forward + trans.transpose(1, 2) forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2) # wordprob = torch.bmm(forward, emiss) log_wordprob = forward + emiss.transpose(1, 2) log_wordprob = log_wordprob.logsumexp(dim=-1) # word = torch.max(log_wordprob, dim=-1)[1] word = Categorical(log_wordprob.exp()).sample_n(1)[0] if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1) word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T) forward = forward + word_emiss return output, attns
def main(): opt = get_args() decode_fn = Decoder(opt.decode, max_len=opt.max_len, beam_size=opt.beam_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(open(opt.model, mode="rb"), map_location=device) model = model.to(device) trg_i2c = {i: c for c, i in model.trg_c2i.items()} def decode_trg(seq): return [trg_i2c[i] for i in seq] maybe_mkdir(opt.out_file) with open(opt.out_file, "w", encoding="utf-8") as fp: for lemma, tags in read_file(opt.in_file, opt.lang): src = encode(model, lemma, tags, device) src_mask = dummy_mask(src) pred, _ = decode_fn(model, src, src_mask) pred = unpack_batch(pred)[0] pred_out = "".join(decode_trg(pred)) fp.write(f'{"".join(lemma)}\t{pred_out}\t{";".join(tags[1:])}\n')
def decode_greedy_transformer(transducer, src_sentence, src_mask, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): """ src_sentence: [seq_len] """ assert isinstance(transducer, Transformer) transducer.eval() src_mask = (src_mask == 0).transpose(0, 1) enc_hs = transducer.encode(src_sentence, src_mask) _, bs = src_sentence.shape output = torch.tensor([trg_bos] * bs, device=DEVICE) output = output.view(1, bs) finished = None for _ in range(max_len): trg_mask = dummy_mask(output) trg_mask = (trg_mask == 0).transpose(0, 1) word_logprob = transducer.decode(enc_hs, src_mask, output, trg_mask) word_logprob = word_logprob[-1] word = torch.max(word_logprob, dim=1)[1] output = torch.cat((output, word.view(1, bs))) if finished is None: finished = word == trg_eos else: finished = finished | (word == trg_eos) if finished.all().item(): break return output, None
def decode_greedy_mono(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert isinstance(transducer, HardMonoTransducer) attn_pos = 0 transducer.eval() if isinstance(src_sentence, tuple): seq_len = src_sentence[0].shape[0] else: seq_len = src_sentence.shape[0] src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden, attn_pos) word = torch.max(word_logprob, dim=1)[1] attns.append(attn) if word == STEP_IDX: attn_pos += 1 if attn_pos == seq_len: attn_pos = seq_len - 1 if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def decode_beam_hmm(transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX, return_top_beams=False): def score(beam): ''' compute score based on logprob ''' assert isinstance(beam, BeamHMM) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) T = src_mask.shape[0] output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) seq_len, log_prob, partial_sent, attn, forward = 1, 0, '', [], None beam = BeamHMM(seq_len, log_prob, hidden, input_, partial_sent, attn, forward) beams = [beam] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: trans, emiss, hidden = transducer.decode_step( enc_hs, src_mask, beam.input, beam.hidden) if beam.seq_len == 1: assert beam.forward is None initial = trans[:, 0].unsqueeze(1) attn = initial forward = initial else: assert beam.forward is not None attn = trans # forward = torch.bmm(forward, trans) forward = beam.forward + trans.transpose(1, 2) forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2) seq_len = beam.seq_len + 1 next_attn = beam.attn + [attn] # wordprob = torch.bmm(forward, emiss) log_wordprob = forward + emiss.transpose(1, 2) log_wordprob = log_wordprob.logsumexp(dim=-1) topk_word = torch.topk(log_wordprob, nb_beam, dim=-1)[1] for word in topk_word.view(nb_beam, 1): next_input = transducer.dropout(transducer.trg_embed(word)) next_output = str(word.item()) word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1) word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T) next_forward = forward + word_emiss log_prob = torch.logsumexp(next_forward, dim=-1).item() if word == trg_eos: sent = beam.partial_sent beam = BeamHMM(seq_len, log_prob, None, None, sent, next_attn, next_forward) finish_beams.append(beam) else: sent = f'{beam.partial_sent} {next_output}' beam = BeamHMM(seq_len, log_prob, hidden, next_input, sent, next_attn, next_forward) next_beams.append(beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams sorted_beams = sorted(finish_beams, key=score) if return_top_beams: return sorted_beams[:nb_beam] else: max_output = sorted_beams[0] return list(map(int, max_output.partial_sent.split())), max_output.attn
def decode_beam_mono(transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX): assert isinstance(transducer, HardMonoTransducer) def score(beam): ''' compute score based on logprob ''' assert isinstance(beam, BeamHard) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() if isinstance(src_sentence, tuple): seq_len = src_sentence[0].shape[0] else: seq_len = src_sentence.shape[0] src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) start = BeamHard(1, 0, hidden, input_, '', [], 0) beams = [start] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, beam.input, beam.hidden, beam.attn_pos) topk_log_prob, topk_word = word_logprob.topk(nb_beam) topk_log_prob = topk_log_prob.view(nb_beam, 1) topk_word = topk_word.view(nb_beam, 1) for log_prob, word in zip(topk_log_prob, topk_word): if word == trg_eos: beam = BeamHard(beam.seq_len + 1, beam.log_prob + log_prob.item(), None, None, beam.partial_sent, beam.attn + [attn], beam.attn_pos) finish_beams.append(beam) # if len(finish_beams) == 10*K: # max_output = sorted(finish_beams, key=score)[0] # return list(map(int, max_output.partial_sent.split())), max_output.attn else: shift = 1 if word == STEP_IDX and beam.attn_pos + 1 < seq_len else 0 beam = BeamHard( beam.seq_len + 1, beam.log_prob + log_prob.item(), hidden, transducer.dropout(transducer.trg_embed(word)), ' '.join([beam.partial_sent, str(word.item())]), beam.attn + [attn], beam.attn_pos + shift) next_beams.append(beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams max_output = sorted(finish_beams, key=score)[0] return list(map(int, max_output.partial_sent.split())), max_output.attn
def decode_beam_transformer( transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ): """ src_sentence: [seq_len] """ assert isinstance(transducer, Transformer) def score(beam): """ compute score based on logprob """ assert isinstance(beam, Beam) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() src_mask = dummy_mask(src_sentence) src_mask = (src_mask == 0).transpose(0, 1) enc_hs = transducer.encode(src_sentence, src_mask) input_ = torch.tensor([trg_bos], device=DEVICE).view(1, 1) start = Beam(1, 0, None, input_, "", None) beams = [start] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: trg_mask = dummy_mask(beam.input) trg_mask = (trg_mask == 0).transpose(0, 1) word_logprob = transducer.decode(enc_hs, src_mask, beam.input, trg_mask) word_logprob = word_logprob[-1] topk_log_prob, topk_word = word_logprob.topk(nb_beam) topk_log_prob = topk_log_prob.view(nb_beam, 1) topk_word = topk_word.view(nb_beam, 1) for log_prob, word in zip(topk_log_prob, topk_word): if word == trg_eos: new_beam = Beam( beam.seq_len + 1, beam.log_prob + log_prob.item(), None, None, beam.partial_sent, None, ) finish_beams.append(new_beam) else: new_beam = Beam( beam.seq_len + 1, beam.log_prob + log_prob.item(), None, torch.cat((beam.input, word.view(1, 1))), " ".join([beam.partial_sent, str(word.item())]), None, ) next_beams.append(new_beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams max_output = sorted(finish_beams, key=score)[0] return list(map(int, max_output.partial_sent.split())), []
def decode_beam_search( transducer, src_sentence, max_len=50, nb_beam=5, norm=True, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ): """ src_sentence: [seq_len] """ if isinstance(transducer, HardMonoTransducer): return decode_beam_mono( transducer, src_sentence, max_len=max_len, nb_beam=nb_beam, norm=norm, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ) if isinstance(transducer, HMMTransducer): return decode_beam_hmm( transducer, src_sentence, max_len=max_len, nb_beam=nb_beam, norm=norm, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ) if isinstance(transducer, Transformer): return decode_beam_transformer( transducer, src_sentence, max_len=max_len, nb_beam=nb_beam, norm=norm, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ) def score(beam): """ compute score based on logprob """ assert isinstance(beam, Beam) if norm: return -beam.log_prob / beam.seq_len return -beam.log_prob transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) start = Beam(1, 0, hidden, input_, "", []) beams = [start] finish_beams = [] for _ in range(max_len): next_beams = [] for beam in sorted(beams, key=score)[:nb_beam]: word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, beam.input, beam.hidden) topk_log_prob, topk_word = word_logprob.topk(nb_beam) topk_log_prob = topk_log_prob.view(nb_beam, 1) topk_word = topk_word.view(nb_beam, 1) for log_prob, word in zip(topk_log_prob, topk_word): if word == trg_eos: new_beam = Beam( beam.seq_len + 1, beam.log_prob + log_prob.item(), None, None, beam.partial_sent, beam.attn + [attn], ) finish_beams.append(new_beam) # if len(finish_beams) == 10*K: # max_output = sorted(finish_beams, key=score)[0] # return list(map(int, max_output.partial_sent.split())), max_output.attn else: new_beam = Beam( beam.seq_len + 1, beam.log_prob + log_prob.item(), hidden, transducer.dropout(transducer.trg_embed(word)), " ".join([beam.partial_sent, str(word.item())]), beam.attn + [attn], ) next_beams.append(new_beam) beams = next_beams finish_beams = finish_beams if finish_beams else next_beams max_output = sorted(finish_beams, key=score)[0] return list(map(int, max_output.partial_sent.split())), max_output.attn
def decode_beam_transformer( transducer, src_sentence, src_mask, max_len=50, nb_beam=5, trg_bos=BOS_IDX, trg_eos=EOS_IDX, ): """ src_sentence: [seq_len] """ assert isinstance(transducer, Transformer) transducer.eval() src_mask = (src_mask == 0).transpose(0, 1) enc_hs = transducer.encode(src_sentence, src_mask) _, bs = src_sentence.shape input_ = torch.tensor([trg_bos] * bs, device=DEVICE) input_ = input_.view(1, bs) output = input_ start = Beam(0, None, input_, output) beams = [start] finish_beams = [list() for _ in range(bs)] for i in range(max_len): cur_len = i + 2 # bos & the current prediction next_beams = [] for beam in beams: trg_mask = dummy_mask(beam.input) trg_mask = (trg_mask == 0).transpose(0, 1) word_logprob = transducer.decode(enc_hs, src_mask, beam.input, trg_mask) word_logprob = word_logprob[-1] topk_log_prob, topk_word = word_logprob.topk(nb_beam) topk_log_prob = topk_log_prob.split(1, dim=1) topk_word = topk_word.split(1, dim=1) for log_prob, word in zip(topk_log_prob, topk_word): log_prob = log_prob.squeeze(1) word = word.squeeze(1) log_prob = beam.log_prob + log_prob input_ = torch.cat((beam.input, word.view(1, bs))) output = torch.cat((beam.partial_sent, word.view(1, bs))) if (word == trg_eos).any(): batch_idx = (word == trg_eos).nonzero().view(-1).tolist() for j in batch_idx: score = log_prob[j] / cur_len seq = output[:, j].tolist() log_prob[j] = -1e6 finish_beams[j].append((score, seq)) new_beam = Beam( log_prob, None, input_, output, ) next_beams.append(new_beam) beam_idx = get_topk_beam_idx(next_beams, nb_beam) log_prob = gather_logprob(next_beams, bs, beam_idx) input_ = torch.stack([b.input for b in next_beams], dim=-1)[torch.arange(cur_len).view(-1, 1, 1), torch.arange(bs).view(1, -1, 1), beam_idx, ] output = gather_output(next_beams, cur_len, bs, beam_idx) beams = [ Beam(lp.squeeze(-1), None, ip.squeeze(-1), op.squeeze(-1)) for lp, ip, op in zip( log_prob.split(1, dim=-1), input_.split(1, dim=-1), output.split(1, dim=-1), ) ] return [max(b)[1] if b else [] for b in finish_beams], None