def compute_batched_losses(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None): batch_size = input_ids.shape[0] doc_count = 20 input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, _ = parse_raw(enc, input_ids, doc_count, args, batch_size) if cluster_queries_str == -1: return -1, -1, -1, -1, -1 max_passage_len = input_ids_list.shape[2] # clone baseline indiv_idx_pointer for baselines that need to generate baseline_indiv_idx_pointer=None if args.use_baseline_2 or args.use_baseline_3: baseline_indiv_idx_pointer = indiv_idx_pointer[:] out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq = multisource_batch_generate_sequence( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False if eval_mode else args.is_sampling, # argmax when evaluating. device=args.device, coordinator=coordinator, indiv_idx_pointer=indiv_idx_pointer, args=args ) # convert generations to tokens out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] #TODO need batched here and below # if using individual generations for baseline 1 indiv_out_decoded = [] if args.use_baseline_1: for s in indiv_out: indiv_out_decoded.append([enc.decode(cut_seq_to_eos(w)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for w in s.cpu().numpy()]) # Auxiliary SL if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'sl' or eval_mode == True: sl_loss, ent_loss = auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq) else: sl_loss, ent_loss = torch.FloatTensor([0.0]).to(args.device), torch.FloatTensor([0.0]).to(args.device) # Main RL routine if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'rl' or eval_mode == True: reward = compute_reward_and_baselines(batch_size, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded, eval_mode, indiv_out_decoded, model, input_ids_list, enc, baseline_indiv_idx_pointer, coordinator, cluster_queries_str) rl_loss = torch.sum( - out_total_log_probs.squeeze(1) * torch.from_numpy(reward).float().to(args.device)) else: rl_loss = torch.FloatTensor([0.0]).to(args.device) # Training options for different components of the loss function loss = aggregate_loss(args, rl_loss, sl_loss, ent_loss) return rl_loss.item(), sl_loss.item(), loss, torch.mean(torch.from_numpy(reward).float().to(args.device)), out_decoded
def complete_eval_attention_weights(model, enc, input_ids, args, eval_mode, coordinator, genenc): batch_size = input_ids.shape[0] doc_count = 20 input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2 = parse_raw_non_retrieved(enc, input_ids, doc_count, args, batch_size) if cluster_queries_str == -1: return -1, -1, -1, -1, -1, -1, -1, -1,-1, -1, -1, -1, -1 max_passage_len = input_ids_list.shape[2] reference_corr = -1.0 if genenc: try: try: reference_corr = dot(genenc[cluster_queries_str[0]], genenc[cluster_queries_str2[0]]) / ( norm(genenc[cluster_queries_str[0]]) * norm(genenc[cluster_queries_str2[0]])) except: pass try: reference_corr = dot(genenc[cluster_queries_str[0].split('1.0 ')[1]], genenc[cluster_queries_str2[0]]) / ( norm(genenc[cluster_queries_str[0].split('1.0 ')[1]]) * norm(genenc[cluster_queries_str2[0]])) except: pass except: pass # clone baseline indiv_idx_pointer for baselines that need to generate baseline_indiv_idx_pointer=None if args.use_baseline_2 or args.use_baseline_3: baseline_indiv_idx_pointer = indiv_idx_pointer[:] out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq, balancer_seq = multisource_batch_generate_sequence_for_attentions( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False if eval_mode else args.is_sampling, # argmax when evaluating. device=args.device, coordinator=coordinator, indiv_idx_pointer=indiv_idx_pointer, args=args, ) # convert generations to tokens out_decoded = [enc.splitted_decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] assert isinstance(out_decoded, list) return out_decoded, attention_seq, balancer_seq, reference_corr
def BaselineThree(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, coordinator): with torch.no_grad(): base_out, base_out_total_log_probs, base_common_vocab_dist_seq, base_indiv_vocab_dist_seq, base_indiv_out = multisource_generate_sequence( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False, # KEY PART device=args.device, coordinator=coordinator ) # convert generations to tokens base_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in base_out.cpu().numpy()][0] # ranker scoring + compute rewards base_reward = RankerRewards(ranker, ranker_tokenizer, base_out_decoded, passage_str_list, max_passage_len, args) return base_reward
def BaselineTwoBatched(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, indiv_idx_pointer, eval_mode=False): with torch.no_grad(): # generate out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out = multisource_batch_generate_sequence( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False, # argmax when evaluating. device=args.device, coordinator=None, indiv_idx_pointer=indiv_idx_pointer ) base_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] # ranker scoring + compute rewards base_reward = RankerRewardsBatched(ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded = base_out_decoded, eval_mode=eval_mode) return base_reward
def auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq): ent_loss, positive_cluster_loss, negative_cluster_loss = torch.FloatTensor([0.0]).to(args.device), \ torch.FloatTensor([0.0]).to(args.device), \ torch.FloatTensor([0.0]).to(args.device) for m in range(batch_size): batch_positive_cluster_loss = 0.0 batch_negative_cluster_loss = 0.0 step_ent_loss = 0.0 decode_len = len(cut_seq_to_eos(out[m].cpu().numpy())) for i in range(decode_len): step_ent_loss += torch.sum(F.softmax(attention_seq[i], dim=-1) * F.log_softmax(attention_seq[i], dim=-1)) # maximize entropy to be near uniform for j in range(10): batch_positive_cluster_loss += F.kl_div(common_vocab_dist_seq[i][m], torch.exp(indiv_vocab_dist_seq[i][m][j])) + F.kl_div( indiv_vocab_dist_seq[i][m][j], torch.exp(common_vocab_dist_seq[i][m])) positive_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][m][:10], 0) - torch.log \ (torch.from_numpy(np.array(float(10)))) negative_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][m][11:], 0) - torch.log \ (torch.from_numpy(np.array(float(10)))) sim_val = F.cosine_similarity(torch.exp(positive_cluster_avg_probs), torch.exp(negative_cluster_avg_probs), dim=0) del positive_cluster_avg_probs, negative_cluster_avg_probs for j in range(11, 20): negative_cluster_kl = F.kl_div(common_vocab_dist_seq[i][m], torch.exp(indiv_vocab_dist_seq[i][m][j])) + F.kl_div( indiv_vocab_dist_seq[i][m][j], torch.exp(common_vocab_dist_seq[i][m])) batch_negative_cluster_loss += negative_cluster_kl if sim_val * negative_cluster_kl < batch_positive_cluster_loss.item() / 10 else 0.0 # torch.FloatTensor([0.0]).to(args.device) positive_cluster_loss += batch_positive_cluster_loss negative_cluster_loss += batch_negative_cluster_loss sl_loss = positive_cluster_loss - negative_cluster_loss # entropy loss ent_loss += step_ent_loss return sl_loss, ent_loss
def run_model(): print(socket.gethostname()) parser = argparse.ArgumentParser() parser.add_argument( '--model_name_or_path', type=str, default='', help='pretrained model name or path to local checkpoint') parser.add_argument("--seed", type=int, default=42) parser.add_argument("--load_checkpoint", '-c', type=str, default='') parser.add_argument("--fp16", type=boolean_string, default=False) parser.add_argument("--test_file", '-t', type=str, default=None, help='input file for testing') parser.add_argument("--output_file", '-o', type=str, default=None, help='output file for testing') parser.add_argument("--normalize_data", type=boolean_string, default=True) parser.add_argument("--batch_size", '-b', type=int, default=256) parser.add_argument("--max_seq_length", type=int, default=512) parser.add_argument("--no_token_id", action='store_true') parser.add_argument("--no_attn_mask", action='store_true') parser.add_argument("--no_eos", action='store_true') parser.add_argument("--generation_length", type=int, default=20) parser.add_argument("--temperature", type=float, default=1) parser.add_argument("--top_k", type=int, default=0) parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') parser.add_argument('--is_sampling', action='store_true', help='If true, sampling for generation.') parser.add_argument('--output_ref', action='store_true', help='If true, output ref') #BEAM parser.add_argument("--beam", action='store_true', help='If true, beam search') parser.add_argument("--beam_width", type=int, default=1) parser.add_argument('--use_gpu', action='store_true') parser.add_argument("--gpu", type=int, default=0) parser.add_argument('--config', help='JSON config file') parser.add_argument('--eval', action='store_true') parser.add_argument('--cstr_decode', action='store_true') parser.add_argument("--bonus", type=float, default=0.0) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if args.config is not None: # override argparse defaults by config JSON opts = json.load(open(args.config)) for k, v in opts.items(): if isinstance(v, str): # PHILLY ENV special cases if 'PHILLY_JOB_DIRECTORY' in v: v = v.replace('PHILLY_JOB_DIRECTORY', os.environ['PHILLY_JOB_DIRECTORY']) elif 'PHILLY_LOG_DIRECTORY' in v: v = v.replace('PHILLY_LOG_DIRECTORY', os.environ['PHILLY_LOG_DIRECTORY']) setattr(args, k, v) # command line should override config JSON argv = sys.argv[1:] overrides, _ = parser.parse_known_args(argv) for k, v in vars(overrides).items(): if f'--{k}' in argv: setattr(args, k, v) # setattr(args, 'local_rank', overrides.local_rank) # do normal parsing device = torch.device( "cuda" if torch.cuda.is_available() and args.use_gpu else "cpu") n_gpu = torch.cuda.device_count() args.device, args.n_gpu = device, n_gpu print(args) np.random.seed(args.seed) torch.random.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) config = GPT2Config.from_json_file( os.path.join(args.model_name_or_path, 'config.json')) enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) model = load_model(GPT2LMHeadModel(config), args.load_checkpoint, args, verbose=True) model.to(device) model.eval() if args.test_file: eval_dataloader = get_eval_list_same_length_with_order( args.test_file, enc, args.batch_size, True) model.eval() outs = [] targets = [] loss_all = [] ppl_all = [] sources = [] conv_ids = [] with torch.no_grad(): with tqdm.tqdm(total=len(eval_dataloader), desc=f"Test") as pbar: for step, batch in enumerate( tqdm.tqdm(eval_dataloader, desc="Iteration")): new_batch = [] for t in batch: if isinstance(t, list): new_batch.append(t) else: new_batch.append(t.to(device)) input_ids, position_ids, token_ids, attn_masks, label_ids, context_len, conv_id = new_batch if args.no_token_id: token_ids = None if args.no_eos: input_ids = input_ids[:, :-1] if args.no_attn_mask: attn_masks = None if args.beam: out = beam_search_naive(model, input_ids, position_ids=position_ids, token_type_ids=token_ids, attn_masks=attn_masks, length=args.generation_length, beam_width=args.beam_width, device=args.device, use_bonus=args.cstr_decode, bonus=args.bonus, enc=enc) else: out = generate_sequence(model, input_ids, position_ids=position_ids, token_type_ids=token_ids, attn_masks=attn_masks, length=args.generation_length, start_token=None, temperature=args.temperature, top_k=args.top_k, sample=args.is_sampling, use_bonus=args.cstr_decode, bonus=args.bonus, enc=enc) sources.extend(input_ids.cpu().numpy()) out = out.tolist() outs.extend(out) targets.extend(label_ids) conv_ids.extend(conv_id.cpu().numpy()) conv_id_map = {conv_ids[i]: i for i in range(len(conv_ids))} val_src = [ enc.decode( cut_seq_to_eos(s)).encode('utf-8').decode('utf-8') for s in sources ] #print(len(val_src),len(targets)) val_set = [ enc.decode(s).encode('utf-8').decode('utf-8') for s in targets ] gen = [ enc.decode( cut_seq_to_eos(s)).encode('utf-8').decode('utf-8') for s in outs ] val_src_orders = [ val_src[conv_id_map[i]] for i in sorted(conv_id_map) ] val_set_orders = [ val_set[conv_id_map[i]] for i in sorted(conv_id_map) ] gen_orders = [gen[conv_id_map[i]] for i in sorted(conv_id_map)] print("=" * 40 + " SAMPLE " + "=" * 40) src = enc.decode([ x for x in input_ids[-1].cpu().numpy() if x != 0 ]).encode('utf-8').decode('utf-8') gt = val_set[-1] resp = gen[-1] print( f"Source: \t {src} \n Oracle: \t {gt} \n Resp: \t {resp}\n" ) if args.output_file: with open(args.output_file + '.resp.txt', "w") as resp_f: for i, r in enumerate(gen_orders): r = re.sub("\n", "", r) if args.output_ref: # import pdb; pdb.set_trace() resp_f.write(val_src_orders[i] + '\t' + val_set_orders[i] + '\t' + r + '\n') else: resp_f.write(r + '\n') print("=" * 80) sys.stdout.flush() else: generated = 0 while True: raw_text = input("Model prompt >>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input("Model prompt >>> ") context_tokens = enc.encode(raw_text) + [EOS_ID] context_tokens = torch.tensor(context_tokens, device=device, dtype=torch.long).unsqueeze( 0) #.repeat(batch_size, 1) generated += 1 position_ids = torch.arange(0, context_tokens.size(-1), dtype=torch.long, device=context_tokens.device) token_ids = None if args.no_token_id else torch.zeros_like( context_tokens, dtype=torch.long, device=context_tokens.device) if args.beam: out = beam_search_naive(model, context_tokens, position_ids=None, token_type_ids=token_ids, length=args.generation_length, beam_width=args.beam_width, device=args.device) else: out = generate_sequence(model, context_tokens, position_ids=None, token_type_ids=token_ids, length=args.generation_length, start_token=None, temperature=args.temperature, top_k=args.top_k, sample=args.is_sampling) out = out.tolist() text = enc.decode(cut_seq_to_eos( out[0])).encode('utf-8').decode('utf-8') print("=" * 40 + " RESPONSE " + str(generated) + " " + "=" * 40) print(text) print("=" * 80)
def compute_losses(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None): doc_count = 20 rl_loss, sl_loss, loss, positive_cluster_loss, negative_cluster_loss = torch.FloatTensor([-1.0]).to(args.device),\ torch.FloatTensor([-1.0]).to(args.device),\ torch.FloatTensor([-1.0]).to(args.device),\ torch.FloatTensor([-1.0]).to(args.device),\ torch.FloatTensor([-1.0]).to(args.device) full_str = [enc.decode(s).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in input_ids.cpu().numpy()][0] split_str = [enc.encode(x) for x in full_str.split('<|endoftext|>') if len(x) > 2] split_tensors = [torch.tensor(x).unsqueeze(0).to(args.device) for x in split_str] input_ids_list = split_tensors[:doc_count] max_passage_len = max([len(x[0]) for x in input_ids_list]) passage_str_list = [enc.decode(s).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in split_str[:doc_count]] # assert len(input_ids_list) == 20, "Input is corrupt. Total length is %d, and raw text: %s" % (len(input_ids_list), full_str) if len(input_ids_list) != 20: return -1, -1, -1, -1, -1 # generate out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out = multisource_generate_sequence( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False if eval_mode else args.is_sampling, # argmax when evaluating. device=args.device, coordinator=coordinator ) # convert generations to tokens out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()][0] indiv_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in indiv_out.cpu().numpy()] # Auxiliary SL if args.optimize_option == None or args.optimize_option == 'only_sl' or eval_mode == True: decode_len = len(cut_seq_to_eos(out[0].cpu().numpy())) for i in range(decode_len): for j in range(10): positive_cluster_loss += F.kl_div(common_vocab_dist_seq[i][0], torch.exp(indiv_vocab_dist_seq[i][j])) positive_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][:10], 0) - torch.log \ (torch.from_numpy(np.array(float(10)))) negative_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][11:], 0) - torch.log \ (torch.from_numpy(np.array(float(10)))) sim_val = F.cosine_similarity(torch.exp(positive_cluster_avg_probs), torch.exp(negative_cluster_avg_probs), dim=0) for j in range(11 ,20): negative_cluster_kl = F.kl_div(common_vocab_dist_seq[i][0], torch.exp(indiv_vocab_dist_seq[i][j])) negative_cluster_loss += negative_cluster_kl if sim_val * negative_cluster_kl < positive_cluster_loss.item( ) / 10 else torch.FloatTensor \ ([0.0]).to(args.device) sl_loss = positive_cluster_loss - negative_cluster_loss # Main RL routine reward = 0.0 if args.optimize_option == None or args.optimize_option == 'only_rl' or eval_mode == True: # ranker scoring + compute rewards reward = RankerRewards(ranker, ranker_tokenizer, out_decoded, passage_str_list, max_passage_len, args) # if eval mode, only get pure reward of argmax generations if not eval_mode: # baseline 1: average rewards of individual samples if args.use_baseline_1: mean_indiv_rewards = BaselineOne(ranker, ranker_tokenizer, indiv_out_decoded, passage_str_list, max_passage_len, args) reward -= mean_indiv_rewards # baseline 2: naive average baseline reward if args.use_baseline_2: base_reward = BaselineTwo(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args) reward -= base_reward # baseline 3: self-critic baseline reward if args.use_baseline_3: assert args.is_sampling == True, "To use self-critic, the main model should be sampled, not top-k" base_reward = BaselineThree(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args) reward -= base_reward # baseline 4: multiple generation samples (top-k) and average of their rewards if args.use_baseline_4: base_reward = BaselineFour(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args) reward -= base_reward # baseline 5: past training batches average baseline reward if args.use_baseline_5: base_reward = BaselineFive(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args) reward -= base_reward # loss rl_loss = - out_total_log_probs * reward # Training options for different components of the loss function if args.optimize_option == None: loss = rl_loss + sl_loss elif args.optimize_option == 'only_rl': loss = rl_loss elif args.optimize_option == 'only_sl': loss = sl_loss elif args.optimize_option == 'only_sl_positive_cluster': loss = positive_cluster_loss elif args.optimize_option == 'only_sl_negative_cluster': loss = -negative_cluster_loss elif args.optimize_option == 'only_sl_iter': raise NotImplementedError return rl_loss.item(), sl_loss.item(), loss, reward, out_decoded
def complete_eval_samples(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None, compute_baseline_reward=False, genenc=None, engine=None, compute_retrieval_baseline_reward=False): batch_size = input_ids.shape[0] doc_count = 20 # input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2, retrieved_question = parse_raw(enc, input_ids, doc_count, args, batch_size) input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2 = parse_raw_non_retrieved(enc, input_ids, doc_count, args, batch_size) if cluster_queries_str == -1: return -1, -1, -1, -1, -1, -1, -1, -1,-1, -1, -1, -1, -1 max_passage_len = input_ids_list.shape[2] reference_corr = -1.0 if genenc: try: try: reference_corr = dot(genenc[cluster_queries_str[0]], genenc[cluster_queries_str2[0]]) / ( norm(genenc[cluster_queries_str[0]]) * norm(genenc[cluster_queries_str2[0]])) except: pass try: reference_corr = dot(genenc[cluster_queries_str[0].split('1.0 ')[1]], genenc[cluster_queries_str2[0]]) / ( norm(genenc[cluster_queries_str[0].split('1.0 ')[1]]) * norm(genenc[cluster_queries_str2[0]])) except: pass except: pass # clone baseline indiv_idx_pointer for baselines that need to generate baseline_indiv_idx_pointer=None if args.use_baseline_2 or args.use_baseline_3: baseline_indiv_idx_pointer = indiv_idx_pointer[:] # for computing reference baseline reward if compute_baseline_reward: assert isinstance(cluster_queries_str, list) reward = BaselineSixBatched(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, coordinator, cluster_queries_str=cluster_queries_str, eval_mode=eval_mode, return_all=True) # 100-passage retrieval evaluation MAP, RPREC, MRR, NDCG, MRR10 = 0.0, 0.0, 0.0, 0.0, 0.0 if engine: MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine, generated_query=cluster_queries_str, args=args) return reward, cluster_queries_str, cluster_queries_str2, MAP, RPREC, MRR, NDCG, MRR10 # return base_reward, cluster_queries_str, passage_str_list, -1, -1, cluster_queries_str2 elif compute_retrieval_baseline_reward: assert isinstance(retrieved_question, list) reward = RankerRewardsBatched(ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded=retrieved_question, eval_mode=True, return_all=True) if engine: MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine, generated_query=retrieved_question, args=args) return reward, retrieved_question, MAP, RPREC, MRR, NDCG, MRR10 else: out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq = multisource_batch_generate_sequence( model, input_ids_list, length=args.generation_length, temperature=args.temperature, top_k=args.top_k, sample=False if eval_mode else args.is_sampling, # argmax when evaluating. device=args.device, coordinator=coordinator, indiv_idx_pointer=indiv_idx_pointer, args=args ) # convert generations to tokens out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] #TODO need batched here and below assert isinstance(out_decoded, list) # if using individual generations for baseline 1 indiv_out_decoded = [] if args.use_baseline_1: for s in indiv_out: indiv_out_decoded.append([enc.decode(cut_seq_to_eos(w)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for w in s.cpu().numpy()]) # Auxiliary SL if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'sl' or eval_mode == True: sl_loss, ent_loss = auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq) else: sl_loss, ent_loss = torch.FloatTensor([0.0]).to(args.device), torch.FloatTensor([0.0]).to(args.device) # Main RL routine if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'rl' or eval_mode == True: reward = compute_reward_and_baselines(batch_size, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded, eval_mode, indiv_out_decoded, model, input_ids_list, enc, baseline_indiv_idx_pointer, coordinator, cluster_queries_str) # since it is complete eval compute rl_loss with rprec rl_loss = torch.sum(- out_total_log_probs.squeeze(1) * torch.from_numpy(reward[:, 1]).float().to(args.device)) else: rl_loss = torch.FloatTensor([0.0]).to(args.device) # Training options for different components of the loss function loss = aggregate_loss(args, rl_loss, sl_loss, ent_loss) # 100-passage retrieval evaluation MAP, RPREC, MRR, NDCG, MRR10 = 0.0, 0.0, 0.0, 0.0, 0.0 if engine: MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine, generated_query=out_decoded, args=args) # return rl_loss.item(), sl_loss.item(), loss, torch.mean(torch.from_numpy(reward).float().to(args.device), axis=0), out_decoded, passage_str_list, attention_seq, reference_corr, MAP, RPREC, MRR, NDCG, MRR10 return rl_loss.item(), sl_loss.item(), loss, reward, out_decoded, passage_str_list, attention_seq, reference_corr, MAP, RPREC, MRR, NDCG, MRR10