def beam_search(src, model, SRC, TRG, opt): outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt) eos_tok = TRG.vocab.stoi['<eos>'] src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) ind = None for i in range(2, opt.max_len): trg_mask = nopeak_mask(i, opt) out = model.out( model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask)) out = F.softmax(out, dim=-1) outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k) ones = (outputs == eos_tok).nonzero( ) # Occurrences of end symbols for all input sentences. sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() for vec in ones: i = vec[0] if sentence_lengths[ i] == 0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == opt.k: alpha = 0.7 div = 1 / (sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break if ind is None: length = (outputs[0] == eos_tok).nonzero()[0] return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]]) else: length = (outputs[ind] == eos_tok).nonzero()[0] return ' '.join( [TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])
def init_vars(src, model, SRC, TRG, opt): init_tok = TRG.vocab.stoi['<sos>'] src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) # this is the output from the encoder e_output = model.encoder(src, src_mask) # this is initializing the outputs outputs = torch.LongTensor([[init_tok]]) if opt.device == 0: outputs = outputs.cuda() trg_mask = nopeak_mask(1, opt) src_mask = src_mask.cuda() trg_mask = trg_mask.cuda() outputs = outputs.cuda() e_output = e_output.cuda() out = model.out(model.decoder(outputs, e_output, src_mask, trg_mask)) out = F.softmax(out, dim=-1) probs, ix = out[:, -1].data.topk(opt.k) preds_token_ids = ix.view(ix.size(0), -1) pred_strings = [ ' '.join([TRG.vocab.itos[ind] for ind in ex]) for ex in preds_token_ids ] # print (pred_strings) log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0) outputs = torch.zeros(opt.k, opt.max_len).long() if opt.device == 0: outputs = outputs.cuda() outputs[:, 0] = init_tok outputs[:, 1] = ix[0] e_outputs = torch.zeros(opt.k, e_output.size(-2), e_output.size(-1)) if opt.device == 0: e_outputs = e_outputs.cuda() e_outputs[:, :] = e_output[0] return outputs, e_outputs, log_scores
def init_vars(src, model, SRC, TRG, opt, recursive=False, query_mask=None): init_tok = TRG.vocab.stoi['<sos>'] if recursive: src_mask = None e_holder = [] for i, c in enumerate(src): e_holder.append( model.encoder([c[0], [e_holder[x] for x in c[1]]], src_mask)) e_output = e_holder[-1] elif query_mask is not None: src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) e_output = model.encoder(src, src_mask, query_mask=query_mask) else: src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) e_output = model.encoder(src, src_mask) outputs = torch.LongTensor([[init_tok]]) if opt.device == 0: outputs = outputs.cuda() trg_mask = nopeak_mask(1, opt) out = model.out(model.decoder(outputs, e_output, src_mask, trg_mask)) out = F.softmax(out, dim=-1) probs, ix = out[:, -1].data.topk(opt.k) log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0) outputs = torch.zeros(opt.k, opt.max_len).long() if opt.device == 0: outputs = outputs.cuda() outputs[:, 0] = init_tok outputs[:, 1] = ix[0] e_outputs = torch.zeros(opt.k, e_output.size(-2), e_output.size(-1)) if opt.device == 0: e_outputs = e_outputs.cuda() e_outputs[:, :] = e_output[0] return outputs, e_outputs, log_scores
def beam_search(src, model, SRC, TRG, opt, recursive=False, query_mask=None): outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt, query_mask=query_mask) eos_tok = TRG.vocab.stoi['<eos>'] if recursive: src_mask = None else: src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) ind = None for i in range(2, opt.max_len): trg_mask = nopeak_mask(i, opt) out = model.out( model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask)) out = F.softmax(out, dim=-1) outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k) ones = torch.nonzero(torch.eq(outputs, eos_tok), as_tuple=False) sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() for vec in ones: i = vec[0] if sentence_lengths[ i] == 0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == opt.k: alpha = 0.7 div = 1 / (sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break if ind is None: try: length = torch.nonzero(torch.eq(outputs[0], eos_tok), as_tuple=False)[0] except IndexError: length = opt.max_len built_string = '' for tok in outputs[0][1:length]: temp_tok = TRG.vocab.itos[tok] built_string += ' ' + temp_tok return built_string.strip() else: try: length = torch.nonzero(torch.eq(outputs[ind], eos_tok), as_tuple=False)[0] except IndexError: length = opt.max_len built_string = '' for tok in outputs[ind][1:length]: temp_tok = TRG.vocab.itos[tok] built_string += ' ' + temp_tok return built_string.strip()
def beam_search(src, model, SRC, TRG, opt): outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt) eos_tok = TRG.vocab.stoi['<eos>'] src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) ind = None query = {} query_tokens = [] for i in range(2, opt.max_len): trg_mask = nopeak_mask(i, opt) src_mask = src_mask.cuda() trg_mask = trg_mask.cuda() out = model.out( model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask)) # print (outputs.size()) # print (out.size()) out = F.softmax(out, dim=-1) # print("output data shape") # print(out.data.shape) outputs, log_scores, pred_strings, pred_strings_dict = k_best_outputs( outputs, out, log_scores, i, opt.k, TRG) # This part is another way of forming the query dictionary for pred_string in pred_strings: pred_string_splitted = pred_string.split() for st in pred_string_splitted: query.setdefault(st, 1.0) query[st] = query[st] + 1 query_tokens.extend(pred_string_splitted) for term in pred_strings_dict: if term in query: if pred_strings_dict[term] > query[term]: query[term] = pred_strings_dict[term] else: query[term] = pred_strings_dict[term] if (outputs == eos_tok).nonzero().size(0) == opt.k: alpha = 0.7 div = 1 / ((outputs == eos_tok).nonzero()[:, 1].type_as(log_scores) **alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break # print("query") # print(query) # if ind is None: # length = (outputs[0] == eos_tok).nonzero()[0] # return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]]) # # else: # length = (outputs[ind] == eos_tok).nonzero()[0] # return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]]) if ind is None: query_list = [] # print("value of k is " + str(opt.k)) for i in np.arange(opt.k): if eos_tok in outputs[i]: length = (outputs[i] == eos_tok).nonzero()[0] else: length = opt.max_len query_list.append(' '.join( [TRG.vocab.itos[tok] for tok in outputs[i][1:length]])) return query_list, query, query_tokens # if (outputs[0]==eos_tok).nonzero().size(0) >= 1: # length = (outputs[0]==eos_tok).nonzero()[0] # return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]]) # else: # return ' ' else: # if (outputs[ind] == eos_tok).nonzero().size(0) >= 1: # length = (outputs[ind]==eos_tok).nonzero()[0] # return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]]) # else: # return ' ' query_list = [] # print("value of k is " + str(opt.k)) for i in np.arange(opt.k): if eos_tok in outputs[i]: length = (outputs[i] == eos_tok).nonzero()[0] else: length = opt.max_len query_list.append(' '.join( [TRG.vocab.itos[tok] for tok in outputs[i][1:length]])) return query_list, query, query_tokens
def beam_search(src, model, vocab, opt): outputs, e_outputs, log_scores = init_vars(src, model, vocab, opt) #print('\nbeam_search start\n') #print(f'outputs: {outputs.size()}') #print(f'e_outputs: {e_outputs.size()}') #print(f'log_scores: {log_scores.size()}') eos_tok = vocab.stoi['<eos>'] src_mask = (src != vocab.stoi['<pad>']).unsqueeze(-2) ind = None for i in range(2, opt.max_len): trg_mask = nopeak_mask(i, opt) out = model.out( model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask, policy=True)) out = F.softmax(out, dim=-1) outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k) ones = (outputs == eos_tok).nonzero( ) # Occurrences of end symbols for all input sentences. sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() for vec in ones: i = vec[0] if sentence_lengths[ i] == 0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == opt.k: alpha = 0.7 div = 1 / (sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break length = [] if ind is None: for output in outputs: length.append((output == eos_tok).nonzero()[0]) result = [] for idx in range(opt.k): result.append(' '.join( [vocab.itos[tok] for tok in outputs[idx][1:length[idx]]])) #return ' '.join([vocab.itos[tok] for tok in outputs[ind][1:length]]) return result else: for output in outputs: length.append((output == eos_tok).nonzero()[0]) result = [] for idx in range(opt.k): result.append(' '.join( [vocab.itos[tok] for tok in outputs[idx][1:length[idx]]])) #return ' '.join([vocab.itos[tok] for tok in outputs[ind][1:length]]) return result
def beam_search(src, model, src_vocab, trg_vocab, opt): model.eval() outputs, e_outputs, log_scores = init_vars(src, model, src_vocab, trg_vocab, opt) eos_tok = trg_vocab.eos_idx src_mask = (src != src_vocab.pad_idx).unsqueeze(-2) ind = None for i in range(2, opt.max_trg_len): trg_mask = nopeak_mask(i, opt) out = model.out( model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask)) outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.beam_size) ones = (outputs == eos_tok).nonzero( as_tuple=True ) # Occurrences of end symbols for all input sentences. x, y = ones ones = list(zip(x.detach().cpu().numpy(), y.detach().cpu().numpy())) sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(opt.device) for vec in ones: i = vec[0] if sentence_lengths[ i] == 0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == opt.beam_size: alpha = 0.7 div = 1 / (sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break pad_token = trg_vocab.pad_idx if ind is None: length = (outputs[0] == eos_tok).nonzero(as_tuple=True)[0] outputs = outputs.detach().cpu().numpy() try: return ' '.join( [trg_vocab.itos[tok] for tok in outputs[0][1:length]]) except: return ' '.join([trg_vocab.itos[tok] for tok in outputs[0][1:]]) else: length = (outputs[ind] == eos_tok).nonzero(as_tuple=True)[0] outputs = outputs.detach().cpu().numpy() try: return ' '.join([ trg_vocab.itos[tok] for tok in outputs[ind][1:length] if tok != pad_token and tok != eos_tok ]) except: return ' '.join([ trg_vocab.itos[tok] for tok in outputs[ind][1:] if tok != pad_token and tok != eos_tok ])
def beam_search(src, model, SRC, TRG, opt): eos_tok = TRG.vocab.stoi['<eos>'] pad_tok = SRC.vocab.stoi['<pad>'] ind = None if opt.nmt_model_type == 'rnn_naive_model': tensor_to_fill_max_len = torch.full((1, opt.max_len - src.shape[1]), pad_tok).to(opt.device) src = torch.cat((src, tensor_to_fill_max_len), dim=1) outputs, encoder_outputs, log_scores = init_vars(src, model, SRC, TRG, opt) encoder_hidden = encoder_outputs decoder_hidden = encoder_hidden elif opt.nmt_model_type == 'transformer': src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2) outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, opt) # [SRC.vocab.itos[i] for i in src.tolist()[0]] to debug for i in range(2, opt.max_len): # we already filled init_tok and some of most probable translations if opt.nmt_model_type == 'transformer': # keep increasing size of sentence trg_mask = nopeak_mask(i, opt) out = model.out(model.decoder(outputs[:,:i], e_outputs, src_mask, trg_mask)) # [[3, 2], [3, 7, 300], [1, 1, 7], [1, 2, 2]] -> [3, 2, 300] -> [3, 2, 11436] out = F.softmax(out, dim=-1) elif opt.nmt_model_type == 'rnn_naive_model': decoder_input = torch.zeros(opt.k, src.shape[1]).long().to(opt.device) # TODO change to opt.max_len decoder_input[:, :i] = outputs[:, :i] # OPTION 2 - input a tensor of size src.shape[1] and fill up the other numbers with <unk> for j in range(opt.k): out_piece, decoder_hidden_piece = model.decoder(decoder_input[j, :], decoder_hidden[j, :].unsqueeze(0), encoder_outputs[j, :].unsqueeze(0)) if j == 0: out = out_piece[:i, :].unsqueeze(0) decoder_hidden_carry = decoder_hidden_piece[:i, :] else: out = torch.cat([out, out_piece[:i, :].unsqueeze(0)], dim=0) # final shape: [src.shape[1]*3, vocab_size] decoder_hidden_carry = torch.cat([decoder_hidden_carry, decoder_hidden_piece[:i, :]], dim=0) out = F.softmax(out, dim=-1) decoder_hidden = decoder_hidden_carry outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.k, TRG, SRC) # (torch.Size([3, 100]), torch.Size([3, 2, 11436])) ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences. sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(opt.device) for vec in ones: i = vec[0] if sentence_lengths[i]==0: # First end symbol has not been found yet sentence_lengths[i] = vec[1] # Position of first end symbol num_finished_sentences = len([s for s in sentence_lengths if s > 0]) if num_finished_sentences == opt.k: alpha = 0.7 div = 1/(sentence_lengths.type_as(log_scores)**alpha) _, ind = torch.max(log_scores * div, 1) ind = ind.data[0] break if ind is None: length = (outputs[0]==eos_tok).nonzero() if len(length) != 0: return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length[0]]]) else: eos_3_hypothesis = torch.ones([3, 1], dtype=torch.int64).to(opt.device)*eos_tok outputs = torch.cat((outputs, eos_3_hypothesis), dim=1) return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:]]) else: length = (outputs[ind]==eos_tok).nonzero() if len(length) != 0: return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length[0]]]) else: eos_3_hypothesis = torch.ones([3, 1], dtype=torch.int64).to(opt.device)*eos_tok outputs = torch.cat((outputs, eos_3_hypothesis), dim=1) return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:]])