Exemple #1
0
def compute_batched_losses(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None):
    batch_size = input_ids.shape[0]
    doc_count = 20

    input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, _ = parse_raw(enc, input_ids, doc_count, args, batch_size)
    if cluster_queries_str == -1:
        return -1, -1, -1, -1, -1
    max_passage_len = input_ids_list.shape[2]

    # clone baseline indiv_idx_pointer for baselines that need to generate
    baseline_indiv_idx_pointer=None
    if args.use_baseline_2 or args.use_baseline_3:
        baseline_indiv_idx_pointer = indiv_idx_pointer[:]

    out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq = multisource_batch_generate_sequence(
        model,
        input_ids_list,
        length=args.generation_length,
        temperature=args.temperature,
        top_k=args.top_k,
        sample=False if eval_mode else args.is_sampling, # argmax when evaluating.
        device=args.device,
        coordinator=coordinator,
        indiv_idx_pointer=indiv_idx_pointer,
        args=args
        )

    # convert generations to tokens
    out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] #TODO need batched here and below

    # if using individual generations for baseline 1
    indiv_out_decoded = []
    if args.use_baseline_1:
        for s in indiv_out:
            indiv_out_decoded.append([enc.decode(cut_seq_to_eos(w)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for w in s.cpu().numpy()])

    # Auxiliary SL
    if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'sl' or eval_mode == True:
        sl_loss, ent_loss = auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq)
    else:
        sl_loss, ent_loss = torch.FloatTensor([0.0]).to(args.device), torch.FloatTensor([0.0]).to(args.device)

    # Main RL routine
    if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'rl' or eval_mode == True:
        reward = compute_reward_and_baselines(batch_size, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded, eval_mode, indiv_out_decoded, model, input_ids_list, enc, baseline_indiv_idx_pointer, coordinator, cluster_queries_str)
        rl_loss = torch.sum( - out_total_log_probs.squeeze(1) * torch.from_numpy(reward).float().to(args.device))
    else:
        rl_loss = torch.FloatTensor([0.0]).to(args.device)

    # Training options for different components of the loss function
    loss = aggregate_loss(args, rl_loss, sl_loss, ent_loss)
    return rl_loss.item(), sl_loss.item(), loss, torch.mean(torch.from_numpy(reward).float().to(args.device)), out_decoded
Exemple #2
0
def complete_eval_attention_weights(model, enc, input_ids, args, eval_mode, coordinator, genenc):
    batch_size = input_ids.shape[0]
    doc_count = 20

    input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2 = parse_raw_non_retrieved(enc, input_ids, doc_count, args, batch_size)
    if cluster_queries_str == -1:
        return -1, -1, -1, -1, -1, -1, -1, -1,-1, -1, -1, -1, -1
    max_passage_len = input_ids_list.shape[2]

    reference_corr = -1.0
    if genenc:
        try:
            try:
                reference_corr = dot(genenc[cluster_queries_str[0]], genenc[cluster_queries_str2[0]]) / (
                            norm(genenc[cluster_queries_str[0]]) * norm(genenc[cluster_queries_str2[0]]))
            except:
                pass

            try:
                reference_corr = dot(genenc[cluster_queries_str[0].split('1.0 ')[1]], genenc[cluster_queries_str2[0]]) / (
                        norm(genenc[cluster_queries_str[0].split('1.0 ')[1]]) * norm(genenc[cluster_queries_str2[0]]))
            except:
                pass
        except:
            pass

    # clone baseline indiv_idx_pointer for baselines that need to generate
    baseline_indiv_idx_pointer=None
    if args.use_baseline_2 or args.use_baseline_3:
        baseline_indiv_idx_pointer = indiv_idx_pointer[:]

    out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq, balancer_seq = multisource_batch_generate_sequence_for_attentions(
        model,
        input_ids_list,
        length=args.generation_length,
        temperature=args.temperature,
        top_k=args.top_k,
        sample=False if eval_mode else args.is_sampling, # argmax when evaluating.
        device=args.device,
        coordinator=coordinator,
        indiv_idx_pointer=indiv_idx_pointer,
        args=args,
        )

    # convert generations to tokens
    out_decoded = [enc.splitted_decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()]
    assert isinstance(out_decoded, list)

    return out_decoded, attention_seq, balancer_seq, reference_corr
Exemple #3
0
def BaselineThree(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, coordinator):
    with torch.no_grad():
        base_out, base_out_total_log_probs, base_common_vocab_dist_seq, base_indiv_vocab_dist_seq, base_indiv_out = multisource_generate_sequence(
            model,
            input_ids_list,
            length=args.generation_length,
            temperature=args.temperature,
            top_k=args.top_k,
            sample=False, # KEY PART
            device=args.device,
            coordinator=coordinator
        )

        # convert generations to tokens
        base_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in base_out.cpu().numpy()][0]

        # ranker scoring + compute rewards
        base_reward = RankerRewards(ranker, ranker_tokenizer, base_out_decoded, passage_str_list, max_passage_len, args)
        return base_reward
Exemple #4
0
def BaselineTwoBatched(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, indiv_idx_pointer, eval_mode=False):
    with torch.no_grad():
        # generate
        out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out = multisource_batch_generate_sequence(
            model,
            input_ids_list,
            length=args.generation_length,
            temperature=args.temperature,
            top_k=args.top_k,
            sample=False,  # argmax when evaluating.
            device=args.device,
            coordinator=None,
            indiv_idx_pointer=indiv_idx_pointer
        )

        base_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()]

        # ranker scoring + compute rewards
        base_reward = RankerRewardsBatched(ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded = base_out_decoded, eval_mode=eval_mode)
        return base_reward
Exemple #5
0
def auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq):
    ent_loss, positive_cluster_loss, negative_cluster_loss = torch.FloatTensor([0.0]).to(args.device), \
                                                                   torch.FloatTensor([0.0]).to(args.device), \
                                                                   torch.FloatTensor([0.0]).to(args.device)
    for m in range(batch_size):
        batch_positive_cluster_loss = 0.0
        batch_negative_cluster_loss = 0.0
        step_ent_loss = 0.0
        decode_len = len(cut_seq_to_eos(out[m].cpu().numpy()))
        for i in range(decode_len):
            step_ent_loss += torch.sum(F.softmax(attention_seq[i], dim=-1) * F.log_softmax(attention_seq[i],
                                                                                           dim=-1))  # maximize entropy to be near uniform
            for j in range(10):
                batch_positive_cluster_loss += F.kl_div(common_vocab_dist_seq[i][m], torch.exp(indiv_vocab_dist_seq[i][m][j])) + F.kl_div(
                    indiv_vocab_dist_seq[i][m][j], torch.exp(common_vocab_dist_seq[i][m]))

            positive_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][m][:10], 0) - torch.log \
                (torch.from_numpy(np.array(float(10))))
            negative_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][m][11:], 0) - torch.log \
                (torch.from_numpy(np.array(float(10))))
            sim_val = F.cosine_similarity(torch.exp(positive_cluster_avg_probs),
                                          torch.exp(negative_cluster_avg_probs), dim=0)

            del positive_cluster_avg_probs, negative_cluster_avg_probs

            for j in range(11, 20):
                negative_cluster_kl = F.kl_div(common_vocab_dist_seq[i][m],
                                               torch.exp(indiv_vocab_dist_seq[i][m][j])) + F.kl_div(
                    indiv_vocab_dist_seq[i][m][j], torch.exp(common_vocab_dist_seq[i][m]))
                batch_negative_cluster_loss += negative_cluster_kl if sim_val * negative_cluster_kl < batch_positive_cluster_loss.item() / 10 else 0.0  # torch.FloatTensor([0.0]).to(args.device)

        positive_cluster_loss += batch_positive_cluster_loss
        negative_cluster_loss += batch_negative_cluster_loss
    sl_loss = positive_cluster_loss - negative_cluster_loss

    # entropy loss
    ent_loss += step_ent_loss

    return sl_loss, ent_loss
Exemple #6
0
def run_model():
    print(socket.gethostname())

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        default='',
        help='pretrained model name or path to local checkpoint')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--load_checkpoint", '-c', type=str, default='')
    parser.add_argument("--fp16", type=boolean_string, default=False)
    parser.add_argument("--test_file",
                        '-t',
                        type=str,
                        default=None,
                        help='input file for testing')
    parser.add_argument("--output_file",
                        '-o',
                        type=str,
                        default=None,
                        help='output file for testing')
    parser.add_argument("--normalize_data", type=boolean_string, default=True)
    parser.add_argument("--batch_size", '-b', type=int, default=256)
    parser.add_argument("--max_seq_length", type=int, default=512)
    parser.add_argument("--no_token_id", action='store_true')
    parser.add_argument("--no_attn_mask", action='store_true')
    parser.add_argument("--no_eos", action='store_true')

    parser.add_argument("--generation_length", type=int, default=20)
    parser.add_argument("--temperature", type=float, default=1)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument('--unconditional',
                        action='store_true',
                        help='If true, unconditional generation.')
    parser.add_argument('--is_sampling',
                        action='store_true',
                        help='If true, sampling for generation.')
    parser.add_argument('--output_ref',
                        action='store_true',
                        help='If true, output ref')

    #BEAM
    parser.add_argument("--beam",
                        action='store_true',
                        help='If true, beam search')
    parser.add_argument("--beam_width", type=int, default=1)

    parser.add_argument('--use_gpu', action='store_true')
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument('--config', help='JSON config file')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--cstr_decode', action='store_true')
    parser.add_argument("--bonus", type=float, default=0.0)

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    if args.config is not None:
        # override argparse defaults by config JSON
        opts = json.load(open(args.config))
        for k, v in opts.items():
            if isinstance(v, str):
                # PHILLY ENV special cases
                if 'PHILLY_JOB_DIRECTORY' in v:
                    v = v.replace('PHILLY_JOB_DIRECTORY',
                                  os.environ['PHILLY_JOB_DIRECTORY'])
                elif 'PHILLY_LOG_DIRECTORY' in v:
                    v = v.replace('PHILLY_LOG_DIRECTORY',
                                  os.environ['PHILLY_LOG_DIRECTORY'])
            setattr(args, k, v)

        # command line should override config JSON
        argv = sys.argv[1:]
        overrides, _ = parser.parse_known_args(argv)
        for k, v in vars(overrides).items():
            if f'--{k}' in argv:
                setattr(args, k, v)
        # setattr(args, 'local_rank', overrides.local_rank)


# do normal parsing

    device = torch.device(
        "cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
    n_gpu = torch.cuda.device_count()
    args.device, args.n_gpu = device, n_gpu
    print(args)

    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    config = GPT2Config.from_json_file(
        os.path.join(args.model_name_or_path, 'config.json'))
    enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
    model = load_model(GPT2LMHeadModel(config),
                       args.load_checkpoint,
                       args,
                       verbose=True)
    model.to(device)
    model.eval()

    if args.test_file:
        eval_dataloader = get_eval_list_same_length_with_order(
            args.test_file, enc, args.batch_size, True)

        model.eval()
        outs = []
        targets = []
        loss_all = []
        ppl_all = []
        sources = []
        conv_ids = []
        with torch.no_grad():
            with tqdm.tqdm(total=len(eval_dataloader), desc=f"Test") as pbar:
                for step, batch in enumerate(
                        tqdm.tqdm(eval_dataloader, desc="Iteration")):

                    new_batch = []
                    for t in batch:
                        if isinstance(t, list):
                            new_batch.append(t)
                        else:
                            new_batch.append(t.to(device))

                    input_ids, position_ids, token_ids, attn_masks, label_ids, context_len, conv_id = new_batch

                    if args.no_token_id:
                        token_ids = None
                    if args.no_eos:
                        input_ids = input_ids[:, :-1]
                    if args.no_attn_mask:
                        attn_masks = None
                    if args.beam:
                        out = beam_search_naive(model,
                                                input_ids,
                                                position_ids=position_ids,
                                                token_type_ids=token_ids,
                                                attn_masks=attn_masks,
                                                length=args.generation_length,
                                                beam_width=args.beam_width,
                                                device=args.device,
                                                use_bonus=args.cstr_decode,
                                                bonus=args.bonus,
                                                enc=enc)
                    else:
                        out = generate_sequence(model,
                                                input_ids,
                                                position_ids=position_ids,
                                                token_type_ids=token_ids,
                                                attn_masks=attn_masks,
                                                length=args.generation_length,
                                                start_token=None,
                                                temperature=args.temperature,
                                                top_k=args.top_k,
                                                sample=args.is_sampling,
                                                use_bonus=args.cstr_decode,
                                                bonus=args.bonus,
                                                enc=enc)

                    sources.extend(input_ids.cpu().numpy())
                    out = out.tolist()
                    outs.extend(out)
                    targets.extend(label_ids)
                    conv_ids.extend(conv_id.cpu().numpy())

                conv_id_map = {conv_ids[i]: i for i in range(len(conv_ids))}
                val_src = [
                    enc.decode(
                        cut_seq_to_eos(s)).encode('utf-8').decode('utf-8')
                    for s in sources
                ]
                #print(len(val_src),len(targets))

                val_set = [
                    enc.decode(s).encode('utf-8').decode('utf-8')
                    for s in targets
                ]
                gen = [
                    enc.decode(
                        cut_seq_to_eos(s)).encode('utf-8').decode('utf-8')
                    for s in outs
                ]

                val_src_orders = [
                    val_src[conv_id_map[i]] for i in sorted(conv_id_map)
                ]
                val_set_orders = [
                    val_set[conv_id_map[i]] for i in sorted(conv_id_map)
                ]
                gen_orders = [gen[conv_id_map[i]] for i in sorted(conv_id_map)]

                print("=" * 40 + " SAMPLE " + "=" * 40)
                src = enc.decode([
                    x for x in input_ids[-1].cpu().numpy() if x != 0
                ]).encode('utf-8').decode('utf-8')
                gt = val_set[-1]
                resp = gen[-1]
                print(
                    f"Source: \t {src} \n Oracle: \t {gt} \n Resp: \t {resp}\n"
                )
                if args.output_file:
                    with open(args.output_file + '.resp.txt', "w") as resp_f:
                        for i, r in enumerate(gen_orders):
                            r = re.sub("\n", "", r)
                            if args.output_ref:
                                # import pdb; pdb.set_trace()
                                resp_f.write(val_src_orders[i] + '\t' +
                                             val_set_orders[i] + '\t' + r +
                                             '\n')
                            else:
                                resp_f.write(r + '\n')
                print("=" * 80)

                sys.stdout.flush()

    else:
        generated = 0
        while True:
            raw_text = input("Model prompt >>> ")
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("Model prompt >>> ")
            context_tokens = enc.encode(raw_text) + [EOS_ID]
            context_tokens = torch.tensor(context_tokens,
                                          device=device,
                                          dtype=torch.long).unsqueeze(
                                              0)  #.repeat(batch_size, 1)
            generated += 1
            position_ids = torch.arange(0,
                                        context_tokens.size(-1),
                                        dtype=torch.long,
                                        device=context_tokens.device)
            token_ids = None if args.no_token_id else torch.zeros_like(
                context_tokens, dtype=torch.long, device=context_tokens.device)
            if args.beam:
                out = beam_search_naive(model,
                                        context_tokens,
                                        position_ids=None,
                                        token_type_ids=token_ids,
                                        length=args.generation_length,
                                        beam_width=args.beam_width,
                                        device=args.device)
            else:
                out = generate_sequence(model,
                                        context_tokens,
                                        position_ids=None,
                                        token_type_ids=token_ids,
                                        length=args.generation_length,
                                        start_token=None,
                                        temperature=args.temperature,
                                        top_k=args.top_k,
                                        sample=args.is_sampling)
            out = out.tolist()
            text = enc.decode(cut_seq_to_eos(
                out[0])).encode('utf-8').decode('utf-8')
            print("=" * 40 + " RESPONSE " + str(generated) + " " + "=" * 40)
            print(text)
            print("=" * 80)
Exemple #7
0
def compute_losses(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None):
    doc_count = 20

    rl_loss, sl_loss, loss, positive_cluster_loss, negative_cluster_loss = torch.FloatTensor([-1.0]).to(args.device),\
                                                                            torch.FloatTensor([-1.0]).to(args.device),\
                                                                            torch.FloatTensor([-1.0]).to(args.device),\
                                                                            torch.FloatTensor([-1.0]).to(args.device),\
                                                                            torch.FloatTensor([-1.0]).to(args.device)


    full_str = [enc.decode(s).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in input_ids.cpu().numpy()][0]
    split_str = [enc.encode(x) for x in full_str.split('<|endoftext|>') if len(x) > 2]
    split_tensors = [torch.tensor(x).unsqueeze(0).to(args.device) for x in split_str]
    input_ids_list = split_tensors[:doc_count]
    max_passage_len = max([len(x[0]) for x in input_ids_list])

    passage_str_list = [enc.decode(s).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in split_str[:doc_count]]

    # assert len(input_ids_list) == 20, "Input is corrupt. Total length is %d, and raw text: %s" % (len(input_ids_list), full_str)
    if len(input_ids_list) != 20:
        return -1, -1, -1, -1, -1

    # generate
    out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out = multisource_generate_sequence(
        model,
        input_ids_list,
        length=args.generation_length,
        temperature=args.temperature,
        top_k=args.top_k,
        sample=False if eval_mode else args.is_sampling, # argmax when evaluating.
        device=args.device,
        coordinator=coordinator
        )

    # convert generations to tokens
    out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()][0]
    indiv_out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in indiv_out.cpu().numpy()]


    # Auxiliary SL
    if args.optimize_option == None or args.optimize_option == 'only_sl' or eval_mode == True:
        decode_len = len(cut_seq_to_eos(out[0].cpu().numpy()))
        for i in range(decode_len):
            for j in range(10):
                positive_cluster_loss += F.kl_div(common_vocab_dist_seq[i][0], torch.exp(indiv_vocab_dist_seq[i][j]))

            positive_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][:10], 0) - torch.log \
                (torch.from_numpy(np.array(float(10))))
            negative_cluster_avg_probs = torch.logsumexp(indiv_vocab_dist_seq[i][11:], 0) - torch.log \
                (torch.from_numpy(np.array(float(10))))
            sim_val = F.cosine_similarity(torch.exp(positive_cluster_avg_probs), torch.exp(negative_cluster_avg_probs), dim=0)

            for j in range(11 ,20):
                negative_cluster_kl = F.kl_div(common_vocab_dist_seq[i][0], torch.exp(indiv_vocab_dist_seq[i][j]))
                negative_cluster_loss += negative_cluster_kl if sim_val * negative_cluster_kl < positive_cluster_loss.item( ) / 10  else torch.FloatTensor \
                    ([0.0]).to(args.device)

        sl_loss = positive_cluster_loss - negative_cluster_loss

    # Main RL routine
    reward = 0.0
    if args.optimize_option == None or args.optimize_option == 'only_rl' or eval_mode == True:
        # ranker scoring + compute rewards
        reward = RankerRewards(ranker, ranker_tokenizer, out_decoded, passage_str_list, max_passage_len, args)

        # if eval mode, only get pure reward of argmax generations
        if not eval_mode:
            # baseline 1: average rewards of individual samples
            if args.use_baseline_1:
                mean_indiv_rewards = BaselineOne(ranker, ranker_tokenizer, indiv_out_decoded, passage_str_list, max_passage_len, args)
                reward -= mean_indiv_rewards

            # baseline 2: naive average baseline reward
            if args.use_baseline_2:
                base_reward = BaselineTwo(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args)
                reward -= base_reward

            # baseline 3: self-critic baseline reward
            if args.use_baseline_3:
                assert args.is_sampling == True, "To use self-critic, the main model should be sampled, not top-k"
                base_reward = BaselineThree(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list,
                                          max_passage_len, args)
                reward -= base_reward

            # baseline 4: multiple generation samples (top-k) and average of their rewards
            if args.use_baseline_4:
                base_reward = BaselineFour(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list,
                                            max_passage_len, args)
                reward -= base_reward

            # baseline 5: past training batches average baseline reward
            if args.use_baseline_5:
                base_reward = BaselineFive(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list,
                                            max_passage_len, args)
                reward -= base_reward

        # loss
        rl_loss = - out_total_log_probs * reward

    # Training options for different components of the loss function
    if args.optimize_option == None:
        loss = rl_loss + sl_loss
    elif args.optimize_option == 'only_rl':
        loss = rl_loss
    elif args.optimize_option == 'only_sl':
        loss = sl_loss
    elif args.optimize_option == 'only_sl_positive_cluster':
        loss = positive_cluster_loss
    elif args.optimize_option == 'only_sl_negative_cluster':
        loss = -negative_cluster_loss
    elif args.optimize_option == 'only_sl_iter':
        raise NotImplementedError

    return rl_loss.item(), sl_loss.item(), loss, reward, out_decoded
Exemple #8
0
def complete_eval_samples(model, enc, ranker, ranker_tokenizer, input_ids, args, eval_mode=False, coordinator=None, compute_baseline_reward=False, genenc=None, engine=None, compute_retrieval_baseline_reward=False):
    batch_size = input_ids.shape[0]
    doc_count = 20

    # input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2, retrieved_question = parse_raw(enc, input_ids, doc_count, args, batch_size)
    input_ids_list, full_str, indiv_idx_pointer, passage_str_list, cluster_queries_str, cluster_queries_str2 = parse_raw_non_retrieved(enc, input_ids, doc_count, args, batch_size)

    if cluster_queries_str == -1:
        return -1, -1, -1, -1, -1, -1, -1, -1,-1, -1, -1, -1, -1
    max_passage_len = input_ids_list.shape[2]

    reference_corr = -1.0
    if genenc:
        try:
            try:
                reference_corr = dot(genenc[cluster_queries_str[0]], genenc[cluster_queries_str2[0]]) / (
                            norm(genenc[cluster_queries_str[0]]) * norm(genenc[cluster_queries_str2[0]]))
            except:
                pass

            try:
                reference_corr = dot(genenc[cluster_queries_str[0].split('1.0 ')[1]], genenc[cluster_queries_str2[0]]) / (
                        norm(genenc[cluster_queries_str[0].split('1.0 ')[1]]) * norm(genenc[cluster_queries_str2[0]]))
            except:
                pass
        except:
            pass

    # clone baseline indiv_idx_pointer for baselines that need to generate
    baseline_indiv_idx_pointer=None
    if args.use_baseline_2 or args.use_baseline_3:
        baseline_indiv_idx_pointer = indiv_idx_pointer[:]

    # for computing reference baseline reward
    if compute_baseline_reward:
        assert isinstance(cluster_queries_str, list)
        reward = BaselineSixBatched(model, input_ids_list, enc, ranker, ranker_tokenizer, passage_str_list,
                                         max_passage_len, args, coordinator,
                                         cluster_queries_str=cluster_queries_str,
                                         eval_mode=eval_mode, return_all=True)

        # 100-passage retrieval evaluation
        MAP, RPREC, MRR, NDCG, MRR10 = 0.0, 0.0, 0.0, 0.0, 0.0
        if engine:
            MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine,
                                                                generated_query=cluster_queries_str, args=args)

        return reward, cluster_queries_str, cluster_queries_str2, MAP, RPREC, MRR, NDCG, MRR10
        # return base_reward, cluster_queries_str, passage_str_list, -1, -1, cluster_queries_str2
    elif compute_retrieval_baseline_reward:
        assert isinstance(retrieved_question, list)
        reward = RankerRewardsBatched(ranker, ranker_tokenizer, passage_str_list, max_passage_len, args,
                                      out_decoded=retrieved_question, eval_mode=True, return_all=True)
        if engine:
            MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine,
                                                                generated_query=retrieved_question, args=args)

        return reward, retrieved_question, MAP, RPREC, MRR, NDCG, MRR10
    else:
        out, out_total_log_probs, common_vocab_dist_seq, indiv_vocab_dist_seq, indiv_out, attention_seq = multisource_batch_generate_sequence(
            model,
            input_ids_list,
            length=args.generation_length,
            temperature=args.temperature,
            top_k=args.top_k,
            sample=False if eval_mode else args.is_sampling, # argmax when evaluating.
            device=args.device,
            coordinator=coordinator,
            indiv_idx_pointer=indiv_idx_pointer,
            args=args
            )

        # convert generations to tokens
        out_decoded = [enc.decode(cut_seq_to_eos(s)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for s in out.cpu().numpy()] #TODO need batched here and below
        assert isinstance(out_decoded, list)
        # if using individual generations for baseline 1
        indiv_out_decoded = []
        if args.use_baseline_1:
            for s in indiv_out:
                indiv_out_decoded.append([enc.decode(cut_seq_to_eos(w)).encode('utf-8', 'backslashreplace').decode('utf-8', 'backslashreplace') for w in s.cpu().numpy()])

        # Auxiliary SL
        if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'sl' or eval_mode == True:
            sl_loss, ent_loss = auxiliary_sl(args, batch_size, attention_seq, out, common_vocab_dist_seq, indiv_vocab_dist_seq)
        else:
            sl_loss, ent_loss = torch.FloatTensor([0.0]).to(args.device), torch.FloatTensor([0.0]).to(args.device)

        # Main RL routine
        if args.optimize_option == "rlsl" or args.optimize_option == "all" or args.optimize_option == 'rl' or eval_mode == True:
            reward = compute_reward_and_baselines(batch_size, ranker, ranker_tokenizer, passage_str_list, max_passage_len, args, out_decoded, eval_mode, indiv_out_decoded, model, input_ids_list, enc, baseline_indiv_idx_pointer, coordinator, cluster_queries_str)
            # since it is complete eval compute rl_loss with rprec
            rl_loss = torch.sum(- out_total_log_probs.squeeze(1) * torch.from_numpy(reward[:, 1]).float().to(args.device))
        else:
            rl_loss = torch.FloatTensor([0.0]).to(args.device)

        # Training options for different components of the loss function
        loss = aggregate_loss(args, rl_loss, sl_loss, ent_loss)

        # 100-passage retrieval evaluation
        MAP, RPREC, MRR, NDCG, MRR10 = 0.0, 0.0, 0.0, 0.0, 0.0
        if engine:
            MAP, RPREC, MRR, NDCG, MRR10 = retrieval_evaluation(ranker, ranker_tokenizer, passage_str_list, engine,
                                                                generated_query=out_decoded, args=args)

        # return rl_loss.item(), sl_loss.item(), loss, torch.mean(torch.from_numpy(reward).float().to(args.device), axis=0), out_decoded, passage_str_list, attention_seq, reference_corr, MAP, RPREC, MRR, NDCG, MRR10
        return rl_loss.item(), sl_loss.item(), loss, reward, out_decoded, passage_str_list, attention_seq, reference_corr, MAP, RPREC, MRR, NDCG, MRR10