Exemplo n.º 1
0
def greedy_decode(model, src, src_mask, max_len):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1,
                    1).fill_(src[0][-1].cpu().numpy().item()).type_as(src.data)
    for i in range(max_len):
        out = model.decode(
            memory, src_mask, Variable(ys),
            Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys
Exemplo n.º 2
0
def choose_options(model, memory, src, src_mask, ys):
    out = model.decode(
        memory, src_mask, Variable(ys[1]),
        Variable(subsequent_mask(ys[1].size(1)).type_as(src.data)))
    prob = model.generator(out[:, -1])
    dict = {}
    for j in range(prob.size()[-1]):
        dict[j] = prob[0][j].item()
    sort_dict = sorted(zip(dict.values(), dict.keys()), reverse=True)
    options = sort_dict[:beam_search_number]
    result = []
    for i in range(beam_search_number):
        result.append([
            ys[0] + options[i][0],
            torch.cat([
                ys[1],
                torch.ones(1, 1).type_as(src.data).fill_(options[i][1])
            ],
                      dim=1)
        ])
    return result
Exemplo n.º 3
0
 def make_std_mask(tgt, pad):
     "创建一个mask来隐藏填充和将来的单词"
     tgt_mask = (tgt != pad).unsqueeze(-2)
     tgt_mask = tgt_mask & Variable(
         subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
     return tgt_mask