def greedy(lm_model, lm_vocab, device, s, max_len): x, m = s2t(s, lm_vocab) x = x.to(device) res = [] for l in range(max_len): probs, pred = lm_model.work(x) next_tk = [] for i in range(len(s)): next_tk.append(lm_vocab.idx2token(pred[len(s[i]) - 1, i].item())) s_ = [] for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "<eos>": res.append(sent) else: s_.append(sent + [t]) if not s_: break s = s_ x, m = s2t(s, lm_vocab) x = x.to(device) res += s_ r = ''.join(res[0]) if "<bos>" in r: return r.split("<bos>")[1] else: return r
def top_k(s, max_len): incremental_state = None x, m = s2t(s, lm_vocab) x = x.cuda(gpu) for l in range(max_len): probs, pred, incremental_state = lm_model.work_incremental( x, incremental_state) next_tk = [] for i in range(len(s)): if l == 0: logits = probs[len(s[i]) - 1, i] else: logits = probs[0, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples=1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s = [sent + t for sent, t in zip(s, next_tk)] x, m = s2t(s, lm_vocab) x = x.cuda(gpu) for i in s: print(i) return s
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): start = time.time() incremental_state = None inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res = [] # print("inp_ys_tpl.size(0)", inp_ys_tpl.size(0), "s len", len(s)) for l in range(inp_ys_tpl.size(0)): probs, pred, incremental_state = lm_model.work_incremental( enc, src_padding_mask, inp_y, inp_ys_tpl[0:l + 1, :], inp_ys_seg[0:l + 1, :], inp_ys_pos[0:l + 1, :], incremental_state) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l, i].item()) if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>": next_tk.append(ctk) continue if l == 0: logits = probs[len(s[i]) - 1, i] else: logits = probs[0, i] # print(logits) ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial( ps, num_samples=1) # 根据权重张量, 按照数量, 抽取样例, 返回是下标? sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "<eos>": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) bidx = torch.BoolTensor(bidx).cuda(gpu) incremental_state["bidx"] = bidx res += s_ # for i in res: # print(''.join(i)) # print(time.time()-start) return res
def top_p_inc(lm_model, lm_vocab, device, s, k, p, max_len): start = time.time() incremental_state = None x, m = s2t(s, lm_vocab) x = x.to(device) res = [] for l in range(max_len): probs, pred, incremental_state = lm_model.work_incremental( x, incremental_state) next_tk = [] for i in range(len(s)): if l == 0: logits = probs[len(s[i]) - 1, i] ps, idx = top_p_sampling(logits, k, p) ps = ps / torch.sum(ps) else: logits = probs[0, i] ps, idx = top_p_sampling(logits, k, p) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples=1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "<eos>": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ x, m = s2t(s, lm_vocab) x = x.to(device) bidx = torch.BoolTensor(bidx).to(device) incremental_state["bidx"] = bidx res += s_ r = ''.join(res[0]) if "<bos>" in r: return r.split("<bos>")[1] else: return r
def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): start = time.time() incremental_state = None inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res = [] for l in range(inp_ys_tpl.size(0)): probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \ inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\ incremental_state) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>": next_tk.append(ctk) continue if l == 0: pred = pred[len(s[i]) - 1, i] else: pred = pred[0, i] next_tk.append(lm_vocab.idx2token(pred.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "<eos>": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) bidx = torch.BoolTensor(bidx).cuda(gpu) incremental_state["bidx"] = bidx res += s_ #for i in res: # print(''.join(i)) print(time.time()-start) return res
def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) start = time.time() res = [] for l in range(inp_ys_tpl.size(0)): probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l + 1, :], inp_ys_seg[0:l + 1, :], inp_ys_pos[0:l + 1, :]) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l, i].item()) if ctk != "<c1>": next_tk.append(ctk) continue logits = probs[len(s[i]) - 1, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples=1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] for sent, t in zip(s, next_tk): if t == "<eos>": res.append(sent) else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res += s_ # for i in res: # print(''.join(i)) # print(time.time()-start) return res
def top_g(lm_model, lm_vocab, device, s, max_len): x, m = s2t(s, lm_vocab) x = x.to(device) for l in range(max_len): probs, pred = lm_model.work(x) next_tk = [] for i in range(len(s)): logits = probs[len(s[i]) - 1, i] ps, idx = top_g_sampling(logits) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples = 1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s = [sent + [t] for sent, t in zip(s, next_tk)] x, m = s2t(s, lm_vocab) x = x.to(device) for i in s: print(i)
def beam_search(lm_model, lm_vocab, device, s, max_len): x, m = s2t(s, lm_vocab) return beam_decode(lm_model, lm_vocab, device, s[0], x[:len(s[0]), 0], max_len)
def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s): x, m = s2t(s, lm_vocab) return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos)