def backward_decode(model,
                    embed,
                    src,
                    src_mask,
                    max_len,
                    start_symbol=2,
                    raw=False,
                    return_term=-1):
    if raw == False:
        memory = model.encode(embed(src.to(device)), src_mask)
    else:
        memory = model.encode(src.to(device), src_mask)

    ys = torch.ones(src.shape[0], 1,
                    dtype=torch.int64).fill_(start_symbol).to(device)
    ret_back = embed(ys).float()
    if return_term == 2:
        ppls = 0
    for i in range(max_len + 2 - 1):
        out = model.decode(
            memory, src_mask, embed(Variable(ys)),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator.scaled_forward(out[:, -1], scale=10.0)
        if return_term == 2:
            ppls += perplexity(
                model.generator.scaled_forward(out[:, -1], scale=1.0))
        back = torch.matmul(prob, embed.weight.data.float())
        _, next_word = torch.max(prob, dim=1)
        ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1)
        ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1)
    return (
        ret_back, ys
    ) if return_term == -1 else ret_back if return_term == 0 else ys if return_term == 1 else (
        ret_back, ppls) if return_term == 2 else None
def prob_backward(model,
                  embed,
                  src,
                  src_mask,
                  max_len,
                  start_symbol=2,
                  raw=False):
    if raw == False:
        memory = model.encode(embed(src.to(device)), src_mask)
    else:
        memory = model.encode(src.to(device), src_mask)

    ys = torch.ones(src.shape[0], 1,
                    dtype=torch.int64).fill_(start_symbol).to(device)
    probs = []
    for i in range(max_len + 2 - 1):
        out = model.decode(
            memory, src_mask, embed(Variable(ys)),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        probs.append(prob.unsqueeze(1))
        _, next_word = torch.max(prob, dim=1)
        ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1)
    ret = torch.cat(probs, dim=1)
    return ret
def reconstruct(model, src, max_len, start_symbol=2):
    memory = model.encoder(model.src_embed[1](src), None)
    ys = torch.ones(src.shape[0], 1).fill_(start_symbol).long().to(device)
    ret_back = model.tgt_embed[0].pure_emb(ys).float()
    for i in range(max_len - 1):
        out = model.decode(
            memory, None, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        back = torch.matmul(prob, model.tgt_embed[0].lut.weight.data.float())
        _, next_word = torch.max(prob, dim=1)
        ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1)
        ret_back = torch.cat([ret_back, back.unsqueeze(1)], dim=1)
    return ret_back