def data_loader_checker(para_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): sel_para_data = json_loader(json_file_name=para_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(sel_para_data) == len(full_data) and len(full_data) == len( examples) print('Number of examples = {}'.format(len(examples))) # hotpotdata = HotpotTestDataset(examples=examples, sep_token_id=tokenizer.sep_token_id) # dev_data_loader = DataLoader(dataset=hotpotdata, batch_size=8, # shuffle=False, # num_workers=5, # collate_fn=HotpotTestDataset.collate_fn) hotpotdata = HotpotDataset(examples=examples, sep_token_id=tokenizer.sep_token_id, sent_drop_ratio=0.25) dev_data_loader = DataLoader(dataset=hotpotdata, batch_size=8, shuffle=False, num_workers=5, collate_fn=HotpotDataset.collate_fn) for batch_idx, batch in tqdm(enumerate(dev_data_loader)): # print(batch_idx) x = batch_idx ids = batch['ids'] input_ids = batch['context_idxs'] y1 = batch['y1'] y2 = batch['y2'] # print(batch['q_type']) batch_size = input_ids.shape[0] # print(y1.shape) # print(y2.shape) for i in range(batch_size): inp_id_i = input_ids[i] y1_i = y1[i] y2_i = y2[i] # print(y1_i, y2_i) orig_answer = example_dict[ids[i]].answer_text if y1_i > 0: ans_ids = inp_id_i[y1_i:y2_i]
def case_to_feature_checker(para_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): sel_para_data = json_loader(json_file_name=para_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(sel_para_data) == len(full_data) and len(full_data) == len( examples) print('Number of examples = {}'.format(len(examples))) no_answer_count = 0 sep_id = tokenizer.encode(tokenizer.sep_token) print(sep_id) ans_count_list = [] for row in tqdm(full_data): key = row['_id'] if data_source_type is not None: exam_key = key + '_' + data_source_type else: exam_key = key example_i: Example = example_dict[exam_key]
def hotpot_answer_neg_sents_tokenizer(split_para_file: str, full_file: str, tokenizer, cls_token='[CLS]', sep_token='[SEP]', is_roberta=False, data_source_type=None): split_para_rank_data = json_loader(json_file_name=split_para_file) full_data = json_loader(json_file_name=full_file) examples = [] answer_not_found_count = 0 for row in tqdm(full_data): key = row['_id'] qas_type = row['type'] sent_names = [] sup_facts_sent_id = [] para_names = [] sup_para_id = [] # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ split_paras = split_para_rank_data[key] assert len(split_paras) == 3 selected_para_titles = [_[0] for _ in split_paras[0]] norm_question, norm_answer, selected_contexts, supporting_facts_filtered, yes_no_flag, answer_found_flag = \ ranked_context_processing(row=row, tokenizer=tokenizer, selected_para_titles=selected_para_titles, is_roberta=is_roberta) # print(yes_no_flag, answer_found_flag) if not answer_found_flag and not yes_no_flag: answer_not_found_count = answer_not_found_count + 1 # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ query_tokens = [cls_token] query_words, query_sub_tokens = tokenize_text(text=norm_question, tokenizer=tokenizer, is_roberta=is_roberta) query_tokens += query_sub_tokens if is_roberta: query_tokens += [sep_token, sep_token] else: query_tokens += [sep_token] query_input_ids = tokenizer.convert_tokens_to_ids(query_tokens) assert len(query_tokens) == len(query_input_ids) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_to_id, sent_id = {}, 0 ctx_token_list = [] ctx_input_id_list = [] sent_num = 0 para_num = 0 ctx_with_answer = False answer_positions = [] ## answer position ans_sub_tokens = [] ans_input_ids = [] # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ for para_idx, para_tuple in enumerate(selected_contexts): para_num += 1 title, sents, _, answer_sent_flags, supp_para_flag = para_tuple para_names.append(title) if supp_para_flag: sup_para_id.append(para_idx) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_tokens_list = [] sent_input_id_list = [] for local_sent_id, sent_text in enumerate(sents): sent_num += 1 local_sent_name = (title, local_sent_id) sent_to_id[local_sent_name] = sent_id sent_names.append(local_sent_name) if local_sent_name in supporting_facts_filtered: sup_facts_sent_id.append(sent_id) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ sent_words, sent_sub_tokens = tokenize_text( text=sent_text, tokenizer=tokenizer, is_roberta=is_roberta) if is_roberta: sent_sub_tokens.append(sep_token) sub_input_ids = tokenizer.convert_tokens_to_ids( sent_sub_tokens) assert len(sub_input_ids) == len(sent_sub_tokens) sent_tokens_list.append(sent_sub_tokens) sent_input_id_list.append(sub_input_ids) assert len(sent_sub_tokens) == len(sub_input_ids) sent_id = sent_id + 1 # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ctx_token_list.append(sent_tokens_list) ctx_input_id_list.append(sent_input_id_list) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if (norm_answer.strip() not in ['yes', 'no', 'noanswer' ]) and answer_found_flag: ans_words, ans_sub_tokens = tokenize_text( text=norm_answer, tokenizer=tokenizer, is_roberta=is_roberta) ans_input_ids = tokenizer.convert_tokens_to_ids(ans_sub_tokens) for sup_sent_idx, supp_sent_flag in answer_sent_flags: supp_sent_encode_ids = sent_input_id_list[sup_sent_idx] if supp_sent_flag: answer_start_idx = sub_list_match_idx( target=ans_input_ids, source=supp_sent_encode_ids) if answer_start_idx < 0: ans_words, ans_sub_tokens = tokenize_text( text=norm_answer.strip(), tokenizer=tokenizer, is_roberta=is_roberta) ans_input_ids = tokenizer.convert_tokens_to_ids( ans_sub_tokens) answer_start_idx = sub_list_match_idx( target=ans_input_ids, source=supp_sent_encode_ids) answer_len = len(ans_input_ids) assert answer_start_idx >= 0, "supp sent={} \n answer={} \n answer={} \n {} \n {}".format( tokenizer.decode(supp_sent_encode_ids), tokenizer.decode(ans_input_ids), norm_answer, supp_sent_encode_ids, ans_sub_tokens) ctx_with_answer = True # answer_positions.append((para_idx, sup_sent_idx, answer_start_idx, answer_start_idx + answer_len)) answer_positions.append( (title, sup_sent_idx, answer_start_idx, answer_start_idx + answer_len)) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ assert len(para_names) == para_num assert len(sent_names) == sent_num assert len(ctx_token_list) == para_num and len( ctx_input_id_list) == para_num ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++diff the rankers if data_source_type is not None: key = key + "_" + data_source_type ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ negative_selected_titles = [_[0] for _ in split_paras[2]] neg_ctx_text, neg_ctx_tokens, neg_ctx_input_ids = negative_context_processing( row=row, tokenizer=tokenizer, is_roberta=is_roberta, sep_token=sep_token, selected_para_titles=negative_selected_titles) ###+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ example = Example(qas_id=key, qas_type=qas_type, ctx_text=selected_contexts, ctx_tokens=ctx_token_list, ctx_input_ids=ctx_input_id_list, para_names=para_names, sup_para_id=sup_para_id, sent_names=sent_names, para_num=para_num, sent_num=sent_num, sup_fact_id=sup_facts_sent_id, question_text=norm_question, question_tokens=query_tokens, question_input_ids=query_input_ids, answer_text=norm_answer, answer_tokens=ans_sub_tokens, answer_input_ids=ans_input_ids, answer_positions=answer_positions, ctx_with_answer=ctx_with_answer, neg_ctx_text=neg_ctx_text, neg_ctx_tokens=neg_ctx_tokens, neg_ctx_input_ids=neg_ctx_input_ids) examples.append(example) print('Answer not found = {}'.format(answer_not_found_count)) return examples
def trim_case_to_feature_checker(para_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): sel_para_data = json_loader(json_file_name=para_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(sel_para_data) == len(full_data) and len(full_data) == len( examples) print('Number of examples = {}'.format(len(examples))) no_answer_count = 0 trim_no_answer_count = 0 sep_id = tokenizer.encode(tokenizer.sep_token) print(sep_id) ans_count_list = [] trim_ans_count_list = [] one_supp_sent = 0 miss_supp_count = 0 larger_512 = 0 drop_larger_512 = 0 trim_larger_512 = 0 max_query_len = 0 query_len_list = [] max_sent_num = 0 for row in tqdm(full_data): key = row['_id'] if data_source_type is not None: exam_key = key + '_' + data_source_type else: exam_key = key example_i: Example = example_dict[exam_key] doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=example_i, train_dev=True) supp_para_ids = example_i.sup_para_id if len(sent_spans) > max_sent_num: max_sent_num = len(sent_spans) # trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span( # doc_input_ids, query_spans, para_spans, sent_spans, # limit=512, sep_token_id=tokenizer.sep_token_id, ans_spans=ans_spans) # print('before drop sent {}\n{}'.format(len(sent_spans), sent_spans)) print('before drop sent {}'.format(len(sent_spans))) for supp_para_id in supp_para_ids: if supp_para_id < len(para_spans): print('before drop', example_i.para_names[supp_para_id]) # # print(len(doc_input_ids)) # # print('orig', doc_input_ids) # # print(len(sent_spans)) if len(doc_input_ids) > 512: larger_512 += 1 # # print('orig', example_i.ctx_input_ids) drop_example_i = example_sent_drop(case=example_i, drop_ratio=0.25) # # print('drop', drop_example_i.ctx_input_ids) # query_len_list.append(query_spans[0][1]) # if max_query_len < query_spans[0][1]: # max_query_len = query_spans[0][1] # query_len_list.append(query_spans[0][1]) # print(max_query_len) # print('orig q ids {}'.format(example_i.question_input_ids)) # print('drop q ids {}'.format(drop_example_i.question_input_ids)) # supp_para_names = list(set([x[0] for x in row['supporting_facts']])) # exam_para_names = [example_i.para_names[x] for x in example_i.sup_para_id] # drop_exam_para_names = [drop_example_i.para_names[x] for x in drop_example_i.sup_para_id] # print('drop', example_i.para_names) # print(drop_example_i.para_names) # print('orig {}'.format(supp_para_names)) # print('exam {}'.format(exam_para_names)) # print('drop exam {}'.format(drop_exam_para_names)) # # # print(example_i.sent_num, drop_example_i.sent_num) # orig_supp_count = len(row['supporting_facts']) # if drop_example_i.sent_num < orig_supp_count: # miss_supp_count +=1 # if drop_example_i.sent_num < 2: # one_supp_sent += 1 # sel_para_names = sel_para_data[key] # print('selected para names: ', sel_para_names) # print('example para names: ', example_i.para_names) doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=drop_example_i, train_dev=True) # print('after drop sent {}\n{}'.format(len(sent_spans), sent_spans)) print('after drop sent {}'.format(len(sent_spans))) supp_para_ids = drop_example_i.sup_para_id for supp_para_id in supp_para_ids: if supp_para_id < len(para_spans): print('after drop', drop_example_i.para_names[supp_para_id]) # print(len(drop_doc_input_ids)) # # print('drop', drop_doc_input_ids) # print(len(drop_sent_spans)) if len(doc_input_ids) > 512: drop_larger_512 += 1 # # print(type(doc_input_ids), type(query_spans), type(para_spans), type(sent_spans), type(ans_spans)) # # orig_query = row['question'] # # query_input_ids = doc_input_ids[query_spans[0][0]:query_spans[0][1]] # # decoded_query = tokenizer.decode(query_input_ids) # # # print('Orig query = {}'.format(orig_query)) # # # print('Decoded query = {}'.format(decoded_query)) # # # print('para number {}'.format(len(para_spans))) # # # print('sent number {}'.format(len(sent_spans))) # # # print('ans_spans number {}'.format(len(ans_spans))) # # orig_answer = row['answer'] # # exm_answer = example_i.answer_text # # print('{}\t{}'.format(exm_answer, ans_type_label)) # # # # assert len(example_i.sup_para_id) == len(drop_example_i.sup_para_id) # # assert len(example_i.sup_fact_id) == len(drop_example_i.sup_fact_id) # # ##+++++++ # all_sents = [] # ctx_dict = dict(row['context']) # contex_text = [] # for para_name in example_i.para_names: # contex_text.append(ctx_dict[para_name]) # all_sents += ctx_dict[para_name] # ans_count_list.append(len(ans_spans)) # if ans_type_label == 2 and len(ans_spans) == 0: # no_answer_count = no_answer_count + 1 # # print('orig ans {}'.format(ans_spans)) trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span( doc_input_ids, query_spans, para_spans, sent_spans, limit=512, sep_token_id=tokenizer.sep_token_id, ans_spans=ans_spans) # print('after trim {}\n{}'.format(len(trim_sent_spans), trim_sent_spans)) print('after trim {}'.format(len(trim_sent_spans))) if len(trim_doc_input_ids) > 512: trim_larger_512 += 0 supp_para_ids = drop_example_i.sup_para_id for supp_para_id in supp_para_ids: if supp_para_id < len(trim_para_spans): print('after trim', drop_example_i.para_names[supp_para_id]) # print('trim ans {}'.format(ans_spans)) # print('*' * 75) # trim_ans_count_list.append(len(trim_ans_spans)) # if ans_type_label == 2 and len(trim_ans_spans) == 0: # trim_no_answer_count = trim_no_answer_count + 1 # for s_idx, sent_span in enumerate(sent_spans): # sent_inp_ids = doc_input_ids[sent_span[0]:sent_span[1]] # # print(sent_inp_ids) # decoded_sent = tokenizer.decode(sent_inp_ids) # print('{} orig sent: {}'.format(s_idx, all_sents[s_idx])) # print('{} deco sent: {}'.format(s_idx, decoded_sent)) # print('$' * 10) print('-' * 75) orig_answer = row['answer'] exm_answer = example_i.answer_text for ans_idx, ans_span in enumerate(trim_ans_spans): # print(ans_span) # print(len(doc_input_ids)) # if ans_span[0] < 0 or ans_span[0] >= len(doc_input_ids) or ans_span[1] >= len(doc_input_ids): # print(ans_span) # print(len(doc_input_ids)) # print(ans_span[1]) ans_inp_ids = trim_doc_input_ids[ans_span[0]:ans_span[1]] decoded_ans = tokenizer.decode(ans_inp_ids) print('{} Orig\t{}\t{}\t{}\t{}'.format(ans_idx, orig_answer, exm_answer, decoded_ans, ans_type_label[0])) print('*' * 75)
def sent_drop_case_to_feature_checker(para_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): sel_para_data = json_loader(json_file_name=para_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(sel_para_data) == len(full_data) and len(full_data) == len( examples) print('Number of examples = {}'.format(len(examples))) no_answer_count = 0 sep_id = tokenizer.encode(tokenizer.sep_token) print(sep_id) ans_count_list = [] one_supp_sent = 0 miss_supp_count = 0 larger_512 = 0 drop_larger_512 = 0 max_query_len = 0 query_len_list = [] for row in tqdm(full_data): key = row['_id'] if data_source_type is not None: exam_key = key + '_' + data_source_type else: exam_key = key example_i: Example = example_dict[exam_key] doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=example_i, train_dev=True) # # print(len(doc_input_ids)) # # print('orig', doc_input_ids) # # print(len(sent_spans)) if len(doc_input_ids) > 512: larger_512 += 1 # print('orig', example_i.ctx_input_ids) drop_example_i = example_sent_drop(case=example_i, drop_ratio=1.0) # print('drop', drop_example_i.ctx_input_ids) query_len_list.append(query_spans[0][1]) if max_query_len < query_spans[0][1]: max_query_len = query_spans[0][1] query_len_list.append(query_spans[0][1]) print(max_query_len) # print('orig q ids {}'.format(example_i.question_input_ids)) # print('drop q ids {}'.format(drop_example_i.question_input_ids)) # supp_para_names = list(set([x[0] for x in row['supporting_facts']])) # exam_para_names = [example_i.para_names[x] for x in example_i.sup_para_id] # drop_exam_para_names = [drop_example_i.para_names[x] for x in drop_example_i.sup_para_id] # print('drop', example_i.para_names) # print(drop_example_i.para_names) # print('orig {}'.format(supp_para_names)) # print('exam {}'.format(exam_para_names)) # print('drop exam {}'.format(drop_exam_para_names)) # # # print(example_i.sent_num, drop_example_i.sent_num) # orig_supp_count = len(row['supporting_facts']) # if drop_example_i.sent_num < orig_supp_count: # miss_supp_count +=1 # if drop_example_i.sent_num < 2: # one_supp_sent += 1 # sel_para_names = sel_para_data[key] # print('selected para names: ', sel_para_names) # print('example para names: ', example_i.para_names) doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=drop_example_i, train_dev=True) # print(len(drop_doc_input_ids)) # # print('drop', drop_doc_input_ids) # print(len(drop_sent_spans)) if len(doc_input_ids) > 512: drop_larger_512 += 1 # print(type(doc_input_ids), type(query_spans), type(para_spans), type(sent_spans), type(ans_spans)) # orig_query = row['question'] # query_input_ids = doc_input_ids[query_spans[0][0]:query_spans[0][1]] # decoded_query = tokenizer.decode(query_input_ids) # # print('Orig query = {}'.format(orig_query)) # # print('Decoded query = {}'.format(decoded_query)) # # print('para number {}'.format(len(para_spans))) # # print('sent number {}'.format(len(sent_spans))) # # print('ans_spans number {}'.format(len(ans_spans))) orig_answer = row['answer'] exm_answer = example_i.answer_text # print('{}\t{}'.format(exm_answer, ans_type_label)) # # assert len(example_i.sup_para_id) == len(drop_example_i.sup_para_id) # assert len(example_i.sup_fact_id) == len(drop_example_i.sup_fact_id) # ##+++++++ all_sents = [] ctx_dict = dict(row['context']) contex_text = [] for para_name in example_i.para_names: contex_text.append(ctx_dict[para_name]) all_sents += ctx_dict[para_name] # for s_idx, sent_span in enumerate(sent_spans): # sent_inp_ids = doc_input_ids[sent_span[0]:sent_span[1]] # # print(sent_inp_ids) # decoded_sent = tokenizer.decode(sent_inp_ids) # print('{} orig sent: {}'.format(s_idx, all_sents[s_idx])) # print('{} deco sent: {}'.format(s_idx, decoded_sent)) # print('$' * 10) # print('-' * 75) # for ans_idx, ans_span in enumerate(ans_spans): # # print(ans_span) # # print(len(doc_input_ids)) # # if ans_span[0] < 0 or ans_span[0] >= len(doc_input_ids) or ans_span[1] >= len(doc_input_ids): # # print(ans_span) # # print(len(doc_input_ids)) # ans_inp_ids = doc_input_ids[ans_span[0]:ans_span[1]] # decoded_ans = tokenizer.decode(ans_inp_ids) # print('{} Orig\t{}\t{}\t{}\t{}'.format(ans_idx, orig_answer, exm_answer, decoded_ans, ans_type_label[0])) # print('*' * 75) # # # for p_idx, para_span in enumerate(para_spans): # # para_inp_ids = doc_input_ids[para_span[0]:para_span[1]] # # decoded_para = tokenizer.decode(para_inp_ids) # # print('{} orig para: {}'.format(p_idx, contex_text[p_idx])) # # print('{} deco para: {}'.format(p_idx, decoded_para)) # # print('-' * 75) # ans_count_list.append(len(ans_spans)) print('Sum of ans count = {}'.format(sum(ans_count_list))) print('One support sent count = {}'.format(one_supp_sent)) print('Miss support sent count = {}'.format(miss_supp_count)) print('Larger than 512 count = {}'.format(larger_512)) print('Larger than 512 count after drop = {}'.format(drop_larger_512)) print('Max query len = {}'.format(max_query_len)) query_len_array = np.array(query_len_list) print('99 = {}'.format(np.percentile(query_len_array, 99))) print('97.5 = {}'.format(np.percentile(query_len_array, 97.5)))
def consist_checker(para_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): sel_para_data = json_loader(json_file_name=para_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(sel_para_data) == len(full_data) and len(full_data) == len( examples) print('Number of examples = {}'.format(len(examples))) no_answer_count = 0 for row in tqdm(full_data): key = row['_id'] if data_source_type is not None: exam_key = key + '_' + data_source_type else: exam_key = key raw_question = row['question'] raw_context = row['context'] raw_answer = row['answer'] example_i: Example = example_dict[exam_key] exm_question = example_i.question_text exm_answer = example_i.answer_text exm_context = example_i.ctx_text exm_ctx_token_list = example_i.ctx_tokens exm_ctx_input_ids = example_i.ctx_input_ids # print('{}\t{}'.format(key, exam_key)) # print('raw question:', raw_question) # print('exm question:', exm_question) # print('raw answer:', raw_answer) # print('exm answer:', exm_answer) answer_positions = example_i.answer_positions para_names = example_i.para_names para_name_dict = dict([(x[1], x[0]) for x in enumerate(para_names)]) encode_answer = '' for para_i, sent_i, start_i, end_i in answer_positions: para_idx = para_name_dict[para_i] sent_ids = exm_ctx_input_ids[para_idx][sent_i] encode_answer = tokenizer.decode(sent_ids[start_i:end_i]) # if raw_answer in ['yes', 'no']: # print('{}\t{}\t{}\t{}'.format(raw_answer, exm_answer, encode_answer, example_i.ctx_with_answer)) if exm_answer in ['noanswer']: print('{}\t{}\tencode:{}\t{}'.format(raw_answer, exm_answer, encode_answer, example_i.ctx_with_answer)) if not example_i.ctx_with_answer and raw_answer not in ['yes', 'no']: no_answer_count = no_answer_count + 1 contex_text = [] ctx_dict = dict(raw_context) contex_text = [] for para_name in para_names: contex_text.append(ctx_dict[para_name]) for para_idx, ctx_token_list in enumerate(exm_ctx_token_list): ctx_inp_id_list = exm_ctx_input_ids[para_idx] orig_context = contex_text[para_idx] for sent_idx, sent_inp_ids in enumerate(ctx_inp_id_list): print(ctx_token_list[sent_idx]) print(tokenizer.decode(sent_inp_ids)) print(orig_context[sent_idx]) print('*' * 75) # if exm_answer.strip() in ['noanswer']: # print('raw answer:', raw_answer) # print('exm answer:', exm_answer) # no_answer_count = no_answer_count + 1 # print('raw context:', raw_context) # print('*' * 75) # print('exm context:', exm_context) # print('*' * 75) # print('exm tokens: ', exm_ctx_token_list) # print('*' * 75) # for x in exm_ctx_input_ids: # for y in x: # print('exm decode: ', tokenizer.decode(y)) print(no_answer_count) return
def trim_case_to_feature_checker(para_rank_file: str, full_file: str, example_file: str, tokenizer, data_source_type=None): para_rank_data = json_loader(json_file_name=para_rank_file) full_data = json_loader(json_file_name=full_file) examples = pickle.load(gzip.open(example_file, 'rb')) example_dict = {e.qas_id: e for e in examples} assert len(para_rank_data) == len(full_data) and len(full_data) == len(examples) print('Number of examples = {}'.format(len(examples))) no_answer_count = 0 trim_no_answer_count = 0 sep_id = tokenizer.encode(tokenizer.sep_token) # print(sep_id) ans_count_list = [] trim_ans_count_list = [] one_supp_sent = 0 miss_supp_count = 0 larger_512 = 0 drop_larger_512 = 0 trim_larger_512 = 0 max_query_len = 0 query_len_list = [] max_sent_num = 0 supp_sent_num_list = [] miss_count = 0 for row in tqdm(full_data): key = row['_id'] if data_source_type is not None: exam_key = key + '_' + data_source_type else: exam_key = key example_i: Example = example_dict[exam_key] print('before replace in examples, sent num = {}'.format(len(example_i.sent_names))) doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=example_i, train_dev=True) ####++++++ print('before replacement ans = {}'.format(ans_spans)) for idx, trim_ans_span in enumerate(ans_spans): decoded_ans = tokenizer.decode(doc_input_ids[trim_ans_span[0]:trim_ans_span[1]]) print('{}\txxx{}xxx'.format(idx+1, decoded_ans)) ####++++++ supp_para_ids = example_i.sup_para_id supp_sent_ids = example_i.sup_fact_id supp_sent_num_i = len(supp_sent_ids) if len(sent_spans) > max_sent_num: max_sent_num = len(sent_spans) print('before replace sent {}'.format(len(sent_spans))) for supp_para_id in supp_para_ids: if supp_para_id < len(para_spans): print('before replace, supp para', example_i.para_names[supp_para_id]) for supp_sent_id in supp_sent_ids: if supp_sent_id < len(example_i.sent_names): print('before replace, supp sent', example_i.sent_names[supp_sent_id]) if len(doc_input_ids) > 512: larger_512 += 1 #+++++++++++++++++++++++++++++++ replace_example_i, replace_sent_ids_i = example_sent_replacement(case=example_i, replace_ratio=0.25) assert len(example_i.sent_names) == len(replace_example_i.sent_names) print('replacement ids', replace_sent_ids_i) supp_para_ids = replace_example_i.sup_para_id supp_sent_ids = replace_example_i.sup_fact_id for supp_para_id in supp_para_ids: if supp_para_id < len(replace_example_i.para_names): print('after replace, supp para', replace_example_i.para_names[supp_para_id]) for supp_sent_id in supp_sent_ids: if supp_sent_id < len(replace_example_i.sent_names): print('after replace, supp sent', replace_example_i.sent_names[supp_sent_id]) print('after replace in examples, sent num = {}'.format(len(replace_example_i.sent_names))) doc_input_ids, query_spans, para_spans, sent_spans, ans_spans, ans_type_label = \ case_to_features(case=replace_example_i, train_dev=True) trim_doc_input_ids, trim_query_spans, trim_para_spans, trim_sent_spans, trim_ans_spans = trim_input_span(doc_input_ids, query_spans, para_spans, sent_spans, limit=512, sep_token_id=tokenizer.sep_token_id, ans_spans=ans_spans) supp_sent_ids = [x for x in supp_sent_ids if x < len(trim_sent_spans)] if len(supp_sent_ids) < supp_sent_num_i: miss_count = miss_count + 1 supp_sent_num_list.append(len(supp_sent_ids)) print('after trim replace sent {}'.format(len(trim_sent_spans))) for supp_sent_id in supp_sent_ids: if supp_sent_id < len(trim_sent_spans): print('after replace/trim, supp sent', replace_example_i.sent_names[supp_sent_id]) #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ print('after trim replacement ans = {}'.format(trim_ans_spans)) print('after trim = {}'.format(len(trim_doc_input_ids))) for idx, trim_ans_span in enumerate(trim_ans_spans): decoded_ans = tokenizer.decode(trim_doc_input_ids[trim_ans_span[0]:trim_ans_span[1]]) print('{}\txxx{}xxx'.format(idx+1, decoded_ans)) print('*' * 75) print(Counter(supp_sent_num_list)) print(miss_count)