def eval_sent_for_sampler(model,
                          token_indexers,
                          vocab,
                          vc_ss_training_sampler,
                          nei_only=False):
    max_len = 80
    batch_size = 64
    data_reader = VCSS_Reader(token_indexers=token_indexers,
                              lazy=True,
                              max_l=max_len)

    data_list = vc_ss_training_sampler.sent_list
    vc_ss.data_wrangler.assign_task_label(data_list, 'ss')

    if nei_only:  # If we only choose sent for nei example to save computation.
        print("Eval NEI only.")
        filtered_list = []
        for item in data_list:
            if item['claim_label'] == 'NOT ENOUGH INFO':
                filtered_list.append(item)
        del data_list
        data_list = filtered_list

    print("Whole ss training size:", len(data_list))

    data_instance = data_reader.read(data_list)

    biterator = BasicIterator(batch_size=batch_size)
    biterator.index_with(vocab)

    data_iter = biterator(data_instance, shuffle=False, num_epochs=1)

    print("Eval for whole training set")
    print(datetime.datetime.now())

    with torch.no_grad():
        # id2label = {
        #     0: "true",
        #     1: "false",
        #     -2: "hidden"
        # }

        # print("Evaluating ...")
        model.eval()

        for batch_idx, batch in enumerate(data_iter):
            full_out = model(batch)
            # out = full_out[:, :3]
            # prob = F.softmax(out, dim=1)

            the_batch_size = full_out.size(0)

            ir_index = 3  # irrelevant index is three. 0,1,2,3
            ir_prob_out = torch.sigmoid(full_out[:, ir_index])

            y = batch['label']

            for i in range(the_batch_size):
                pid = batch['pid'][i]

                score_s = float(full_out[i][0])
                score_r = float(full_out[i][1])
                score_nei = float(full_out[i][2])

                ir_score = float(full_out[i][ir_index])
                ir_prob = float(ir_prob_out[i])

                total = math.exp(score_s) + math.exp(score_r) + math.exp(
                    score_nei)
                prob_s = math.exp(score_s) / total
                prob_r = math.exp(score_r) / total
                prob_nei = math.exp(score_nei) / total

                item = {
                    'score_s': score_s,
                    'score_r': score_r,
                    'score_nei': score_nei,
                    'prob_s': prob_s,
                    'prob_r': prob_r,
                    'prob_nei': prob_nei,
                    'score':
                    -ir_score,  # important value, remember to set this to negative
                    'prob': 1 - ir_prob,  # important value
                }

                vc_ss_training_sampler.assign_score_direct(pid, item)

            if batch_idx % 10000 == 0:
                print(batch_idx, end=' ')
                print(datetime.datetime.now())

    print()
def train_fever_std_ema_v1(resume_model=None, do_analysis=False):
    """
    This method is created on 26 Nov 2018 08:50 with the purpose of training vc and ss all together.
    :param resume_model:
    :param wn_feature:
    :return:
    """

    num_epoch = 200
    seed = 12
    batch_size = 32
    lazy = True
    train_prob_threshold = 0.02
    train_sample_top_k = 8
    dev_prob_threshold = 0.1
    dev_sample_top_k = 5
    top_k_doc = 5

    schedule_sample_dict = defaultdict(lambda: 0.1)

    ratio_ss_for_vc = 0.2

    schedule_sample_dict.update({
        0: 0.1,
        1: 0.1,  # 200k + 400K
        2: 0.1,
        3: 0.1,  # 200k + 200k ~ 200k + 100k
        4: 0.1,
        5: 0.1,  # 200k + 100k
        6: 0.1  # 20k + 20k
    })

    # Eval at beginning of the training.
    eval_full_epoch = 1
    eval_nei_epoches = [2, 3, 4, 5, 6, 7]

    neg_only = False
    debug = False

    experiment_name = f"vc_ss_v17_ratio_ss_for_vc:{ratio_ss_for_vc}|t_prob:{train_prob_threshold}|top_k:{train_sample_top_k}_scheduled_neg_sampler"
    # resume_model = None

    print("Do EMA:")

    print("Dev prob threshold:", dev_prob_threshold)
    print("Train prob threshold:", train_prob_threshold)
    print("Train sample top k:", train_sample_top_k)

    # Get upstream sentence document retrieval data
    dev_doc_upstream_file = config.RESULT_PATH / "doc_retri/std_upstream_data_using_pageview/dev_doc.jsonl"
    train_doc_upstream_file = config.RESULT_PATH / "doc_retri/std_upstream_data_using_pageview/train_doc.jsonl"

    complete_upstream_dev_data = get_full_list(config.T_FEVER_DEV_JSONL,
                                               dev_doc_upstream_file,
                                               pred=True,
                                               top_k=top_k_doc)

    complete_upstream_train_data = get_full_list(config.T_FEVER_TRAIN_JSONL,
                                                 train_doc_upstream_file,
                                                 pred=False,
                                                 top_k=top_k_doc)
    if debug:
        complete_upstream_dev_data = complete_upstream_dev_data[:1000]
        complete_upstream_train_data = complete_upstream_train_data[:1000]

    print("Dev size:", len(complete_upstream_dev_data))
    print("Train size:", len(complete_upstream_train_data))

    # Prepare Data
    token_indexers = {
        'tokens':
        SingleIdTokenIndexer(namespace='tokens'),  # This is the raw tokens
        'elmo_chars': ELMoTokenCharactersIndexer(
            namespace='elmo_characters')  # This is the elmo_characters
    }

    # Data Reader
    dev_fever_data_reader = VCSS_Reader(token_indexers=token_indexers,
                                        lazy=lazy,
                                        max_l=260)
    train_fever_data_reader = VCSS_Reader(token_indexers=token_indexers,
                                          lazy=lazy,
                                          max_l=260)

    # Load Vocabulary
    biterator = BasicIterator(batch_size=batch_size)

    vocab, weight_dict = load_vocab_embeddings(config.DATA_ROOT /
                                               "vocab_cache" / "nli_basic")

    vocab.add_token_to_namespace('true', namespace='labels')
    vocab.add_token_to_namespace('false', namespace='labels')
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    print(vocab.get_token_to_index_vocabulary('labels'))
    print(vocab.get_vocab_size('tokens'))

    biterator.index_with(vocab)
    # Reader and prepare end

    vc_ss_training_sampler = VCSSTrainingSampler(complete_upstream_train_data)
    vc_ss_training_sampler.show_info()

    # Build Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu",
                          index=0)
    device_num = -1 if device.type == 'cpu' else 0

    model = Model(rnn_size_in=(1024 + 300 + 1, 1024 + 450 + 1),
                  rnn_size_out=(450, 450),
                  weight=weight_dict['glove.840B.300d'],
                  vocab_size=vocab.get_vocab_size('tokens'),
                  mlp_d=900,
                  embedding_dim=300,
                  max_l=300,
                  num_of_class=4)

    print("Model Max length:", model.max_l)
    if resume_model is not None:
        model.load_state_dict(torch.load(resume_model))
    model.display()
    model.to(device)

    cloned_empty_model = copy.deepcopy(model)
    ema: EMA = EMA(parameters=model.named_parameters())

    # Create Log File
    file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}")
    # Save the source code.
    script_name = os.path.basename(__file__)
    with open(os.path.join(file_path_prefix, script_name),
              'w') as out_f, open(__file__, 'r') as it:
        out_f.write(it.read())
        out_f.flush()

    analysis_dir = None
    if do_analysis:
        analysis_dir = Path(file_path_prefix) / "analysis_aux"
        analysis_dir.mkdir()
    # Save source code end.

    # Staring parameter setup
    best_dev = -1
    iteration = 0

    start_lr = 0.0001
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=start_lr)
    criterion = nn.CrossEntropyLoss()
    # parameter setup end

    for i_epoch in range(num_epoch):
        print("Resampling...")
        # This is for train
        # This is for sample candidate data for from result of ss for vc.
        # This we will need to do after each epoch.
        if i_epoch == eval_full_epoch:  # only eval at 1
            print("We now need to eval the whole training set.")
            print("Be patient and hope good luck!")
            load_ema_to_model(cloned_empty_model, ema)
            eval_sent_for_sampler(cloned_empty_model, token_indexers, vocab,
                                  vc_ss_training_sampler)

        elif i_epoch in eval_nei_epoches:  # at 2, 3, 4 eval for NEI
            print("We now need to eval the NEI training set.")
            print("Be patient and hope good luck!")
            load_ema_to_model(cloned_empty_model, ema)
            eval_sent_for_sampler(cloned_empty_model,
                                  token_indexers,
                                  vocab,
                                  vc_ss_training_sampler,
                                  nei_only=True)

        train_data_with_candidate_sample_list = vc_ss.data_wrangler.sample_sentences_for_vc_with_nei(
            config.T_FEVER_TRAIN_JSONL, vc_ss_training_sampler.sent_list,
            train_prob_threshold, train_sample_top_k)
        # We initialize the prob for each sentence so the sampler can work, but we will need to run the model for dev data to work.

        train_selection_dict = paired_selection_score_dict(
            vc_ss_training_sampler.sent_list)

        cur_train_vc_data = adv_simi_sample_with_prob_v1_1(
            config.T_FEVER_TRAIN_JSONL,
            train_data_with_candidate_sample_list,
            train_selection_dict,
            tokenized=True)

        if do_analysis:
            # Customized analysis output
            common.save_jsonl(
                vc_ss_training_sampler.sent_list, analysis_dir /
                f"E_{i_epoch}_whole_train_sent_{save_tool.get_cur_time_str()}.jsonl"
            )
            common.save_jsonl(
                train_data_with_candidate_sample_list, analysis_dir /
                f"E_{i_epoch}_sampled_train_sent_{save_tool.get_cur_time_str()}.jsonl"
            )
            common.save_jsonl(
                cur_train_vc_data, analysis_dir /
                f"E_{i_epoch}_train_vc_data_{save_tool.get_cur_time_str()}.jsonl"
            )

        print(f"E{i_epoch} VC_data:", len(cur_train_vc_data))

        # This is for sample negative candidate data for ss
        # After sampling, we decrease the ratio.
        neg_sample_upper_prob = schedule_sample_dict[i_epoch]
        print("Neg Sampler upper rate:", neg_sample_upper_prob)
        # print("Rate decreasing")
        # neg_sample_upper_prob -= decay_r
        neg_sample_upper_prob = max(0.000, neg_sample_upper_prob)

        cur_train_ss_data = vc_ss_training_sampler.sample_for_ss(
            neg_only=neg_only, upper_prob=neg_sample_upper_prob)

        if i_epoch >= 1:  # if epoch num >= 6 we balance pos and neg example for selection
            # new_ss_data = []
            pos_ss_data = []
            neg_ss_data = []
            for item in cur_train_ss_data:
                if item['selection_label'] == 'true':
                    pos_ss_data.append(item)
                elif item['selection_label'] == 'false':
                    neg_ss_data.append(item)

            ss_sample_size = min(len(pos_ss_data), len(neg_ss_data))
            random.shuffle(pos_ss_data)
            random.shuffle(neg_ss_data)
            cur_train_ss_data = pos_ss_data[:int(
                ss_sample_size * 0.5)] + neg_ss_data[:ss_sample_size]
            random.shuffle(cur_train_ss_data)

        vc_ss_training_sampler.show_info(cur_train_ss_data)
        print(f"E{i_epoch} SS_data:", len(cur_train_ss_data))

        vc_ss.data_wrangler.assign_task_label(cur_train_ss_data, 'ss')
        vc_ss.data_wrangler.assign_task_label(cur_train_vc_data, 'vc')

        vs_ss_train_list = cur_train_ss_data + cur_train_vc_data
        random.shuffle(vs_ss_train_list)
        print(f"E{i_epoch} Total ss+vc:", len(vs_ss_train_list))
        vc_ss_instance = train_fever_data_reader.read(vs_ss_train_list)

        train_iter = biterator(vc_ss_instance, shuffle=True, num_epochs=1)

        for i, batch in tqdm(enumerate(train_iter)):
            model.train()
            out = model(batch)

            if i_epoch >= 1:
                ratio_ss_for_vc = 0.8

            loss = compute_mixing_loss(
                model,
                out,
                batch,
                criterion,
                vc_ss_training_sampler,
                ss_for_vc_prob=ratio_ss_for_vc)  # Important change

            # No decay
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iteration += 1

            # EMA update
            ema(model.named_parameters())

            if i_epoch < 9:
                mod = 10000
                # mod = 100
            else:
                mod = 2000

            if iteration % mod == 0:

                # This is the code for eval:
                load_ema_to_model(cloned_empty_model, ema)

                vc_ss.data_wrangler.assign_task_label(
                    complete_upstream_dev_data, 'ss')
                dev_ss_instance = dev_fever_data_reader.read(
                    complete_upstream_dev_data)
                eval_ss_iter = biterator(dev_ss_instance,
                                         num_epochs=1,
                                         shuffle=False)
                scored_dev_sent_data = hidden_eval_ss(
                    cloned_empty_model, eval_ss_iter,
                    complete_upstream_dev_data)

                # for vc
                filtered_dev_list = vc_ss.data_wrangler.sample_sentences_for_vc_with_nei(
                    config.T_FEVER_DEV_JSONL, scored_dev_sent_data,
                    dev_prob_threshold, dev_sample_top_k)

                dev_selection_dict = paired_selection_score_dict(
                    scored_dev_sent_data)
                ready_dev_list = select_sent_with_prob_for_eval(
                    config.T_FEVER_DEV_JSONL,
                    filtered_dev_list,
                    dev_selection_dict,
                    tokenized=True)

                vc_ss.data_wrangler.assign_task_label(ready_dev_list, 'vc')
                dev_vc_instance = dev_fever_data_reader.read(ready_dev_list)
                eval_vc_iter = biterator(dev_vc_instance,
                                         num_epochs=1,
                                         shuffle=False)
                eval_dev_result_list = hidden_eval_vc(cloned_empty_model,
                                                      eval_vc_iter,
                                                      ready_dev_list)

                # Scoring
                eval_mode = {'check_sent_id_correct': True, 'standard': True}
                strict_score, acc_score, pr, rec, f1 = c_scorer.fever_score(
                    eval_dev_result_list,
                    common.load_jsonl(config.T_FEVER_DEV_JSONL),
                    mode=eval_mode,
                    verbose=False)
                print("Fever Score(Strict/Acc./Precision/Recall/F1):",
                      strict_score, acc_score, pr, rec, f1)

                print(f"Dev:{strict_score}/{acc_score}")

                if do_analysis:
                    # Customized analysis output
                    common.save_jsonl(
                        scored_dev_sent_data, analysis_dir /
                        f"E_{i_epoch}_scored_dev_sent_{save_tool.get_cur_time_str()}.jsonl"
                    )
                    common.save_jsonl(
                        eval_dev_result_list, analysis_dir /
                        f"E_{i_epoch}_eval_vc_output_data_{save_tool.get_cur_time_str()}.jsonl"
                    )

                need_save = False
                if strict_score > best_dev:
                    best_dev = strict_score
                    need_save = True

                if need_save or i_epoch < 7:
                    # save_path = os.path.join(
                    #     file_path_prefix,
                    #     f'i({iteration})_epoch({i_epoch})_dev({strict_score})_lacc({acc_score})_seed({seed})'
                    # )

                    # torch.save(model.state_dict(), save_path)

                    ema_save_path = os.path.join(
                        file_path_prefix,
                        f'ema_i({iteration})_epoch({i_epoch})_dev({strict_score})_lacc({acc_score})_p({pr})_r({rec})_f1({f1})_seed({seed})'
                    )

                    save_ema_to_file(ema, ema_save_path)
def analysis_model(model_path):
    batch_size = 32
    lazy = True
    train_prob_threshold = 0.02
    train_sample_top_k = 8
    dev_prob_threshold = 0.1
    dev_sample_top_k = 5

    neg_sample_upper_prob = 0.006
    decay_r = 0.002

    top_k_doc = 5
    dev_doc_upstream_file = config.RESULT_PATH / "doc_retri/std_upstream_data_using_pageview/dev_doc.jsonl"

    complete_upstream_dev_data = get_full_list(config.T_FEVER_DEV_JSONL,
                                               dev_doc_upstream_file,
                                               pred=True,
                                               top_k=top_k_doc)

    print("Dev size:", len(complete_upstream_dev_data))

    # Prepare Data
    token_indexers = {
        'tokens':
        SingleIdTokenIndexer(namespace='tokens'),  # This is the raw tokens
        'elmo_chars': ELMoTokenCharactersIndexer(
            namespace='elmo_characters')  # This is the elmo_characters
    }

    # Data Reader
    dev_fever_data_reader = VCSS_Reader(token_indexers=token_indexers,
                                        lazy=lazy,
                                        max_l=260)

    # Load Vocabulary
    biterator = BasicIterator(batch_size=batch_size)

    vocab, weight_dict = load_vocab_embeddings(config.DATA_ROOT /
                                               "vocab_cache" / "nli_basic")

    vocab.add_token_to_namespace('true', namespace='labels')
    vocab.add_token_to_namespace('false', namespace='labels')
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    print(vocab.get_token_to_index_vocabulary('labels'))
    print(vocab.get_vocab_size('tokens'))

    biterator.index_with(vocab)
    # Reader and prepare end

    # vc_ss_training_sampler = VCSSTrainingSampler(complete_upstream_train_data)
    # vc_ss_training_sampler.show_info()

    # Build Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu",
                          index=0)
    device_num = -1 if device.type == 'cpu' else 0

    model = Model(rnn_size_in=(1024 + 300 + 1, 1024 + 450 + 1),
                  rnn_size_out=(450, 450),
                  weight=weight_dict['glove.840B.300d'],
                  vocab_size=vocab.get_vocab_size('tokens'),
                  mlp_d=900,
                  embedding_dim=300,
                  max_l=300)

    print("Model Max length:", model.max_l)

    model.display()
    model.to(device)

    cloned_empty_model = copy.deepcopy(model)

    load_ema_to_model(cloned_empty_model, model_path)

    vc_ss.data_wrangler.assign_task_label(complete_upstream_dev_data, 'ss')
    dev_ss_instance = dev_fever_data_reader.read(complete_upstream_dev_data)
    eval_ss_iter = biterator(dev_ss_instance, num_epochs=1, shuffle=False)
    scored_dev_sent_data = hidden_eval_ss(cloned_empty_model, eval_ss_iter,
                                          complete_upstream_dev_data)

    common.save_jsonl(scored_dev_sent_data, "dev_scored_sent_data.jsonl")
    # for vc
    filtered_dev_list = vc_ss.data_wrangler.sample_sentences_for_vc_with_nei(
        config.T_FEVER_DEV_JSONL, scored_dev_sent_data, dev_prob_threshold,
        dev_sample_top_k)
    common.save_jsonl(filtered_dev_list,
                      "dev_scored_sent_data_after_sample.jsonl")

    dev_selection_dict = paired_selection_score_dict(scored_dev_sent_data)
    ready_dev_list = select_sent_with_prob_for_eval(config.T_FEVER_DEV_JSONL,
                                                    filtered_dev_list,
                                                    dev_selection_dict,
                                                    tokenized=True)

    vc_ss.data_wrangler.assign_task_label(ready_dev_list, 'vc')
    dev_vc_instance = dev_fever_data_reader.read(ready_dev_list)
    eval_vc_iter = biterator(dev_vc_instance, num_epochs=1, shuffle=False)
    eval_dev_result_list = hidden_eval_vc(cloned_empty_model, eval_vc_iter,
                                          ready_dev_list)

    common.save_jsonl(eval_dev_result_list, "dev_nli_results.jsonl")

    # Scoring
    eval_mode = {'check_sent_id_correct': True, 'standard': True}
    strict_score, acc_score, pr, rec, f1 = c_scorer.fever_score(
        eval_dev_result_list,
        common.load_jsonl(config.T_FEVER_DEV_JSONL),
        mode=eval_mode,
        verbose=False)
    print("Fever Score(Strict/Acc./Precision/Recall/F1):", strict_score,
          acc_score, pr, rec, f1)

    print(f"Dev:{strict_score}/{acc_score}")