def decode_sample(transducer,
                  src_sentence,
                  max_len=100,
                  trg_bos=BOS_IDX,
                  trg_eos=EOS_IDX):
    '''
    src_sentence: [seq_len]
    '''
    assert not isinstance(transducer, HardMonoTransducer)
    if isinstance(transducer, HMMTransducer):
        return decode_sample_hmm(transducer,
                                 src_sentence,
                                 max_len=max_len,
                                 trg_bos=BOS_IDX,
                                 trg_eos=EOS_IDX)
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for _ in range(max_len):
        word_logprob, hidden, attn = transducer.decode_step(
            enc_hs, src_mask, input_, hidden)
        word = Categorical(word_logprob.exp()).sample_n(1)[0]
        attns.append(attn)
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
    return output, attns
def decode_sample_hmm(transducer,
                      src_sentence,
                      max_len=100,
                      trg_bos=BOS_IDX,
                      trg_eos=EOS_IDX):
    transducer.eval()
    src_mask = dummy_mask(src_sentence)
    enc_hs = transducer.encode(src_sentence)
    T = src_mask.shape[0]

    output, attns = [], []
    hidden = transducer.dec_rnn.get_init_hx(1)
    input_ = torch.tensor([trg_bos], device=DEVICE)
    input_ = transducer.dropout(transducer.trg_embed(input_))
    for idx in range(max_len):
        trans, emiss, hidden = transducer.decode_step(enc_hs, src_mask, input_,
                                                      hidden)
        if idx == 0:
            initial = trans[:, 0].unsqueeze(1)
            attns.append(initial)
            forward = initial
        else:
            attns.append(trans)
            # forward = torch.bmm(forward, trans)
            forward = forward + trans.transpose(1, 2)
            forward = forward.logsumexp(dim=-1, keepdim=True).transpose(1, 2)

        # wordprob = torch.bmm(forward, emiss)
        log_wordprob = forward + emiss.transpose(1, 2)
        log_wordprob = log_wordprob.logsumexp(dim=-1)
        # word = torch.max(log_wordprob, dim=-1)[1]
        word = Categorical(log_wordprob.exp()).sample_n(1)[0]
        if word == trg_eos:
            break
        input_ = transducer.dropout(transducer.trg_embed(word))
        output.append(word.item())
        word_idx = word.view(-1, 1).expand(1, T).unsqueeze(-1)
        word_emiss = torch.gather(emiss, -1, word_idx).view(1, 1, T)
        forward = forward + word_emiss
    return output, attns