def decode(self, trg: torch.Tensor, trg_mask: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor: """ Function that encodes the target sequence. :param trg: Our vectorized target sequence. [Batch_size x trg_seq_len] :param trg_mask: Mask to be passed to the SelfAttention Block when encoding the trg sequence-> check SelfAttention. :param memory: Our src sequence encoded by the encoding function. This will be used as a memory over the source sequence. :param src_mask: Mask to be passed to the SelfAttention Block when computing attention over the source sequence/memory. -> check SelfAttention. :returns: - Returns the log probabilities of the next word over the entire target sequence. [Batch_size x trg_seq_len x vocab_size] """ tokens = self.token_embedding(trg) b, t, e = tokens.size() positions = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b, t, e) x = tokens + positions trg_mask = trg_mask & subsequent_mask(t).type_as(trg_mask) for block in self.decoding_blocks: x = block(x, memory, src_mask, trg_mask) x = self.toprobs(x.view(b*t, e)).view(b, t, self.vocab_size) return F.log_softmax(x, dim=2)
def init_vars(device, src, src_mask, utter_type, model, tokenizer): start_index = tokenizer.cls_token_id mem, utter_mask, token_features, token_mask = model.encode( src, src_mask, utter_type) ys = torch.ones(1, 1).fill_(start_index).long().to(device) trg_mask = subsequent_mask(ys.size(1)).type_as(ys) coverage = torch.zeros_like(utter_mask).contiguous().float() token_coverage = torch.zeros_like(token_mask).contiguous().float() vocab_dist, tgt_attn_dist, p_gen, next_cov, next_tok_cov, tok_utter_index = model.decode( ys, trg_mask, mem, utter_mask, token_features, token_mask, coverage, token_coverage) index_1 = torch.arange(0, Config.batch_size).long() if Config.pointer_gen: vocab_dist_ = p_gen * vocab_dist attn_dist_ = (1 - p_gen) * tgt_attn_dist token_attn_indices = src[tok_utter_index, index_1, :] final_dist = vocab_dist_.scatter_add(1, token_attn_indices, attn_dist_) else: final_dist = vocab_dist final_dist = torch.log(final_dist) log_scores, ix = final_dist.topk(Config.beam_size) outputs = torch.zeros(Config.beam_size, Config.max_decode_output_length).long() outputs = outputs.to(device) outputs[:, 0] = start_index outputs[:, 1] = ix[0] e_outputs = torch.zeros(Config.beam_size, mem.size(-2), mem.size(-1)) e_outputs = e_outputs.to(device) e_outputs[:, :] = mem[0] return outputs, e_outputs, log_scores, utter_mask, next_cov, next_tok_cov, mem, utter_mask, token_features, token_mask
def forward(self, batch_data): src = batch_data['src'] src_mask = batch_data['src_mask'] rel_ids = batch_data['rel_ids'] extra_vocab = batch_data['extra_vocab'] expanded_x = batch_data['expanded_x'] semantic_mask = batch_data['semantic_mask'] batch_size = src.size(0) memory = self.model.encode(src, src_mask, rel_ids) ys = torch.ones(batch_size, 1).fill_(self.start_pos).type_as(src.data) for i in range(self.max_tgt_len - 1): tgt_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data)) tgt_mask = tgt_mask.unsqueeze(1) tgt_embed = self.model.tgt_embed(Variable(ys)) decoder_outputs, decoder_attn = self.model.decode( memory, src_mask, tgt_embed, tgt_mask) prob = self.model.generator( (memory, decoder_outputs[:, -1].unsqueeze(1), decoder_attn[:, :, -1].unsqueeze(2), tgt_embed[:, -1].unsqueeze(1), extra_vocab, expanded_x, semantic_mask)) prob = prob.squeeze(1) _, next_word = torch.max(prob, dim=1) ys = torch.cat([ys, next_word.unsqueeze(1).type_as(src.data)], dim=1) return ys, extra_vocab
def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." # noinspection PyUnresolvedReferences tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask
def get_batch_data(src, tgt, pad, rel_ids=[], extra_vocab=dict(), expanded_x=[], semantic_mask=[]): src_mask = (src != pad).unsqueeze(-2).unsqueeze(1) tgt_mask = (tgt != pad).unsqueeze(-2).unsqueeze(1) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) batch = { 'src': src, 'tgt': tgt, 'src_mask': src_mask, 'tgt_mask': tgt_mask, 'rel_ids': rel_ids, 'extra_vocab': extra_vocab, 'expanded_x': expanded_x, 'semantic_mask': semantic_mask } return batch
def greedy_decode(self, seq, attn_masks, max_len, vocab): start_symbol = vocab.token_to_idx['[CLS]'] # end_symbol = vocab.token_to_idx['[SEP]'] # required attn_masks shape for BERT: [batch_size,sequence_length] (zzingae) cont_reps, _ = self.bert(seq, attention_mask=attn_masks) # required attn_masks shape for decoder: [batch_size,1,sequence_length] (zzingae) attn_masks = attn_masks.unsqueeze(1) ys = torch.ones(seq.shape[0], 1).fill_(start_symbol).type_as(seq.data) for i in range(max_len): tgt_mask = subsequent_mask(ys.shape[1]).repeat(seq.shape[0], 1, 1).type_as(seq.data) output = self.decoder(self.tgt_embed(ys), cont_reps, attn_masks, tgt_mask) log_prob = self.generator(output[:, -1, :]) _, next_word = torch.max(log_prob, dim=1) # ys = torch.cat([ys, torch.ones(seq.shape[0], 1).type_as(seq.data).fill_(next_word[0,i])], dim=1) ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) if i == (max_len - 1): logits = self.generator(output) return logits, ys
def greedy(args, net, src, src_mask, src_vocab, tgt_vocab): # src: torch.LongTensor (bsz, slen) # src_mask: torch.ByteTensor (bsz, 1, slen) slen = src.size(1) max_len = min(args.max_len, int(args.gen_a * slen + args.gen_b)) src_lan = src_vocab.stoi["<" + args.src_lan.upper() + ">"] tgt_lan = tgt_vocab.stoi["<" + args.tgt_lan.upper() + ">"] with torch.no_grad(): net.eval() bsz = src.size(0) enc_out = net.encode(src=src, src_mask=src_mask, src_lang=src_lan) generated = src.new(bsz, max_len) generated.fill_(tgt_vocab.stoi[config.PAD]) generated[:, 0].fill_(tgt_vocab.stoi[config.BOS]) generated = generated.long() cur_len = 1 gen_len = src.new_ones(bsz).long() unfinished_sents = src.new_ones(bsz).long() cache = {'cur_len': cur_len - 1} while cur_len < max_len: x = generated[:, cur_len - 1].unsqueeze(-1) tgt_mask = (generated[:, :cur_len] != tgt_vocab.stoi[config.PAD]).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(cur_len).type_as(tgt_mask.data)) logit = net.decode( enc_out, src_mask, x, tgt_mask[:, cur_len - 1, :].unsqueeze(-2), cache=cache, tgt_lang=tgt_lan, decoder_lang_id=config.LANG2IDS[args.tgt_lan.upper()]) scores = net.generator(logit).exp().squeeze() next_words = torch.topk(scores, 1)[1].squeeze() assert next_words.size() == (bsz, ) generated[:, cur_len] = next_words * unfinished_sents + tgt_vocab.stoi[ config.PAD] * (1 - unfinished_sents) gen_len.add_(unfinished_sents) unfinished_sents.mul_( next_words.ne(tgt_vocab.stoi[config.EOS]).long()) cur_len = cur_len + 1 cache['cur_len'] = cur_len - 1 if unfinished_sents.max() == 0: break if cur_len == max_len: generated[:, -1].masked_fill_(unfinished_sents.bool(), tgt_vocab.stoi[config.EOS]) return generated[:, :cur_len], gen_len
def make_std_mask(data, pad): """ Create a mask to hide padding and future words. """ data_mask = (data != pad).unsqueeze(-2) data_mask = data_mask & Variable( subsequent_mask(data.size(-1)).type_as(data_mask.data)) return data_mask
def make_std_mask(tgt, pad): """ Create a mask to hide padding and future words. """ tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask
def test(model, config): max_sentences = config.get("max_sentences", 1e9) max_tokens = config.get("max_tokens", 1e9) corpus_prefix = Path(config['corpus_prefix']) / "subword" model_path = corpus_prefix / "spm.model" tokenizer = spm.SentencePieceProcessor() tokenizer.Load(str(model_path)) test_src = load_corpus(corpus_prefix / Path(config["test_source"]).name, tokenizer) num_test_sents = len(test_src) eos_id = tokenizer.eos_id() test_ids = list(range(num_test_sents)) test_itr = create_batch_itr(test_src, max_tokens=max_tokens, max_sentences=max_sentences, shuffle=False) test_itr = tqdm(test_itr, desc='test') for batch_ids in test_itr: src_batch = make_batch(test_src, batch_ids, eos_id) src_mask = padding_mask(src_batch, eos_id) src_encode = model.encode(src_batch, src_mask, train=False) trg_ids = [np.array([tokenizer.PieceToId('<s>')] * len(batch_ids))] eos_ids = np.array([eos_id] * len(batch_ids)) while (trg_ids[-1] != eos_ids).any(): if len(trg_ids) > config['generation_limit']: print("Warning: Sentence generation did not finish in", config['generation_limit'], "iterations.", file=sys.stderr) trg_ids.append(eos_ids) break trg_mask = [ subsequent_mask(len(trg_ids)) for _ in padding_mask(trg_ids, eos_id) ] out = model.decode(src_encode, trg_ids, src_mask, trg_mask, train=False) y = TF.pick(out, [out.shape()[0] - 1], 0) y = np.array(y.argmax(1)) trg_ids.append(y) hyp = [ hyp_sent[:np.where(hyp_sent == eos_id)[0][0]] for hyp_sent in np.array(trg_ids).T ] for ids in hyp: sent = tokenizer.DecodeIds(ids.tolist()) print(sent)
def beam_search(src, net, src_vocab, tgt_vocab, args): """ This implementation of beam search is problematic. log_scores should not include scores of special tokens like <pad>, <s> etc. """ outputs, e_outputs, log_scores = init_vars(src, net, src_vocab, tgt_vocab, args) eos_tok = tgt_vocab.stoi[config.EOS] src_mask = (src != src_vocab.stoi[config.PAD]).unsqueeze(-2) ind = None src_len = src.size(-1) max_len = int(min(args.max_len, args.max_ratio * src_len)) for i in range(1, max_len): tgt_mask = subsequent_mask(i) if args.use_cuda: tgt_mask = tgt_mask.cuda() out = net.generator.proj( net.decode(e_outputs, src_mask, outputs[:, :i], tgt_mask)) out = F.softmax(out, dim=-1) outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, args.k) ones = (outputs == eos_tok).nonzero( ) # Occurrences of end symbols for all input sentences. sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() for vec in ones: i = vec[0] if sentence_lengths[ i] == 0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == args.k: alpha = args.length_penalty div = 1 / (sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break if ind is None: try: length = (outputs[0] == eos_tok).nonzero()[0] return ' '.join( [tgt_vocab.itos[tok] for tok in outputs[0][1:length]]) except IndexError: return ' '.join([tgt_vocab.itos[tok] for tok in outputs[0]]) else: length = (outputs[ind] == eos_tok).nonzero()[0] return ' '.join( [tgt_vocab.itos[tok] for tok in outputs[ind][1:length]])
def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) for i in range(max_len - 1): out = model.decode( memory, src_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.data[0] ys = torch.cat( [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) return ys
def make_std_mask(tgt, pad): # pad_tgt_mask.size(): (batch_size, 1, max_len-1) # pad_tgt_mask is for [padding] mask in each sentence pad_tgt_mask = (tgt != pad).unsqueeze(-2) length = tgt.size(-1) # sub_tgt_mask.size(): (max_len-1, max_len-1) # sub_tgt_mask is for [future word] mask sub_tgt_mask = subsequent_mask(length).type_as(pad_tgt_mask.data) # total_tgt_mask.size(): (batch_size, max_len-1, max_len-1) # total_tgt_mask is for padding and future words mask total_tgt_mask = pad_tgt_mask & sub_tgt_mask return total_tgt_mask
def make_mask(target): """Create a target mask. The mask shape is [batch_size, num_steps, sequence_length]. The mask blocks out both PADs and access to next tokens. Args: target: (torch.Tensor) [batch_size, sequence_length]. Returns: target_mask: (torch.Tensor) [batch_size, num_steps, sequence_length]. """ target_mask = (target != constants.PAD).unsqueeze(-2) target_mask = target_mask & Var( utils.subsequent_mask(target.size(-1)).type_as(target_mask.data)) return target_mask
def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.ones(1, 1, dtype=torch.int64).fill_(start_symbol) for i in range(max_len - 1): out = model.decode( memory.to(device), src_mask, Variable(ys).to(device), Variable(subsequent_mask(ys.size(1)).type( torch.LongTensor)).to(device)) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.data[0] if next_word == 2: break ys = torch.cat( [ys, torch.ones(1, 1, dtype=torch.int64).fill_(next_word)], dim=1) return ys
def greedy_decode(tree_transformer_model, batch, max_len, start_pos): memory, _ = tree_transformer_model.encode(batch.code, batch.par_matrix, batch.bro_matrix, batch.re_par_ids, batch.re_bro_ids) ys = torch.ones(1, 1).fill_(start_pos).type_as(batch.code.data) for i in range(max_len - 1): # memory, code_mask, comment, comment_mask out, _, _ = tree_transformer_model.decode( memory, batch.code_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(batch.code.data))) prob = tree_transformer_model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.data[0] ys = torch.cat( [ys, torch.ones(1, 1).type_as(batch.code.data).fill_(next_word)], dim=1) return ys
def greedy_decode(model, src, src_mask, max_len, start_symbol): """ 传入一个训练好的模型,对指定数据进行预测 """ # 先用encoder进行encode memory = model.encode(src, src_mask) # 初始化预测内容为1×1的tensor,填入开始符('BOS')的id,并将type设置为输入数据类型(LongTensor) ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) # 遍历输出的长度下标 for i in range(max_len - 1): # decode得到隐层表示 out = model.decode( memory, src_mask, Variable(ys), Variable(subsequent_mask(ys.size(1)).type_as(src.data))) # 将隐藏表示转为对词典各词的log_softmax概率分布表示 prob = model.generator(out[:, -1]) # 获取当前位置最大概率的预测词id _, next_word = torch.max(prob, dim=1) next_word = next_word.data[0] # 将当前位置预测的字符id与之前的预测内容拼接起来 ys = torch.cat( [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) return ys
def init_vars(src, net, src_vocab, tgt_vocab, args): # src: torch.LongTensor, (1, src_len) assert src.dim() == 2 src_len = src.size(-1) max_len = int(min(args.max_len, args.max_ratio * src_len)) init_tok = tgt_vocab.stoi[config.BOS] src_mask = (src != src_vocab.stoi[config.PAD]).unsqueeze(-2) e_output = net.encode(src, src_mask) outputs = torch.LongTensor([[init_tok]]) if args.use_cuda: outputs = outputs.cuda() tgt_mask = subsequent_mask(1) if args.use_cuda: tgt_mask = tgt_mask.cuda() out = net.generator.proj(net.decode(e_output, src_mask, outputs, tgt_mask)) out = F.softmax(out, dim=-1) probs, ix = out[:, -1].data.topk(args.k) log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0) outputs = torch.zeros(args.k, max_len).long() if args.use_cuda: outputs = outputs.cuda() outputs[:, 0] = init_tok outputs[:, 1] = ix[0] e_outputs = torch.zeros(args.k, e_output.size(-2), e_output.size(-1)) if args.use_cuda: e_outputs = e_outputs.cuda() e_outputs[:, :] = e_output[0] return outputs, e_outputs, log_scores
def train(model, optimizer, config, best_valid): max_epoch = config.get("max_epoch", int(1e9)) max_iteration = config.get("max_iteration", int(1e9)) max_sentences = config.get("max_sentences", 1e9) max_tokens = config.get("max_tokens", 1e9) update_freq = config.get('update_freq', 1) optimizer.add(model) corpus_prefix = Path(config['corpus_prefix']) / "subword" model_path = corpus_prefix / "spm.model" tokenizer = spm.SentencePieceProcessor() tokenizer.Load(str(model_path)) train_src = load_corpus(corpus_prefix / Path(config["train_source"]).name, tokenizer) train_trg = load_corpus(corpus_prefix / Path(config["train_target"]).name, tokenizer) train_src, train_trg = clean_corpus(train_src, train_trg, config) dev_src = load_corpus(corpus_prefix / Path(config["dev_source"]).name, tokenizer) dev_trg = load_corpus(corpus_prefix / Path(config["dev_target"]).name, tokenizer) dev_src, dev_trg = clean_corpus(dev_src, dev_trg, config) num_train_sents = len(train_src) num_dev_sents = len(dev_src) eos_id = tokenizer.eos_id() epoch = 0 iteration = 0 while epoch < max_epoch and iteration < max_iteration: epoch += 1 g = Graph() Graph.set_default(g) train_itr = create_batch_itr(train_src, train_trg, max_tokens, max_sentences, shuffle=True) train_itr = tqdm(train_itr, desc='train epoch {}'.format(epoch)) train_loss = 0. itr_loss = 0. itr_tokens = 0 itr_sentences = 0 optimizer.reset_gradients() for step, batch_ids in enumerate(train_itr): src_batch = make_batch(train_src, batch_ids, eos_id) trg_batch = make_batch(train_trg, batch_ids, eos_id) src_mask = padding_mask(src_batch, eos_id) trg_mask = [ x | subsequent_mask(len(trg_batch) - 1) for x in padding_mask(trg_batch[:-1], eos_id) ] itr_tokens += len(src_batch) * len(src_batch[0]) itr_sentences += len(batch_ids) g.clear() loss = model.loss(src_batch, trg_batch, src_mask, trg_mask) loss /= update_freq loss.backward() loss_val = loss.to_float() train_loss += loss_val * update_freq * len(batch_ids) itr_loss += loss_val # with open('graph.dot', 'w') as f: # print(g.dump("dot"), end="", file=f) if (step + 1) % update_freq == 0: step_num = optimizer.get_epoch() + 1 new_scale = config['d_model'] ** (-0.5) * \ min(step_num ** (-0.5), step_num * config['warmup_steps'] ** (-1.5)) optimizer.set_learning_rate_scaling(new_scale) optimizer.update() optimizer.reset_gradients() iteration += 1 train_itr.set_postfix(itr=("%d" % (iteration)), loss=("%.3lf" % (itr_loss)), wpb=("%d" % (itr_tokens)), spb=("%d" % (itr_sentences)), lr=optimizer.get_learning_rate_scaling()) itr_loss = 0. itr_tokens = 0 itr_sentences = 0 if iteration >= max_iteration: break print("\ttrain loss = %.4f" % (train_loss / num_train_sents)) g.clear() valid_loss = 0. valid_itr = create_batch_itr(dev_src, dev_trg, max_tokens, max_sentences, shuffle=False) valid_itr = tqdm(valid_itr, desc='valid epoch {}'.format(epoch)) for batch_ids in valid_itr: src_batch = make_batch(dev_src, batch_ids, eos_id) trg_batch = make_batch(dev_trg, batch_ids, eos_id) src_mask = padding_mask(src_batch, eos_id) trg_mask = [ x | subsequent_mask(len(trg_batch) - 1) for x in padding_mask(trg_batch[:-1], eos_id) ] loss = model.loss(src_batch, trg_batch, src_mask, trg_mask, train=False) valid_loss += loss.to_float() * len(batch_ids) valid_itr.set_postfix(loss=loss.to_float()) print("\tvalid loss = %.4f" % (valid_loss / num_dev_sents)) if valid_loss < best_valid: best_valid = valid_loss print('\tsaving model/optimizer ... ', end="", flush=True) prefix = config['model_prefix'] model.save(prefix + '.model') optimizer.save(prefix + '.optimizer') with Path(prefix).with_suffix('.valid').open('w') as f: f.write(str(best_valid)) print('done.')
def make_std_mask(comment, pad): comment_mask = (comment != pad).unsqueeze(-2) tgt_mask = comment_mask & Variable( subsequent_mask(comment.size(-1)).type_as(comment_mask.data)) return tgt_mask
def generate_beam(args, net, src, src_mask, src_vocab, tgt_vocab): """ Decode a sentence given initial start. `x`: - LongTensor(bs, slen) <EOS> W1 W2 W3 <EOS> <PAD> <EOS> W1 W2 W3 W4 <EOS> `lengths`: - LongTensor(bs) [5, 6] """ max_len = args.max_len beam_size = args.beam_size length_penalty = args.length_penalty early_stopping = args.early_stopping with torch.no_grad(): net.eval() # batch size / number of words n_words = net.generator.n_vocab bs = src.size(0) # calculate encoder output src_enc = net.encode(src=src, src_mask=src_mask) src_len = src_mask.view(bs, -1).sum(dim=-1).long() # check inputs assert src_enc.size(0) == src_len.size(0) assert beam_size >= 1 # expand to beam size the source latent representations / source lengths src_enc = src_enc.unsqueeze( 1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size, ) + src_enc.shape[1:]) src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1) src_mask = src_mask.unsqueeze(1).expand( (bs, beam_size) + src_mask.shape[1:]).contiguous().view((bs * beam_size, ) + src_mask.shape[1:]) # generated sentences (batch with beam current hypotheses) generated = src_len.new(max_len, bs * beam_size) # upcoming output generated.fill_( tgt_vocab.stoi[config.PAD]) # fill upcoming ouput with <PAD> generated[0].fill_( tgt_vocab.stoi[config.BOS]) # we use <EOS> for <BOS> everywhere # generated hypotheses generated_hyps = [ BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs) ] # scores for each sentence in the beam beam_scores = src_enc.new(bs, beam_size).fill_(0) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view(-1) # current position cur_len = 1 # cache compute states cache = {'cur_len': 0} # done sentences done = [False for _ in range(bs)] while cur_len < max_len: # compute word scores x = generated[cur_len - 1].unsqueeze(-1) tgt_mask = (generated[:cur_len] != tgt_vocab.stoi[config.PAD]).transpose(0, 1).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(cur_len).type_as(tgt_mask.data)) tensor = net.decode(src_enc, src_mask, x, tgt_mask[:, cur_len - 1, :].unsqueeze(-2), cache) assert tensor.size() == (bs * beam_size, 1, config.d_model) tensor = tensor.data.view(bs * beam_size, -1) # (bs * beam_size, dim) scores = net.generator(tensor) # (bs * beam_size, n_words) assert scores.size() == (bs * beam_size, n_words) # select next words with scores _scores = scores + beam_scores[:, None].expand_as( scores) # (bs * beam_size, n_words) _scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words) next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True) assert next_scores.size() == next_words.size() == (bs, 2 * beam_size) # next batch beam content # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch) next_batch_beam = [] # for each sentence for sent_id in range(bs): # if we are done with this sentence done[sent_id] = done[sent_id] or generated_hyps[ sent_id].is_done(next_scores[sent_id].max().item()) if done[sent_id]: next_batch_beam.extend( [(0, tgt_vocab.stoi[config.PAD], 0)] * beam_size) # pad the batch continue # next sentence beam content next_sent_beam = [] # next words for this sentence for idx, value in zip(next_words[sent_id], next_scores[sent_id]): # get beam and word IDs beam_id = idx // n_words word_id = idx % n_words # end of sentence, or next word if word_id == tgt_vocab.stoi[ config.EOS] or cur_len + 1 == max_len: generated_hyps[sent_id].add( generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item()) else: next_sent_beam.append( (value, word_id, sent_id * beam_size + beam_id)) # the beam for next step is full if len(next_sent_beam) == beam_size: break # update next beam content assert len(next_sent_beam ) == 0 if cur_len + 1 == max_len else beam_size if len(next_sent_beam) == 0: next_sent_beam = [(0, tgt_vocab.stoi[config.PAD], 0) ] * beam_size # pad the batch next_batch_beam.extend(next_sent_beam) assert len(next_batch_beam) == beam_size * (sent_id + 1) # sanity check / prepare next batch assert len(next_batch_beam) == bs * beam_size beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) beam_words = generated.new([x[1] for x in next_batch_beam]) beam_idx = src_len.new([x[2] for x in next_batch_beam]) # re-order batch and internal states generated = generated[:, beam_idx] generated[cur_len] = beam_words for k in cache.keys(): if k != 'cur_len': cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) # update current length cur_len = cur_len + 1 cache['cur_len'] = cur_len - 1 # stop when we are done with each sentence if all(done): break # select the best hypotheses tgt_len = src_len.new(bs) best = [] for i, hypotheses in enumerate(generated_hyps): best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol best.append(best_hyp) # generate target batch decoded = src_len.new(tgt_len.max().item(), bs).fill_(tgt_vocab.stoi[config.PAD]) for i, hypo in enumerate(best): decoded[:tgt_len[i] - 1, i] = hypo decoded[tgt_len[i] - 1, i] = tgt_vocab.stoi[config.PAD] return decoded.transpose(0, 1), tgt_len
def _model_decode(self, comment, memory, code_mask): comment_mask = subsequent_mask(comment.size(1)).cuda() dec_output, _, _ = self.model.decode(memory, code_mask, comment, comment_mask) return self.model.generator(dec_output)
def _get_scores(args, net, active_func, src, src_mask, indices, src_vocab, tgt_vocab): net.eval() slen = src.size(1) max_len = min(args.max_len, int(slen * args.gen_a + args.gen_b)) average_by_length = (active_func != "tte") result = [] with torch.no_grad(): bsz = src.size(0) enc_out = net.encode(src=src, src_mask=src_mask) generated = src.new(bsz, max_len) generated.fill_(tgt_vocab.stoi[config.PAD]) generated[:, 0].fill_(tgt_vocab.stoi[config.BOS]) generated = generated.long() max_gen_len = src_mask.sum(dim=-1).float() * args.gen_a + args.gen_b max_gen_len = max_gen_len.long().view(bsz, ) max_gen_len = torch.clamp(max_gen_len, min=0, max=max_len) cur_len = 1 gen_len = src.new_ones(bsz).long() unfinished_sents = src.new_ones(bsz).long() query_scores = src.new_zeros(bsz).float() cache = {'cur_len': cur_len - 1} while cur_len < max_len: x = generated[:, cur_len - 1].unsqueeze(-1) tgt_mask = (generated[:, :cur_len] != tgt_vocab.stoi[config.PAD]).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(cur_len).type_as(tgt_mask.data)) logit = net.decode(enc_out, src_mask, x, tgt_mask[:, cur_len - 1, :].unsqueeze(-2), cache) scores = net.generator(logit).exp().view(bsz, -1).data # Calculate activation function value # The smaller query score is, the more uncertain model is about the sentence if active_func == "lc": q_scores, _ = torch.topk(scores, 1, dim=-1) q_scores = -(1.0 - q_scores).squeeze() elif active_func == "margin": q_scores, _ = torch.topk(scores, 2, dim=-1) q_scores = q_scores[:, 0] - q_scores[:, 1] elif active_func == "te" or active_func == "tte": q_scores = -torch.distributions.categorical.Categorical( probs=scores).entropy() else: q_scores = scores.new_zeros(bsz) q_scores = q_scores.view(bsz) assert q_scores.size() == (bsz, ), q_scores query_scores = query_scores + unfinished_sents.float() * q_scores next_words = torch.topk(scores, 1)[1].squeeze() next_words = next_words.view(bsz) assert next_words.size() == (bsz, ) generated[:, cur_len] = next_words * unfinished_sents + tgt_vocab.stoi[ config.PAD] * (1 - unfinished_sents) gen_len.add_(unfinished_sents) unfinished_sents.mul_( next_words.ne(tgt_vocab.stoi[config.EOS]).long()) unfinished_sents.mul_(gen_len.le(max_gen_len).long()) cur_len = cur_len + 1 cache['cur_len'] = cur_len - 1 if unfinished_sents.max() == 0: break if cur_len == max_len: generated[:, -1].masked_fill_(unfinished_sents.bool(), tgt_vocab.stoi[config.EOS]) translated = gen_batch2str(generated[:, :cur_len], gen_len, tgt_vocab) if average_by_length: query_scores = query_scores / gen_len.float() query_scores = query_scores.cpu().numpy().tolist() indices = indices.tolist() assert len(query_scores) == len(indices) assert len(query_scores) == len(translated) for i in range(len(query_scores)): result.append((query_scores[i], indices[i], translated[i])) return result
def make_std_mask(tgt, pad_token_idx): target_mask = (tgt != pad_token_idx).unsqueeze(-2) # make look-ahead mask - 하나씩 늘려가면서 target_mask = target_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(target_mask.data)) return target_mask.squeeze()