def gold_doc_hotpotqa_extraction_example(): print('Pre-processing gold data...') longformer_tokenizer = get_hotpotqa_longformer_tokenizer( model_name=PRE_TAINED_LONFORMER_BASE) start_time = time() dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path) print('*' * 75) gold_dev_data = Gold_Hotpot_Train_Dev_Data_Collection(data=dev_data) print('Get {} gold dev-test records'.format(gold_dev_data.shape[0])) gold_dev_data.to_json( os.path.join(hotpot_path, 'gold_hotpot_dev_distractor_v1.json')) print('Runtime = {:.4f} seconds to get gold documents'.format(time() - start_time)) gold_dev_data, gold_dev_data_res, gold_norm_dev_data = Hotpot_Test_Data_PreProcess( data=gold_dev_data, tokenizer=longformer_tokenizer) print('Get {} dev-test records, encode records {}, normalized records {}'. format(gold_dev_data.shape[0], gold_dev_data_res.shape[0], gold_norm_dev_data.shape[0])) gold_dev_data.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_all_gold.json')) gold_dev_data_res.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_encoded_gold.json')) gold_norm_dev_data.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_tokenized_gold.json')) print('*' * 75)
def __init__(self, args: Namespace, fix_encoder=False): super().__init__() # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.tokenizer = get_hotpotqa_longformer_tokenizer(model_name=args.pretrained_cfg_name) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ longEncoder = LongformerEncoder.init_encoder(cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim, hidden_dropout=args.input_drop, attn_dropout=args.attn_drop, seq_project=args.seq_project) longEncoder.resize_token_embeddings(len(self.tokenizer)) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if args.frozen_layer_num > 0: modules = [longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num]] for module in modules: for param in module.parameters(): param.requires_grad = False logging.info('Frozen the first {} layers'.format(args.frozen_layer_num)) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.longformer = longEncoder #### LongFormer encoder self.hidden_size = longEncoder.get_out_size() self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction self.fix_encoder = fix_encoder ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hparams = args ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.graph_training = self.hparams.with_graph_training == 1 ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.with_graph = self.hparams.with_graph == 1 if self.with_graph: self.graph_encoder = TransformerModule(layer_num=self.hparams.layer_number, d_model=self.hidden_size, heads=self.hparams.heads) ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.mask_value = MASK_VALUE
def __init__(self, args: Namespace, fix_encoder=False): super().__init__() # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.tokenizer = get_hotpotqa_longformer_tokenizer( model_name=args.pretrained_cfg_name) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ longEncoder = LongformerEncoder.init_encoder( cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim, hidden_dropout=args.input_drop, attn_dropout=args.attn_drop, seq_project=args.seq_project) longEncoder.resize_token_embeddings(len(self.tokenizer)) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if args.frozen_layer_num > 0: modules = [ longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num] ] for module in modules: for param in module.parameters(): param.requires_grad = False logging.info('Frozen the first {} layers'.format( args.frozen_layer_num)) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.longformer = longEncoder #### LongFormer encoder self.hidden_size = longEncoder.get_out_size() self.answer_type_outputs = MLP( d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=3) ## yes, no, span question score self.answer_span_outputs = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=2) ## span prediction score self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction self.fix_encoder = fix_encoder ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hparams = args ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hop_model_name = self.hparams.hop_model_name ## triple score ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.graph_training = (self.hparams.with_graph_training == 1) ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if self.hop_model_name not in ['DotProduct', 'BiLinear']: self.hop_model_name = None else: self.hop_doc_dotproduct = DotProduct( args=self.hparams ) if self.hop_model_name == 'DotProduct' else None self.hop_doc_bilinear = BiLinear( args=self.hparams, project_dim=self.hidden_size ) if self.hop_model_name == 'BiLinear' else None ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.mask_value = MASK_VALUE
def hotpotqa_preprocess_example(): start_time = time() longformer_tokenizer = get_hotpotqa_longformer_tokenizer( model_name=PRE_TAINED_LONFORMER_BASE) dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path) # print('*' * 100) # dev_test_data, dev_test_data_res, norm_test_data_res = Hotpot_Test_Data_PreProcess(data=dev_data, tokenizer=longformer_tokenizer) # print('Get {} dev-test records, encode records {}, tokenized records {}'.format(dev_test_data.shape[0], dev_test_data_res.shape[0], norm_test_data_res.shape[0])) # dev_test_data.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_all.json')) # dev_test_data_res.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_encoded.json')) # norm_test_data_res.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_tokenized.json')) print('*' * 100) dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path) dev_data, dev_data_res, norm_dev_data_res = Hotpot_Train_Dev_Data_Preprocess( data=dev_data, tokenizer=longformer_tokenizer) print('Get {} dev records, encode records {} tokenized records {}'.format( dev_data.shape[0], dev_data_res.shape[0], norm_dev_data_res.shape[0])) dev_data.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_dev_distractor_wiki_all.json')) dev_data_res.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_dev_distractor_wiki_encoded.json')) norm_dev_data_res.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_dev_distractor_wiki_tokenized.json')) print('*' * 100) train_data, _ = HOTPOT_TrainData(path=abs_hotpot_path) train_data, train_data_res, norm_train_data_res = Hotpot_Train_Dev_Data_Preprocess( data=train_data, tokenizer=longformer_tokenizer) print('Get {} training records, encode records {} tokenized records {}'. format(train_data.shape[0], train_data_res.shape[0], norm_train_data_res.shape[0])) train_data.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_train_distractor_wiki_all.json')) train_data_res.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_train_distractor_wiki_encoded.json')) norm_train_data_res.to_json( os.path.join(abs_distractor_wiki_path, 'hotpot_train_distractor_wiki_tokenized.json')) print('Runtime = {:.4f} seconds'.format(time() - start_time)) print('*' * 100)
def consistent_checker(): tokenizer = get_hotpotqa_longformer_tokenizer() encoded_data = loadJSONData(PATH=distractor_wiki_path, json_fileName=dev_processed_data_name) orig_data, _ = HOTPOT_DevData_Distractor() # orig_data, _ = HOTPOT_TrainData() col_names = [] for col in encoded_data.columns: col_names.append(col) # print(col) # # doc_label # doc_ans_label # doc_num # doc_len # doc_start # doc_end # head_idx # tail_idx # sent_label # sent_ans_label # sent_num # sent_len # sent_start # sent_end # sent2doc # sentIndoc # doc_sent_num # ctx_encode # ctx_len # attn_mask # global_attn # token2sent # ans_mask # ans_pos_tups # ans_start # ans_end # answer_type def support_doc_checker(row, orig_row): doc_label = row['doc_label'] answer_type = row['answer_type'] ans_orig = orig_row['answer'] print(doc_label) doc_label = row['doc_ans_label'] print(doc_label) doc_idxes = [x[0] for x in enumerate(doc_label) if x[1] > 0] doc_labels = [doc_label[x] for x in doc_idxes] print(doc_labels) # flag = (doc_labels[0] == doc_labels[1]) and (doc_labels[0] == 1) and answer_type[0] > 0 flag = answer_type[0] > 1 and ans_orig.strip() not in {'no'} # flag = (doc_labels[0] != doc_labels[1]) orig_context = orig_row['context'] ctx_titles = [orig_context[x][0] for x in doc_idxes] print('decode support doc title {}'.format(ctx_titles)) support_fact = orig_row['supporting_facts'] supptitle = list(set([x[0] for x in support_fact])) print('supp doc title {}'.format(supptitle)) print('*' * 75) ctx_encode = row['ctx_encode'] ctx_encode = torch.LongTensor(ctx_encode) doc_start = row['doc_start'] doc_end = row['doc_end'] for i in range(len(doc_label)): print('decode doc: \n{}'.format( tokenizer.decode(ctx_encode[doc_start[i]:(doc_end[i] + 1)]))) print('orig_doc : \n{}'.format(orig_row['context'][i])) print('-' * 75) print(tokenizer.decode(ctx_encode[doc_start])) print(tokenizer.decode(ctx_encode[doc_end])) print(len(doc_label)) return len(ctx_encode), flag def support_sent_checker(row, orig_row): sent_label = row['sent_label'] sent_idxes = [x[0] for x in enumerate(sent_label) if x[1] > 0] sent2doc = row['sent2doc'] sentIndoc = row['sentIndoc'] sentidxPair = list(zip(sent2doc, sentIndoc)) suppPair = [sentidxPair[x] for x in sent_idxes] orig_context = orig_row['context'] decode_supp_sent = [(orig_context[x[0]][0], x[1]) for x in suppPair] print('decode supp sent {}'.format(decode_supp_sent)) support_fact = orig_row['supporting_facts'] print(support_fact) print('*' * 75) sent_start = row['sent_start'] sent_end = row['sent_end'] sent_pair = list(zip(sent_start, sent_end)) supp_sent_pair = [sent_pair[x] for x in sent_idxes] ctx_encode = row['ctx_encode'] ctx_encode = torch.LongTensor(ctx_encode) decode_supp_sent_text = [ tokenizer.decode(ctx_encode[x[0]:(x[1] + 1)]) for x in supp_sent_pair ] print('decode sents:\n{}'.format('\n'.join(decode_supp_sent_text))) orig_supp_sent = [orig_context[x[0]][1][x[1]] for x in suppPair] print('orig sents:\n{}'.format('\n'.join(orig_supp_sent))) print('*' * 75) return len(sent_start) def answer_checker(row, orig_row): orig_answer = orig_row['answer'] ans_tups = row['ans_pos_tups'] print(len(ans_tups)) ctx_encode = row['ctx_encode'] ctx_encode = torch.LongTensor(ctx_encode) ans_start = row['ans_start'][0] ans_end = row['ans_end'][0] decode_answer = tokenizer.decode(ctx_encode[ans_start:(ans_end + 1)]) print('ori answer: {}'.format(orig_answer.strip())) print('dec answer: {}'.format(decode_answer.strip())) em, f1, prec, recall = answer_score(prediction=decode_answer, gold=orig_answer) print('em {} f1 {} prec {} recall {}'.format(em, f1, prec, recall)) print('*' * 75) return em, f1, prec, recall, len(ans_tups) def doc_sent_ans_consistent(row, orig_row): answer_start = row['ans_start'] answer_end = row['ans_end'] sent_start = row['sent_start'] sent_end = row['sent_end'] doc_start = row['doc_start'] doc_end = row['doc_end'] doc_ans_label = row['doc_ans_label'] doc_with_ans_idx = [x[0] for x in enumerate(doc_ans_label) if x[1] > 1] sent_ans_label = row['sent_ans_label'] sent_with_ans_idx = [ x[0] for x in enumerate(sent_ans_label) if x[1] > 1 ] ctx_encode = row['ctx_encode'] ctx_encode = torch.LongTensor(ctx_encode) answer_type = row['answer_type'] if answer_type[0] == 0: ans_doc_start = doc_start[doc_with_ans_idx[0]] ans_doc_end = doc_end[doc_with_ans_idx[0]] ans_sent_start = sent_start[sent_with_ans_idx[0]] ans_sent_end = sent_end[sent_with_ans_idx[0]] # print('ans {}\n sent {}\n doc{}'.format((answer_start, answer_end), # (ans_sent_start, ans_sent_end), # (ans_doc_start, ans_doc_end))) flag1 = (answer_start[0] >= ans_sent_start) and (answer_end[0] <= ans_sent_end) flag2 = (answer_start[0] >= ans_doc_start) and (answer_end[0] <= ans_doc_end) flag3 = (ans_sent_start >= ans_doc_start) and (ans_sent_end <= ans_doc_end) # print('ans {} sent {} doc {}'.format(flag1, flag2, flag3)) # if not (flag1 and flag2 and flag3): # print('wrong preprocess') print('ans {}\n sent {}\n doc{}\n'.format( tokenizer.decode(ctx_encode[answer_start[0]:(answer_end[0] + 1)]), tokenizer.decode(ctx_encode[ans_sent_start:(ans_sent_end + 1)]), tokenizer.decode(ctx_encode[ans_doc_start:(ans_doc_end + 1)]))) # em_score = 0.0 # f1_score = 0.0 # ans_count_array = [] # for row_idx, row in encoded_data.iterrows(): # # support_doc_checker(row, orig_data.iloc[row_idx]) # # support_sent_checker(row, orig_data.iloc[row_idx]) # em, f1, prec, recall, ans_count = answer_checker(row, orig_data.iloc[row_idx]) # em_score = em_score + em # f1_score = f1_score + f1 # ans_count_array.append(ans_count) # print('em {} f1 {}'.format(em_score/encoded_data.shape[0], f1_score/encoded_data.shape[0])) # occurrences = dict(collections.Counter(ans_count_array)) # for key, value in occurrences.items(): # print('{}\t{}'.format(key, value*1.0/encoded_data.shape[0])) # print(occurrences) ######################################### # max_len = 0 # equal_count = 0 # for row_idx, row in encoded_data.iterrows(): # doc_len, equal_flag = support_doc_checker(row, orig_data.iloc[row_idx]) # if equal_flag: # equal_count = equal_count + 1 # if max_len < doc_len: # max_len = doc_len # # sent_len = support_sent_checker(row, orig_data.iloc[row_idx]) # # if max_len < sent_len: # # max_len = sent_len # print(max_len) # print(equal_count, equal_count * 1.0/encoded_data.shape[0]) ######################################### for row_idx, row in encoded_data.iterrows(): doc_sent_ans_consistent(row, orig_data.iloc[row_idx])
def consistent_checker(): tokenizer = get_hotpotqa_longformer_tokenizer() encoded_data = loadJSONData(PATH=distractor_wiki_path, json_fileName=dev_processed_data_name) orig_data_frame, _ = HOTPOT_DevData_Distractor() # orig_data_frame, _ = HOTPOT_TrainData() encoded_data['e_id'] = range(0, encoded_data.shape[0]) dev_data_loader = get_val_data_loader(data_frame=encoded_data, tokenizer=tokenizer) # dev_data_loader = get_train_data_loader(data_frame=encoded_data, tokenizer=tokenizer) def answer_checker(row, orig_row): ans_start = row['ans_start'][0] ans_end = row['ans_end'][0] ctx_encode = row['ctx_encode'][0] answer = orig_row['answer'] print('orig answer: {}\ndeco answer: {}'.format(answer, tokenizer.decode(ctx_encode[ans_start:(ans_end+1)]))) print('*' * 75) def doc_sent_checker(row, orig_row): doc_label = row['doc_labels'][0] doc_start = row['doc_start'][0] doc_end = row['doc_end'][0] ctx_encode = row['ctx_encode'][0] doc_num = doc_label.shape[0] pos_doc_idx = (doc_label > 0).nonzero().detach().tolist() supp_docs = orig_row['supporting_facts'] global_attn = row['ctx_global_mask'][0] global_attn_mask_idxes = (global_attn == 1).nonzero(as_tuple=False).squeeze() print(global_attn_mask_idxes) print('global attn = {}'.format(global_attn)) print(tokenizer.decode(ctx_encode[doc_start])) for doc_idx in pos_doc_idx: print('doc {}'.format(tokenizer.decode(ctx_encode[doc_start[doc_idx]:(doc_end[doc_idx] + 1)]))) print(doc_label, doc_num) print('=' * 75) sent_label = row['sent_labels'][0] sent_start = row['sent_start'][0] sent_end = row['sent_end'][0] ctx_encode = row['ctx_encode'][0] sent_num = sent_label.shape[0] print(sent_label) pos_sent_idx = (sent_label > 0).nonzero().detach().tolist() print(pos_sent_idx) # for i in range(sent_num): # print('sent end {}'.format(tokenizer.decode(ctx_encode[sent_end[i]:(sent_end[i] + 1)]))) print('+' * 75) s2d_map = row['s2d_map'][0] print('sent_2_doc {} {}'.format(s2d_map, row['sent_lens'][0].shape)) sInd_map = row['sInd_map'][0] # print(sInd_map) print('+' * 75) doc_head = row['head_idx'] doc_tail = row['tail_idx'] print(doc_head) print(doc_tail) context = orig_row['context'] for sent_idx in pos_sent_idx: print('Sent\n {}'.format(tokenizer.decode(ctx_encode[sent_start[sent_idx]:(sent_end[sent_idx] + 1)]))) # print('Pair {}'.format((s2d_map[sent_idx], sInd_map[sent_idx]))) doc_idx = s2d_map[sent_idx][0].detach().item() sent_idx = sInd_map[sent_idx][0].detach().item() print('doc idx {} sent idx {}'.format(doc_idx, sent_idx)) print('Sent pair\n {}'.format((context[doc_idx][0], sent_idx, context[doc_idx][1][sent_idx]))) print(supp_docs) print('*' * 75) for batch_idx, batch in enumerate(dev_data_loader): row = batch orig_row = orig_data_frame.iloc[batch_idx] doc_sent_checker(row, orig_row) if batch_idx >=100: break