Exemplo n.º 1
0
def greedy(lm_model, lm_vocab, device, s, max_len):
    x, m = s2t(s, lm_vocab)
    x = x.to(device)
    res = []
    for l in range(max_len):
        probs, pred = lm_model.work(x)
        next_tk = []
        for i in range(len(s)):
            next_tk.append(lm_vocab.idx2token(pred[len(s[i]) - 1, i].item()))
            
        s_ = []
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        x, m = s2t(s, lm_vocab)
        x = x.to(device)
    res += s_
        
    r = ''.join(res[0])
    if "<bos>" in r:
        return r.split("<bos>")[1]
    else:
        return r
Exemplo n.º 2
0
def top_k(s, max_len):
    incremental_state = None
    x, m = s2t(s, lm_vocab)
    x = x.cuda(gpu)
    for l in range(max_len):
        probs, pred, incremental_state = lm_model.work_incremental(
            x, incremental_state)
        next_tk = []
        for i in range(len(s)):
            if l == 0:
                logits = probs[len(s[i]) - 1, i]
            else:
                logits = probs[0, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples=1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        s = [sent + t for sent, t in zip(s, next_tk)]

        x, m = s2t(s, lm_vocab)
        x = x.cuda(gpu)

    for i in s:
        print(i)
    return s
Exemplo n.º 3
0
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    start = time.time()
    incremental_state = None
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)
    res = []
    # print("inp_ys_tpl.size(0)", inp_ys_tpl.size(0), "s len", len(s))
    for l in range(inp_ys_tpl.size(0)):
        probs, pred, incremental_state = lm_model.work_incremental(
            enc, src_padding_mask, inp_y, inp_ys_tpl[0:l + 1, :],
            inp_ys_seg[0:l + 1, :], inp_ys_pos[0:l + 1, :], incremental_state)
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l, i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue

            if l == 0:
                logits = probs[len(s[i]) - 1, i]
            else:
                logits = probs[0, i]
            # print(logits)
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(
                ps, num_samples=1)  # 根据权重张量, 按照数量, 抽取样例, 返回是下标?

            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))

        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)
        bidx = torch.BoolTensor(bidx).cuda(gpu)
        incremental_state["bidx"] = bidx
    res += s_

    # for i in res:
    #    print(''.join(i))
    # print(time.time()-start)
    return res
Exemplo n.º 4
0
def top_p_inc(lm_model, lm_vocab, device, s, k, p, max_len):
    start = time.time()
    incremental_state = None
    x, m = s2t(s, lm_vocab)
    x = x.to(device)
    res = []
    for l in range(max_len):
        probs, pred, incremental_state = lm_model.work_incremental(
            x, incremental_state)
        next_tk = []
        for i in range(len(s)):
            if l == 0:
                logits = probs[len(s[i]) - 1, i]
                ps, idx = top_p_sampling(logits, k, p)
                ps = ps / torch.sum(ps)
            else:
                logits = probs[0, i]
                ps, idx = top_p_sampling(logits, k, p)
                ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples=1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))

        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        x, m = s2t(s, lm_vocab)
        x = x.to(device)
        bidx = torch.BoolTensor(bidx).to(device)
        incremental_state["bidx"] = bidx
    res += s_

    r = ''.join(res[0])
    if "<bos>" in r:
        return r.split("<bos>")[1]
    else:
        return r
Exemplo n.º 5
0
def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    start = time.time()
    incremental_state = None
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
                                         incremental_state)
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            
            if l == 0:
                pred = pred[len(s[i]) - 1, i]
            else:
                pred = pred[0, i]
            next_tk.append(lm_vocab.idx2token(pred.item()))
        
        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)
        bidx = torch.BoolTensor(bidx).cuda(gpu)
        incremental_state["bidx"] = bidx
    res += s_
        
    #for i in res:
    #    print(''.join(i))
    print(time.time()-start)
    return res
Exemplo n.º 6
0
def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)

    start = time.time()
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred = lm_model.work(enc, src_padding_mask, inp_y,
                                    inp_ys_tpl[0:l + 1, :],
                                    inp_ys_seg[0:l + 1, :],
                                    inp_ys_pos[0:l + 1, :])
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l, i].item())
            if ctk != "<c1>":
                next_tk.append(ctk)
                continue
            logits = probs[len(s[i]) - 1, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples=1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))

        s_ = []
        for sent, t in zip(s, next_tk):
            if t == "<eos>":
                res.append(sent)
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)

    res += s_

    # for i in res:
    #    print(''.join(i))

    # print(time.time()-start)
    return res
Exemplo n.º 7
0
def top_g(lm_model, lm_vocab, device, s, max_len):
    x, m = s2t(s, lm_vocab)
    x = x.to(device)
    for l in range(max_len):
        probs, pred = lm_model.work(x)
        next_tk = []
        for i in range(len(s)):
            logits = probs[len(s[i]) - 1, i]
            ps, idx = top_g_sampling(logits)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples = 1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        s = [sent + [t] for sent, t in zip(s, next_tk)]

        x, m = s2t(s, lm_vocab)
        x = x.to(device)

    for i in s:
        print(i)
Exemplo n.º 8
0
def beam_search(lm_model, lm_vocab, device, s, max_len):
    x, m = s2t(s, lm_vocab)
    return beam_decode(lm_model, lm_vocab, device, s[0], x[:len(s[0]), 0], max_len)
Exemplo n.º 9
0
def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):
    x, m = s2t(s, lm_vocab)
    return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos)