def data_loader_checker(para_file: str,
                        full_file: str,
                        example_file: str,
                        tokenizer,
                        data_source_type=None):
    sel_para_data = json_loader(json_file_name=para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(sel_para_data) == len(full_data) and len(full_data) == len(
        examples)
    print('Number of examples = {}'.format(len(examples)))

    # hotpotdata = HotpotTestDataset(examples=examples, sep_token_id=tokenizer.sep_token_id)
    # dev_data_loader = DataLoader(dataset=hotpotdata, batch_size=8,
    #         shuffle=False,
    #         num_workers=5,
    #         collate_fn=HotpotTestDataset.collate_fn)
    hotpotdata = HotpotDataset(examples=examples,
                               sep_token_id=tokenizer.sep_token_id,
                               sent_drop_ratio=0.25)

    dev_data_loader = DataLoader(dataset=hotpotdata,
                                 batch_size=8,
                                 shuffle=False,
                                 num_workers=5,
                                 collate_fn=HotpotDataset.collate_fn)

    for batch_idx, batch in tqdm(enumerate(dev_data_loader)):
        # print(batch_idx)
        x = batch_idx
        ids = batch['ids']
        input_ids = batch['context_idxs']
        y1 = batch['y1']
        y2 = batch['y2']
        # print(batch['q_type'])
        batch_size = input_ids.shape[0]
        # print(y1.shape)
        # print(y2.shape)
        for i in range(batch_size):
            inp_id_i = input_ids[i]
            y1_i = y1[i]
            y2_i = y2[i]
            # print(y1_i, y2_i)
            orig_answer = example_dict[ids[i]].answer_text
            if y1_i > 0:
                ans_ids = inp_id_i[y1_i:y2_i]
def case_to_feature_checker(para_file: str,
                            full_file: str,
                            example_file: str,
                            tokenizer,
                            data_source_type=None):
    sel_para_data = json_loader(json_file_name=para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(sel_para_data) == len(full_data) and len(full_data) == len(
        examples)
    print('Number of examples = {}'.format(len(examples)))
    no_answer_count = 0
    sep_id = tokenizer.encode(tokenizer.sep_token)
    print(sep_id)

    ans_count_list = []
    for row in tqdm(full_data):
        key = row['_id']
        if data_source_type is not None:
            exam_key = key + '_' + data_source_type
        else:
            exam_key = key
        example_i: Example = example_dict[exam_key]
def hotpot_answer_neg_sents_tokenizer(split_para_file: str,
                                      full_file: str,
                                      tokenizer,
                                      cls_token='[CLS]',
                                      sep_token='[SEP]',
                                      is_roberta=False,
                                      data_source_type=None):
    split_para_rank_data = json_loader(json_file_name=split_para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = []
    answer_not_found_count = 0
    for row in tqdm(full_data):
        key = row['_id']
        qas_type = row['type']
        sent_names = []
        sup_facts_sent_id = []
        para_names = []
        sup_para_id = []
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        split_paras = split_para_rank_data[key]
        assert len(split_paras) == 3
        selected_para_titles = [_[0] for _ in split_paras[0]]
        norm_question, norm_answer, selected_contexts, supporting_facts_filtered, yes_no_flag, answer_found_flag = \
            ranked_context_processing(row=row, tokenizer=tokenizer, selected_para_titles=selected_para_titles, is_roberta=is_roberta)
        # print(yes_no_flag, answer_found_flag)
        if not answer_found_flag and not yes_no_flag:
            answer_not_found_count = answer_not_found_count + 1
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        query_tokens = [cls_token]
        query_words, query_sub_tokens = tokenize_text(text=norm_question,
                                                      tokenizer=tokenizer,
                                                      is_roberta=is_roberta)
        query_tokens += query_sub_tokens
        if is_roberta:
            query_tokens += [sep_token, sep_token]
        else:
            query_tokens += [sep_token]
        query_input_ids = tokenizer.convert_tokens_to_ids(query_tokens)
        assert len(query_tokens) == len(query_input_ids)
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        sent_to_id, sent_id = {}, 0
        ctx_token_list = []
        ctx_input_id_list = []
        sent_num = 0
        para_num = 0
        ctx_with_answer = False
        answer_positions = []  ## answer position
        ans_sub_tokens = []
        ans_input_ids = []
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for para_idx, para_tuple in enumerate(selected_contexts):
            para_num += 1
            title, sents, _, answer_sent_flags, supp_para_flag = para_tuple
            para_names.append(title)
            if supp_para_flag:
                sup_para_id.append(para_idx)
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            sent_tokens_list = []
            sent_input_id_list = []
            for local_sent_id, sent_text in enumerate(sents):
                sent_num += 1
                local_sent_name = (title, local_sent_id)
                sent_to_id[local_sent_name] = sent_id
                sent_names.append(local_sent_name)
                if local_sent_name in supporting_facts_filtered:
                    sup_facts_sent_id.append(sent_id)
                # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
                sent_words, sent_sub_tokens = tokenize_text(
                    text=sent_text, tokenizer=tokenizer, is_roberta=is_roberta)
                if is_roberta:
                    sent_sub_tokens.append(sep_token)
                sub_input_ids = tokenizer.convert_tokens_to_ids(
                    sent_sub_tokens)
                assert len(sub_input_ids) == len(sent_sub_tokens)
                sent_tokens_list.append(sent_sub_tokens)
                sent_input_id_list.append(sub_input_ids)
                assert len(sent_sub_tokens) == len(sub_input_ids)
                sent_id = sent_id + 1
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            ctx_token_list.append(sent_tokens_list)
            ctx_input_id_list.append(sent_input_id_list)
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            if (norm_answer.strip() not in ['yes', 'no', 'noanswer'
                                            ]) and answer_found_flag:
                ans_words, ans_sub_tokens = tokenize_text(
                    text=norm_answer,
                    tokenizer=tokenizer,
                    is_roberta=is_roberta)
                ans_input_ids = tokenizer.convert_tokens_to_ids(ans_sub_tokens)
                for sup_sent_idx, supp_sent_flag in answer_sent_flags:
                    supp_sent_encode_ids = sent_input_id_list[sup_sent_idx]
                    if supp_sent_flag:
                        answer_start_idx = sub_list_match_idx(
                            target=ans_input_ids, source=supp_sent_encode_ids)
                        if answer_start_idx < 0:
                            ans_words, ans_sub_tokens = tokenize_text(
                                text=norm_answer.strip(),
                                tokenizer=tokenizer,
                                is_roberta=is_roberta)
                            ans_input_ids = tokenizer.convert_tokens_to_ids(
                                ans_sub_tokens)
                            answer_start_idx = sub_list_match_idx(
                                target=ans_input_ids,
                                source=supp_sent_encode_ids)
                        answer_len = len(ans_input_ids)
                        assert answer_start_idx >= 0, "supp sent={} \n answer={} \n answer={} \n {} \n {}".format(
                            tokenizer.decode(supp_sent_encode_ids),
                            tokenizer.decode(ans_input_ids), norm_answer,
                            supp_sent_encode_ids, ans_sub_tokens)
                        ctx_with_answer = True
                        # answer_positions.append((para_idx, sup_sent_idx, answer_start_idx, answer_start_idx + answer_len))
                        answer_positions.append(
                            (title, sup_sent_idx, answer_start_idx,
                             answer_start_idx + answer_len))
            # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
            assert len(para_names) == para_num
            assert len(sent_names) == sent_num
            assert len(ctx_token_list) == para_num and len(
                ctx_input_id_list) == para_num
        ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++diff the rankers
        if data_source_type is not None:
            key = key + "_" + data_source_type
        ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        negative_selected_titles = [_[0] for _ in split_paras[2]]
        neg_ctx_text, neg_ctx_tokens, neg_ctx_input_ids = negative_context_processing(
            row=row,
            tokenizer=tokenizer,
            is_roberta=is_roberta,
            sep_token=sep_token,
            selected_para_titles=negative_selected_titles)
        ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        example = Example(qas_id=key,
                          qas_type=qas_type,
                          ctx_text=selected_contexts,
                          ctx_tokens=ctx_token_list,
                          ctx_input_ids=ctx_input_id_list,
                          para_names=para_names,
                          sup_para_id=sup_para_id,
                          sent_names=sent_names,
                          para_num=para_num,
                          sent_num=sent_num,
                          sup_fact_id=sup_facts_sent_id,
                          question_text=norm_question,
                          question_tokens=query_tokens,
                          question_input_ids=query_input_ids,
                          answer_text=norm_answer,
                          answer_tokens=ans_sub_tokens,
                          answer_input_ids=ans_input_ids,
                          answer_positions=answer_positions,
                          ctx_with_answer=ctx_with_answer,
                          neg_ctx_text=neg_ctx_text,
                          neg_ctx_tokens=neg_ctx_tokens,
                          neg_ctx_input_ids=neg_ctx_input_ids)
        examples.append(example)
    print('Answer not found = {}'.format(answer_not_found_count))
    return examples
def trim_case_to_feature_checker(para_file: str,
                                 full_file: str,
                                 example_file: str,
                                 tokenizer,
                                 data_source_type=None):
    sel_para_data = json_loader(json_file_name=para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(sel_para_data) == len(full_data) and len(full_data) == len(
        examples)
    print('Number of examples = {}'.format(len(examples)))
    no_answer_count = 0
    trim_no_answer_count = 0
    sep_id = tokenizer.encode(tokenizer.sep_token)
    print(sep_id)

    ans_count_list = []
    trim_ans_count_list = []
    one_supp_sent = 0
    miss_supp_count = 0
    larger_512 = 0
    drop_larger_512 = 0
    trim_larger_512 = 0
    max_query_len = 0
    query_len_list = []
    max_sent_num = 0

    for row in tqdm(full_data):
        key = row['_id']
        if data_source_type is not None:
            exam_key = key + '_' + data_source_type
        else:
            exam_key = key
        example_i: Example = example_dict[exam_key]
        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=example_i, train_dev=True)

        supp_para_ids = example_i.sup_para_id

        if len(sent_spans) > max_sent_num:
            max_sent_num = len(sent_spans)

        # trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span(
        #     doc_input_ids, query_spans, para_spans, sent_spans,
        #     limit=512, sep_token_id=tokenizer.sep_token_id, ans_spans=ans_spans)
        # print('before drop sent {}\n{}'.format(len(sent_spans), sent_spans))
        print('before drop sent {}'.format(len(sent_spans)))
        for supp_para_id in supp_para_ids:
            if supp_para_id < len(para_spans):
                print('before drop', example_i.para_names[supp_para_id])
        # # print(len(doc_input_ids))
        # # print('orig', doc_input_ids)
        # # print(len(sent_spans))
        if len(doc_input_ids) > 512:
            larger_512 += 1
        # # print('orig', example_i.ctx_input_ids)
        drop_example_i = example_sent_drop(case=example_i, drop_ratio=0.25)
        # # print('drop', drop_example_i.ctx_input_ids)
        # query_len_list.append(query_spans[0][1])
        # if max_query_len < query_spans[0][1]:
        #     max_query_len = query_spans[0][1]
        #     query_len_list.append(query_spans[0][1])
        # print(max_query_len)

        # print('orig q ids {}'.format(example_i.question_input_ids))
        # print('drop q ids {}'.format(drop_example_i.question_input_ids))
        # supp_para_names = list(set([x[0] for x in row['supporting_facts']]))
        # exam_para_names = [example_i.para_names[x] for x in example_i.sup_para_id]
        # drop_exam_para_names = [drop_example_i.para_names[x] for x in drop_example_i.sup_para_id]
        # print('drop', example_i.para_names)
        # print(drop_example_i.para_names)

        # print('orig {}'.format(supp_para_names))
        # print('exam {}'.format(exam_para_names))
        # print('drop exam {}'.format(drop_exam_para_names))
        #
        #
        # print(example_i.sent_num, drop_example_i.sent_num)
        # orig_supp_count = len(row['supporting_facts'])
        # if drop_example_i.sent_num < orig_supp_count:
        #     miss_supp_count +=1
        # if drop_example_i.sent_num < 2:
        #     one_supp_sent += 1
        # sel_para_names = sel_para_data[key]
        # print('selected para names: ', sel_para_names)
        # print('example para names: ', example_i.para_names)

        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=drop_example_i, train_dev=True)
        # print('after drop sent {}\n{}'.format(len(sent_spans), sent_spans))
        print('after drop sent {}'.format(len(sent_spans)))
        supp_para_ids = drop_example_i.sup_para_id
        for supp_para_id in supp_para_ids:
            if supp_para_id < len(para_spans):
                print('after drop', drop_example_i.para_names[supp_para_id])
        # print(len(drop_doc_input_ids))
        # # print('drop', drop_doc_input_ids)
        # print(len(drop_sent_spans))
        if len(doc_input_ids) > 512:
            drop_larger_512 += 1

        # # print(type(doc_input_ids), type(query_spans), type(para_spans), type(sent_spans), type(ans_spans))
        # # orig_query = row['question']
        # # query_input_ids = doc_input_ids[query_spans[0][0]:query_spans[0][1]]
        # # decoded_query = tokenizer.decode(query_input_ids)
        # # # print('Orig query = {}'.format(orig_query))
        # # # print('Decoded query = {}'.format(decoded_query))
        # # # print('para number {}'.format(len(para_spans)))
        # # # print('sent number {}'.format(len(sent_spans)))
        # # # print('ans_spans number {}'.format(len(ans_spans)))
        # # orig_answer = row['answer']
        # # exm_answer = example_i.answer_text
        # # print('{}\t{}'.format(exm_answer, ans_type_label))
        # #
        # # assert len(example_i.sup_para_id) == len(drop_example_i.sup_para_id)
        # # assert len(example_i.sup_fact_id) == len(drop_example_i.sup_fact_id)
        # # ##+++++++
        # all_sents = []
        # ctx_dict = dict(row['context'])
        # contex_text = []
        # for para_name in example_i.para_names:
        #     contex_text.append(ctx_dict[para_name])
        #     all_sents += ctx_dict[para_name]
        # ans_count_list.append(len(ans_spans))
        # if ans_type_label == 2 and len(ans_spans) == 0:
        #     no_answer_count = no_answer_count + 1
        # # print('orig ans {}'.format(ans_spans))
        trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span(
            doc_input_ids,
            query_spans,
            para_spans,
            sent_spans,
            limit=512,
            sep_token_id=tokenizer.sep_token_id,
            ans_spans=ans_spans)

        # print('after trim {}\n{}'.format(len(trim_sent_spans), trim_sent_spans))
        print('after trim {}'.format(len(trim_sent_spans)))
        if len(trim_doc_input_ids) > 512:
            trim_larger_512 += 0

        supp_para_ids = drop_example_i.sup_para_id
        for supp_para_id in supp_para_ids:
            if supp_para_id < len(trim_para_spans):
                print('after trim', drop_example_i.para_names[supp_para_id])
        # print('trim ans {}'.format(ans_spans))
        # print('*' * 75)
        # trim_ans_count_list.append(len(trim_ans_spans))
        # if ans_type_label == 2 and len(trim_ans_spans) == 0:
        #     trim_no_answer_count = trim_no_answer_count + 1

        # for s_idx, sent_span in enumerate(sent_spans):
        #     sent_inp_ids = doc_input_ids[sent_span[0]:sent_span[1]]
        #     # print(sent_inp_ids)
        #     decoded_sent = tokenizer.decode(sent_inp_ids)
        #     print('{} orig sent: {}'.format(s_idx, all_sents[s_idx]))
        #     print('{} deco sent: {}'.format(s_idx, decoded_sent))
        #     print('$' * 10)
        print('-' * 75)

        orig_answer = row['answer']
        exm_answer = example_i.answer_text
        for ans_idx, ans_span in enumerate(trim_ans_spans):
            # print(ans_span)
            # print(len(doc_input_ids))
            # if ans_span[0] < 0 or ans_span[0] >= len(doc_input_ids) or ans_span[1] >= len(doc_input_ids):
            #     print(ans_span)
            #     print(len(doc_input_ids))
            # print(ans_span[1])
            ans_inp_ids = trim_doc_input_ids[ans_span[0]:ans_span[1]]
            decoded_ans = tokenizer.decode(ans_inp_ids)
            print('{} Orig\t{}\t{}\t{}\t{}'.format(ans_idx, orig_answer,
                                                   exm_answer, decoded_ans,
                                                   ans_type_label[0]))
        print('*' * 75)
def sent_drop_case_to_feature_checker(para_file: str,
                                      full_file: str,
                                      example_file: str,
                                      tokenizer,
                                      data_source_type=None):
    sel_para_data = json_loader(json_file_name=para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(sel_para_data) == len(full_data) and len(full_data) == len(
        examples)
    print('Number of examples = {}'.format(len(examples)))
    no_answer_count = 0
    sep_id = tokenizer.encode(tokenizer.sep_token)
    print(sep_id)

    ans_count_list = []
    one_supp_sent = 0
    miss_supp_count = 0
    larger_512 = 0
    drop_larger_512 = 0
    max_query_len = 0
    query_len_list = []
    for row in tqdm(full_data):
        key = row['_id']
        if data_source_type is not None:
            exam_key = key + '_' + data_source_type
        else:
            exam_key = key
        example_i: Example = example_dict[exam_key]
        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=example_i, train_dev=True)
        # # print(len(doc_input_ids))
        # # print('orig', doc_input_ids)
        # # print(len(sent_spans))
        if len(doc_input_ids) > 512:
            larger_512 += 1
        # print('orig', example_i.ctx_input_ids)
        drop_example_i = example_sent_drop(case=example_i, drop_ratio=1.0)
        # print('drop', drop_example_i.ctx_input_ids)
        query_len_list.append(query_spans[0][1])
        if max_query_len < query_spans[0][1]:
            max_query_len = query_spans[0][1]
            query_len_list.append(query_spans[0][1])
            print(max_query_len)

        # print('orig q ids {}'.format(example_i.question_input_ids))
        # print('drop q ids {}'.format(drop_example_i.question_input_ids))
        # supp_para_names = list(set([x[0] for x in row['supporting_facts']]))
        # exam_para_names = [example_i.para_names[x] for x in example_i.sup_para_id]
        # drop_exam_para_names = [drop_example_i.para_names[x] for x in drop_example_i.sup_para_id]
        # print('drop', example_i.para_names)
        # print(drop_example_i.para_names)

        # print('orig {}'.format(supp_para_names))
        # print('exam {}'.format(exam_para_names))
        # print('drop exam {}'.format(drop_exam_para_names))
        #
        #
        # print(example_i.sent_num, drop_example_i.sent_num)
        # orig_supp_count = len(row['supporting_facts'])
        # if drop_example_i.sent_num < orig_supp_count:
        #     miss_supp_count +=1
        # if drop_example_i.sent_num < 2:
        #     one_supp_sent += 1
        # sel_para_names = sel_para_data[key]
        # print('selected para names: ', sel_para_names)
        # print('example para names: ', example_i.para_names)

        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=drop_example_i, train_dev=True)
        # print(len(drop_doc_input_ids))
        # # print('drop', drop_doc_input_ids)
        # print(len(drop_sent_spans))
        if len(doc_input_ids) > 512:
            drop_larger_512 += 1

        # print(type(doc_input_ids), type(query_spans), type(para_spans), type(sent_spans), type(ans_spans))
        # orig_query = row['question']
        # query_input_ids = doc_input_ids[query_spans[0][0]:query_spans[0][1]]
        # decoded_query = tokenizer.decode(query_input_ids)
        # # print('Orig query = {}'.format(orig_query))
        # # print('Decoded query = {}'.format(decoded_query))
        # # print('para number {}'.format(len(para_spans)))
        # # print('sent number {}'.format(len(sent_spans)))
        # # print('ans_spans number {}'.format(len(ans_spans)))
        orig_answer = row['answer']
        exm_answer = example_i.answer_text
        # print('{}\t{}'.format(exm_answer, ans_type_label))
        #
        # assert len(example_i.sup_para_id) == len(drop_example_i.sup_para_id)
        # assert len(example_i.sup_fact_id) == len(drop_example_i.sup_fact_id)
        # ##+++++++
        all_sents = []
        ctx_dict = dict(row['context'])
        contex_text = []
        for para_name in example_i.para_names:
            contex_text.append(ctx_dict[para_name])
            all_sents += ctx_dict[para_name]

        # for s_idx, sent_span in enumerate(sent_spans):
        #     sent_inp_ids = doc_input_ids[sent_span[0]:sent_span[1]]
        #     # print(sent_inp_ids)
        #     decoded_sent = tokenizer.decode(sent_inp_ids)
        #     print('{} orig sent: {}'.format(s_idx, all_sents[s_idx]))
        #     print('{} deco sent: {}'.format(s_idx, decoded_sent))
        #     print('$' * 10)
        # print('-' * 75)

        # for ans_idx, ans_span in enumerate(ans_spans):
        #     # print(ans_span)
        #     # print(len(doc_input_ids))
        #     # if ans_span[0] < 0 or ans_span[0] >= len(doc_input_ids) or ans_span[1] >= len(doc_input_ids):
        #     #     print(ans_span)
        #     #     print(len(doc_input_ids))
        #     ans_inp_ids = doc_input_ids[ans_span[0]:ans_span[1]]
        #     decoded_ans = tokenizer.decode(ans_inp_ids)
        #     print('{} Orig\t{}\t{}\t{}\t{}'.format(ans_idx, orig_answer, exm_answer, decoded_ans, ans_type_label[0]))
        # print('*' * 75)
        #
        # # for p_idx, para_span in enumerate(para_spans):
        # #     para_inp_ids = doc_input_ids[para_span[0]:para_span[1]]
        # #     decoded_para = tokenizer.decode(para_inp_ids)
        # #     print('{} orig para: {}'.format(p_idx, contex_text[p_idx]))
        # #     print('{} deco para: {}'.format(p_idx, decoded_para))
        # # print('-' * 75)
        #
        ans_count_list.append(len(ans_spans))

    print('Sum of ans count = {}'.format(sum(ans_count_list)))
    print('One support sent count = {}'.format(one_supp_sent))
    print('Miss support sent count = {}'.format(miss_supp_count))
    print('Larger than 512 count = {}'.format(larger_512))
    print('Larger than 512 count after drop = {}'.format(drop_larger_512))
    print('Max query len = {}'.format(max_query_len))
    query_len_array = np.array(query_len_list)

    print('99 = {}'.format(np.percentile(query_len_array, 99)))
    print('97.5 = {}'.format(np.percentile(query_len_array, 97.5)))
def consist_checker(para_file: str,
                    full_file: str,
                    example_file: str,
                    tokenizer,
                    data_source_type=None):
    sel_para_data = json_loader(json_file_name=para_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(sel_para_data) == len(full_data) and len(full_data) == len(
        examples)
    print('Number of examples = {}'.format(len(examples)))
    no_answer_count = 0
    for row in tqdm(full_data):
        key = row['_id']
        if data_source_type is not None:
            exam_key = key + '_' + data_source_type
        else:
            exam_key = key
        raw_question = row['question']
        raw_context = row['context']
        raw_answer = row['answer']
        example_i: Example = example_dict[exam_key]
        exm_question = example_i.question_text
        exm_answer = example_i.answer_text
        exm_context = example_i.ctx_text
        exm_ctx_token_list = example_i.ctx_tokens
        exm_ctx_input_ids = example_i.ctx_input_ids
        # print('{}\t{}'.format(key, exam_key))
        # print('raw question:', raw_question)
        # print('exm question:', exm_question)
        # print('raw answer:', raw_answer)
        # print('exm answer:', exm_answer)
        answer_positions = example_i.answer_positions
        para_names = example_i.para_names
        para_name_dict = dict([(x[1], x[0]) for x in enumerate(para_names)])
        encode_answer = ''
        for para_i, sent_i, start_i, end_i in answer_positions:
            para_idx = para_name_dict[para_i]
            sent_ids = exm_ctx_input_ids[para_idx][sent_i]
            encode_answer = tokenizer.decode(sent_ids[start_i:end_i])
        # if raw_answer in ['yes', 'no']:
        #     print('{}\t{}\t{}\t{}'.format(raw_answer, exm_answer, encode_answer, example_i.ctx_with_answer))
        if exm_answer in ['noanswer']:
            print('{}\t{}\tencode:{}\t{}'.format(raw_answer, exm_answer,
                                                 encode_answer,
                                                 example_i.ctx_with_answer))
        if not example_i.ctx_with_answer and raw_answer not in ['yes', 'no']:
            no_answer_count = no_answer_count + 1

        contex_text = []
        ctx_dict = dict(raw_context)
        contex_text = []
        for para_name in para_names:
            contex_text.append(ctx_dict[para_name])

        for para_idx, ctx_token_list in enumerate(exm_ctx_token_list):
            ctx_inp_id_list = exm_ctx_input_ids[para_idx]
            orig_context = contex_text[para_idx]
            for sent_idx, sent_inp_ids in enumerate(ctx_inp_id_list):
                print(ctx_token_list[sent_idx])
                print(tokenizer.decode(sent_inp_ids))
                print(orig_context[sent_idx])
        print('*' * 75)

        # if exm_answer.strip() in ['noanswer']:
        #     print('raw answer:', raw_answer)
        #     print('exm answer:', exm_answer)
        #     no_answer_count = no_answer_count + 1
        #     print('raw context:', raw_context)
        #     print('*' * 75)
        #     print('exm context:', exm_context)
        #     print('*' * 75)
        #     print('exm tokens: ', exm_ctx_token_list)
        #     print('*' * 75)
        #     for x in exm_ctx_input_ids:
        #         for y in x:
        #             print('exm decode: ', tokenizer.decode(y))

    print(no_answer_count)
    return
Esempio n. 7
0
def trim_case_to_feature_checker(para_rank_file: str,
                    full_file: str,
                    example_file: str,
                    tokenizer,
                    data_source_type=None):
    para_rank_data = json_loader(json_file_name=para_rank_file)
    full_data = json_loader(json_file_name=full_file)
    examples = pickle.load(gzip.open(example_file, 'rb'))
    example_dict = {e.qas_id: e for e in examples}
    assert len(para_rank_data) == len(full_data) and len(full_data) == len(examples)
    print('Number of examples = {}'.format(len(examples)))
    no_answer_count = 0
    trim_no_answer_count = 0
    sep_id = tokenizer.encode(tokenizer.sep_token)
    # print(sep_id)
    ans_count_list = []
    trim_ans_count_list = []
    one_supp_sent = 0
    miss_supp_count = 0
    larger_512 = 0
    drop_larger_512 = 0
    trim_larger_512 = 0
    max_query_len = 0
    query_len_list = []
    max_sent_num = 0
    supp_sent_num_list = []

    miss_count = 0
    for row in tqdm(full_data):
        key = row['_id']
        if data_source_type is not None:
            exam_key = key + '_' + data_source_type
        else:
            exam_key = key
        example_i: Example = example_dict[exam_key]
        print('before replace in examples, sent num = {}'.format(len(example_i.sent_names)))
        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=example_i, train_dev=True)
        ####++++++
        print('before replacement ans = {}'.format(ans_spans))
        for idx, trim_ans_span in enumerate(ans_spans):
            decoded_ans = tokenizer.decode(doc_input_ids[trim_ans_span[0]:trim_ans_span[1]])
            print('{}\txxx{}xxx'.format(idx+1, decoded_ans))
        ####++++++
        supp_para_ids = example_i.sup_para_id
        supp_sent_ids = example_i.sup_fact_id
        supp_sent_num_i = len(supp_sent_ids)
        if len(sent_spans) > max_sent_num:
            max_sent_num = len(sent_spans)
        print('before replace sent {}'.format(len(sent_spans)))
        for supp_para_id in supp_para_ids:
            if supp_para_id < len(para_spans):
                print('before replace, supp para', example_i.para_names[supp_para_id])
        for supp_sent_id in supp_sent_ids:
            if supp_sent_id < len(example_i.sent_names):
                print('before replace, supp sent', example_i.sent_names[supp_sent_id])
        if len(doc_input_ids) > 512:
            larger_512 += 1
        #+++++++++++++++++++++++++++++++
        replace_example_i, replace_sent_ids_i = example_sent_replacement(case=example_i, replace_ratio=0.25)
        assert len(example_i.sent_names) == len(replace_example_i.sent_names)
        print('replacement ids', replace_sent_ids_i)
        supp_para_ids = replace_example_i.sup_para_id
        supp_sent_ids = replace_example_i.sup_fact_id
        for supp_para_id in supp_para_ids:
            if supp_para_id < len(replace_example_i.para_names):
                print('after replace, supp para', replace_example_i.para_names[supp_para_id])
        for supp_sent_id in supp_sent_ids:
            if supp_sent_id < len(replace_example_i.sent_names):
                print('after replace, supp sent', replace_example_i.sent_names[supp_sent_id])

        print('after replace in examples, sent num = {}'.format(len(replace_example_i.sent_names)))
        doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \
            case_to_features(case=replace_example_i, train_dev=True)
        trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span(doc_input_ids, query_spans, para_spans, sent_spans,
                                                                                        limit=512, sep_token_id=tokenizer.sep_token_id, ans_spans=ans_spans)
        supp_sent_ids = [x for x in supp_sent_ids if x < len(trim_sent_spans)]
        if len(supp_sent_ids) < supp_sent_num_i:
            miss_count = miss_count + 1
        supp_sent_num_list.append(len(supp_sent_ids))
        print('after trim replace sent {}'.format(len(trim_sent_spans)))
        for supp_sent_id in supp_sent_ids:
            if supp_sent_id < len(trim_sent_spans):
                print('after replace/trim, supp sent', replace_example_i.sent_names[supp_sent_id])
        #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        print('after trim replacement ans = {}'.format(trim_ans_spans))
        print('after trim = {}'.format(len(trim_doc_input_ids)))
        for idx, trim_ans_span in enumerate(trim_ans_spans):
            decoded_ans = tokenizer.decode(trim_doc_input_ids[trim_ans_span[0]:trim_ans_span[1]])
            print('{}\txxx{}xxx'.format(idx+1, decoded_ans))
        print('*' * 75)


    print(Counter(supp_sent_num_list))
    print(miss_count)