def decode_sample(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): ''' src_sentence: [seq_len] ''' assert not isinstance(transducer, HardMonoTransducer) if isinstance(transducer, HMMTransducer): return decode_sample_hmm(transducer, src_sentence, max_len=max_len, trg_bos=BOS_IDX, trg_eos=EOS_IDX) transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for _ in range(max_len): word_logprob, hidden, attn = transducer.decode_step( enc_hs, src_mask, input_, hidden) word = Categorical(word_logprob.exp()).sample_n(1)[0] attns.append(attn) if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) return output, attns
def decode_sample_hmm(transducer, src_sentence, max_len=100, trg_bos=BOS_IDX, trg_eos=EOS_IDX): transducer.eval() src_mask = dummy_mask(src_sentence) enc_hs = transducer.encode(src_sentence) T = src_mask.shape[0] output, attns = [], [] hidden = transducer.dec_rnn.get_init_hx(1) input_ = torch.tensor([trg_bos], device=DEVICE) input_ = transducer.dropout(transducer.trg_embed(input_)) for idx in range(max_len): trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_, hidden) if idx == 0: initial = trans[:, 0].unsqueeze(1) attns.append(initial) forward = initial else: attns.append(trans) # forward = torch.bmm(forward, trans) forward = forward + trans.transpose(1, 2) forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2) # wordprob = torch.bmm(forward, emiss) log_wordprob = forward + emiss.transpose(1, 2) log_wordprob = log_wordprob.logsumexp(dim=-1) # word = torch.max(log_wordprob, dim=-1)[1] word = Categorical(log_wordprob.exp()).sample_n(1)[0] if word == trg_eos: break input_ = transducer.dropout(transducer.trg_embed(word)) output.append(word.item()) word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1) word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T) forward = forward + word_emiss return output, attns