def gold_doc_hotpotqa_extraction_example():
    print('Pre-processing gold data...')
    longformer_tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=PRE_TAINED_LONFORMER_BASE)
    start_time = time()
    dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path)
    print('*' * 75)
    gold_dev_data = Gold_Hotpot_Train_Dev_Data_Collection(data=dev_data)
    print('Get {} gold dev-test records'.format(gold_dev_data.shape[0]))
    gold_dev_data.to_json(
        os.path.join(hotpot_path, 'gold_hotpot_dev_distractor_v1.json'))
    print('Runtime = {:.4f} seconds to get gold documents'.format(time() -
                                                                  start_time))
    gold_dev_data, gold_dev_data_res, gold_norm_dev_data = Hotpot_Test_Data_PreProcess(
        data=gold_dev_data, tokenizer=longformer_tokenizer)
    print('Get {} dev-test records, encode records {}, normalized records {}'.
          format(gold_dev_data.shape[0], gold_dev_data_res.shape[0],
                 gold_norm_dev_data.shape[0]))
    gold_dev_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_test_distractor_wiki_all_gold.json'))
    gold_dev_data_res.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_test_distractor_wiki_encoded_gold.json'))
    gold_norm_dev_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_test_distractor_wiki_tokenized_gold.json'))
    print('*' * 75)
示例#2
0
 def __init__(self, args: Namespace, fix_encoder=False):
     super().__init__()
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.tokenizer = get_hotpotqa_longformer_tokenizer(model_name=args.pretrained_cfg_name)
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     longEncoder = LongformerEncoder.init_encoder(cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim,
                                                  hidden_dropout=args.input_drop, attn_dropout=args.attn_drop,
                                                  seq_project=args.seq_project)
     longEncoder.resize_token_embeddings(len(self.tokenizer))
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if args.frozen_layer_num > 0:
         modules = [longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num]]
         for module in modules:
             for param in module.parameters():
                 param.requires_grad = False
         logging.info('Frozen the first {} layers'.format(args.frozen_layer_num))
     # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.longformer = longEncoder #### LongFormer encoder
     self.hidden_size = longEncoder.get_out_size()
     self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction
     self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction
     self.fix_encoder = fix_encoder
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hparams = args
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.graph_training = self.hparams.with_graph_training == 1
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.with_graph = self.hparams.with_graph == 1
     if self.with_graph:
         self.graph_encoder = TransformerModule(layer_num=self.hparams.layer_number, d_model=self.hidden_size,
                                                heads=self.hparams.heads)
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.mask_value = MASK_VALUE
示例#3
0
 def __init__(self, args: Namespace, fix_encoder=False):
     super().__init__()
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.tokenizer = get_hotpotqa_longformer_tokenizer(
         model_name=args.pretrained_cfg_name)
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     longEncoder = LongformerEncoder.init_encoder(
         cfg_name=args.pretrained_cfg_name,
         projection_dim=args.project_dim,
         hidden_dropout=args.input_drop,
         attn_dropout=args.attn_drop,
         seq_project=args.seq_project)
     longEncoder.resize_token_embeddings(len(self.tokenizer))
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if args.frozen_layer_num > 0:
         modules = [
             longEncoder.embeddings,
             *longEncoder.encoder.layer[:args.frozen_layer_num]
         ]
         for module in modules:
             for param in module.parameters():
                 param.requires_grad = False
         logging.info('Frozen the first {} layers'.format(
             args.frozen_layer_num))
     # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.longformer = longEncoder  #### LongFormer encoder
     self.hidden_size = longEncoder.get_out_size()
     self.answer_type_outputs = MLP(
         d_input=self.hidden_size, d_mid=4 * self.hidden_size,
         d_out=3)  ## yes, no, span question score
     self.answer_span_outputs = MLP(d_input=self.hidden_size,
                                    d_mid=4 * self.hidden_size,
                                    d_out=2)  ## span prediction score
     self.doc_mlp = MLP(d_input=self.hidden_size,
                        d_mid=4 * self.hidden_size,
                        d_out=1)  ## support document prediction
     self.sent_mlp = MLP(d_input=self.hidden_size,
                         d_mid=4 * self.hidden_size,
                         d_out=1)  ## support sentence prediction
     self.fix_encoder = fix_encoder
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hparams = args
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hop_model_name = self.hparams.hop_model_name  ## triple score
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.graph_training = (self.hparams.with_graph_training == 1)
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if self.hop_model_name not in ['DotProduct', 'BiLinear']:
         self.hop_model_name = None
     else:
         self.hop_doc_dotproduct = DotProduct(
             args=self.hparams
         ) if self.hop_model_name == 'DotProduct' else None
         self.hop_doc_bilinear = BiLinear(
             args=self.hparams, project_dim=self.hidden_size
         ) if self.hop_model_name == 'BiLinear' else None
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.mask_value = MASK_VALUE
def hotpotqa_preprocess_example():
    start_time = time()
    longformer_tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=PRE_TAINED_LONFORMER_BASE)
    dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path)
    # print('*' * 100)
    # dev_test_data, dev_test_data_res, norm_test_data_res = Hotpot_Test_Data_PreProcess(data=dev_data, tokenizer=longformer_tokenizer)
    # print('Get {} dev-test records, encode records {}, tokenized records {}'.format(dev_test_data.shape[0], dev_test_data_res.shape[0], norm_test_data_res.shape[0]))
    # dev_test_data.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_all.json'))
    # dev_test_data_res.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_encoded.json'))
    # norm_test_data_res.to_json(os.path.join(abs_distractor_wiki_path, 'hotpot_test_distractor_wiki_tokenized.json'))
    print('*' * 100)
    dev_data, _ = HOTPOT_DevData_Distractor(path=abs_hotpot_path)
    dev_data, dev_data_res, norm_dev_data_res = Hotpot_Train_Dev_Data_Preprocess(
        data=dev_data, tokenizer=longformer_tokenizer)
    print('Get {} dev records, encode records {} tokenized records {}'.format(
        dev_data.shape[0], dev_data_res.shape[0], norm_dev_data_res.shape[0]))
    dev_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_dev_distractor_wiki_all.json'))
    dev_data_res.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_dev_distractor_wiki_encoded.json'))
    norm_dev_data_res.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_dev_distractor_wiki_tokenized.json'))
    print('*' * 100)
    train_data, _ = HOTPOT_TrainData(path=abs_hotpot_path)
    train_data, train_data_res, norm_train_data_res = Hotpot_Train_Dev_Data_Preprocess(
        data=train_data, tokenizer=longformer_tokenizer)
    print('Get {} training records, encode records {} tokenized records {}'.
          format(train_data.shape[0], train_data_res.shape[0],
                 norm_train_data_res.shape[0]))
    train_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_train_distractor_wiki_all.json'))
    train_data_res.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_train_distractor_wiki_encoded.json'))
    norm_train_data_res.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_train_distractor_wiki_tokenized.json'))
    print('Runtime = {:.4f} seconds'.format(time() - start_time))
    print('*' * 100)
示例#5
0
def consistent_checker():
    tokenizer = get_hotpotqa_longformer_tokenizer()
    encoded_data = loadJSONData(PATH=distractor_wiki_path,
                                json_fileName=dev_processed_data_name)
    orig_data, _ = HOTPOT_DevData_Distractor()
    # orig_data, _ = HOTPOT_TrainData()
    col_names = []
    for col in encoded_data.columns:
        col_names.append(col)
        # print(col)
    #
    # doc_label
    # doc_ans_label
    # doc_num
    # doc_len
    # doc_start
    # doc_end
    # head_idx
    # tail_idx
    # sent_label
    # sent_ans_label
    # sent_num
    # sent_len
    # sent_start
    # sent_end
    # sent2doc
    # sentIndoc
    # doc_sent_num
    # ctx_encode
    # ctx_len
    # attn_mask
    # global_attn
    # token2sent
    # ans_mask
    # ans_pos_tups
    # ans_start
    # ans_end
    # answer_type
    def support_doc_checker(row, orig_row):
        doc_label = row['doc_label']
        answer_type = row['answer_type']
        ans_orig = orig_row['answer']

        print(doc_label)
        doc_label = row['doc_ans_label']
        print(doc_label)
        doc_idxes = [x[0] for x in enumerate(doc_label) if x[1] > 0]
        doc_labels = [doc_label[x] for x in doc_idxes]
        print(doc_labels)
        # flag = (doc_labels[0] == doc_labels[1]) and (doc_labels[0] == 1) and answer_type[0] > 0
        flag = answer_type[0] > 1 and ans_orig.strip() not in {'no'}
        # flag = (doc_labels[0] != doc_labels[1])

        orig_context = orig_row['context']
        ctx_titles = [orig_context[x][0] for x in doc_idxes]
        print('decode support doc title {}'.format(ctx_titles))
        support_fact = orig_row['supporting_facts']
        supptitle = list(set([x[0] for x in support_fact]))
        print('supp doc title {}'.format(supptitle))
        print('*' * 75)
        ctx_encode = row['ctx_encode']
        ctx_encode = torch.LongTensor(ctx_encode)
        doc_start = row['doc_start']
        doc_end = row['doc_end']
        for i in range(len(doc_label)):
            print('decode doc: \n{}'.format(
                tokenizer.decode(ctx_encode[doc_start[i]:(doc_end[i] + 1)])))
            print('orig_doc : \n{}'.format(orig_row['context'][i]))
            print('-' * 75)

        print(tokenizer.decode(ctx_encode[doc_start]))
        print(tokenizer.decode(ctx_encode[doc_end]))
        print(len(doc_label))

        return len(ctx_encode), flag

    def support_sent_checker(row, orig_row):
        sent_label = row['sent_label']
        sent_idxes = [x[0] for x in enumerate(sent_label) if x[1] > 0]
        sent2doc = row['sent2doc']
        sentIndoc = row['sentIndoc']

        sentidxPair = list(zip(sent2doc, sentIndoc))
        suppPair = [sentidxPair[x] for x in sent_idxes]

        orig_context = orig_row['context']
        decode_supp_sent = [(orig_context[x[0]][0], x[1]) for x in suppPair]
        print('decode supp sent {}'.format(decode_supp_sent))
        support_fact = orig_row['supporting_facts']
        print(support_fact)
        print('*' * 75)

        sent_start = row['sent_start']
        sent_end = row['sent_end']
        sent_pair = list(zip(sent_start, sent_end))
        supp_sent_pair = [sent_pair[x] for x in sent_idxes]
        ctx_encode = row['ctx_encode']
        ctx_encode = torch.LongTensor(ctx_encode)
        decode_supp_sent_text = [
            tokenizer.decode(ctx_encode[x[0]:(x[1] + 1)])
            for x in supp_sent_pair
        ]
        print('decode sents:\n{}'.format('\n'.join(decode_supp_sent_text)))
        orig_supp_sent = [orig_context[x[0]][1][x[1]] for x in suppPair]
        print('orig sents:\n{}'.format('\n'.join(orig_supp_sent)))
        print('*' * 75)
        return len(sent_start)

    def answer_checker(row, orig_row):
        orig_answer = orig_row['answer']
        ans_tups = row['ans_pos_tups']
        print(len(ans_tups))
        ctx_encode = row['ctx_encode']
        ctx_encode = torch.LongTensor(ctx_encode)
        ans_start = row['ans_start'][0]
        ans_end = row['ans_end'][0]
        decode_answer = tokenizer.decode(ctx_encode[ans_start:(ans_end + 1)])
        print('ori answer: {}'.format(orig_answer.strip()))
        print('dec answer: {}'.format(decode_answer.strip()))
        em, f1, prec, recall = answer_score(prediction=decode_answer,
                                            gold=orig_answer)
        print('em {} f1 {} prec {} recall {}'.format(em, f1, prec, recall))
        print('*' * 75)
        return em, f1, prec, recall, len(ans_tups)

    def doc_sent_ans_consistent(row, orig_row):
        answer_start = row['ans_start']
        answer_end = row['ans_end']

        sent_start = row['sent_start']
        sent_end = row['sent_end']

        doc_start = row['doc_start']
        doc_end = row['doc_end']

        doc_ans_label = row['doc_ans_label']
        doc_with_ans_idx = [x[0] for x in enumerate(doc_ans_label) if x[1] > 1]
        sent_ans_label = row['sent_ans_label']
        sent_with_ans_idx = [
            x[0] for x in enumerate(sent_ans_label) if x[1] > 1
        ]

        ctx_encode = row['ctx_encode']
        ctx_encode = torch.LongTensor(ctx_encode)

        answer_type = row['answer_type']
        if answer_type[0] == 0:
            ans_doc_start = doc_start[doc_with_ans_idx[0]]
            ans_doc_end = doc_end[doc_with_ans_idx[0]]

            ans_sent_start = sent_start[sent_with_ans_idx[0]]
            ans_sent_end = sent_end[sent_with_ans_idx[0]]
            # print('ans {}\n sent {}\n doc{}'.format((answer_start, answer_end),
            #                                          (ans_sent_start, ans_sent_end),
            #                                          (ans_doc_start, ans_doc_end)))

            flag1 = (answer_start[0] >= ans_sent_start) and (answer_end[0] <=
                                                             ans_sent_end)
            flag2 = (answer_start[0] >= ans_doc_start) and (answer_end[0] <=
                                                            ans_doc_end)
            flag3 = (ans_sent_start >= ans_doc_start) and (ans_sent_end <=
                                                           ans_doc_end)

            # print('ans {} sent {} doc {}'.format(flag1, flag2, flag3))
            # if not (flag1 and flag2 and flag3):
            #     print('wrong preprocess')
            print('ans {}\n sent {}\n doc{}\n'.format(
                tokenizer.decode(ctx_encode[answer_start[0]:(answer_end[0] +
                                                             1)]),
                tokenizer.decode(ctx_encode[ans_sent_start:(ans_sent_end +
                                                            1)]),
                tokenizer.decode(ctx_encode[ans_doc_start:(ans_doc_end + 1)])))

    # em_score = 0.0
    # f1_score = 0.0
    # ans_count_array = []
    # for row_idx, row in encoded_data.iterrows():
    #     # support_doc_checker(row, orig_data.iloc[row_idx])
    #     # support_sent_checker(row, orig_data.iloc[row_idx])
    #     em, f1, prec, recall, ans_count = answer_checker(row, orig_data.iloc[row_idx])
    #     em_score = em_score + em
    #     f1_score = f1_score + f1
    #     ans_count_array.append(ans_count)
    # print('em {} f1 {}'.format(em_score/encoded_data.shape[0], f1_score/encoded_data.shape[0]))
    # occurrences = dict(collections.Counter(ans_count_array))
    # for key, value in occurrences.items():
    #     print('{}\t{}'.format(key, value*1.0/encoded_data.shape[0]))
    # print(occurrences)
    #########################################
    # max_len = 0
    # equal_count = 0
    # for row_idx, row in encoded_data.iterrows():
    #     doc_len, equal_flag = support_doc_checker(row, orig_data.iloc[row_idx])
    #     if equal_flag:
    #         equal_count = equal_count + 1
    #     if max_len < doc_len:
    #         max_len = doc_len
    #     # sent_len = support_sent_checker(row, orig_data.iloc[row_idx])
    #     # if max_len < sent_len:
    #     #     max_len = sent_len
    # print(max_len)
    # print(equal_count, equal_count * 1.0/encoded_data.shape[0])
    #########################################
    for row_idx, row in encoded_data.iterrows():
        doc_sent_ans_consistent(row, orig_data.iloc[row_idx])
示例#6
0
def consistent_checker():
    tokenizer = get_hotpotqa_longformer_tokenizer()
    encoded_data = loadJSONData(PATH=distractor_wiki_path, json_fileName=dev_processed_data_name)
    orig_data_frame, _ = HOTPOT_DevData_Distractor()
    # orig_data_frame, _ = HOTPOT_TrainData()
    encoded_data['e_id'] = range(0, encoded_data.shape[0])
    dev_data_loader = get_val_data_loader(data_frame=encoded_data, tokenizer=tokenizer)
    # dev_data_loader = get_train_data_loader(data_frame=encoded_data, tokenizer=tokenizer)

    def answer_checker(row, orig_row):
        ans_start = row['ans_start'][0]
        ans_end = row['ans_end'][0]
        ctx_encode = row['ctx_encode'][0]
        answer = orig_row['answer']


        print('orig answer: {}\ndeco answer: {}'.format(answer,
                                                                         tokenizer.decode(ctx_encode[ans_start:(ans_end+1)])))
        print('*' * 75)

    def doc_sent_checker(row, orig_row):
        doc_label = row['doc_labels'][0]
        doc_start = row['doc_start'][0]
        doc_end = row['doc_end'][0]
        ctx_encode = row['ctx_encode'][0]
        doc_num = doc_label.shape[0]
        pos_doc_idx = (doc_label > 0).nonzero().detach().tolist()
        supp_docs = orig_row['supporting_facts']
        global_attn = row['ctx_global_mask'][0]
        global_attn_mask_idxes = (global_attn == 1).nonzero(as_tuple=False).squeeze()
        print(global_attn_mask_idxes)

        print('global attn = {}'.format(global_attn))
        print(tokenizer.decode(ctx_encode[doc_start]))

        for doc_idx in pos_doc_idx:
            print('doc {}'.format(tokenizer.decode(ctx_encode[doc_start[doc_idx]:(doc_end[doc_idx] + 1)])))

        print(doc_label, doc_num)
        print('=' * 75)

        sent_label = row['sent_labels'][0]
        sent_start = row['sent_start'][0]
        sent_end = row['sent_end'][0]
        ctx_encode = row['ctx_encode'][0]
        sent_num = sent_label.shape[0]
        print(sent_label)
        pos_sent_idx = (sent_label > 0).nonzero().detach().tolist()
        print(pos_sent_idx)
        # for i in range(sent_num):
        #     print('sent end {}'.format(tokenizer.decode(ctx_encode[sent_end[i]:(sent_end[i] + 1)])))

        print('+' * 75)
        s2d_map = row['s2d_map'][0]
        print('sent_2_doc {} {}'.format(s2d_map, row['sent_lens'][0].shape))
        sInd_map = row['sInd_map'][0]
        # print(sInd_map)
        print('+' * 75)

        doc_head = row['head_idx']
        doc_tail = row['tail_idx']
        print(doc_head)
        print(doc_tail)

        context = orig_row['context']

        for sent_idx in pos_sent_idx:
            print('Sent\n {}'.format(tokenizer.decode(ctx_encode[sent_start[sent_idx]:(sent_end[sent_idx] + 1)])))
            # print('Pair {}'.format((s2d_map[sent_idx], sInd_map[sent_idx])))
            doc_idx = s2d_map[sent_idx][0].detach().item()
            sent_idx = sInd_map[sent_idx][0].detach().item()
            print('doc idx {} sent idx {}'.format(doc_idx, sent_idx))
            print('Sent pair\n {}'.format((context[doc_idx][0], sent_idx, context[doc_idx][1][sent_idx])))
        print(supp_docs)
        print('*' * 75)

    for batch_idx, batch in enumerate(dev_data_loader):
        row = batch
        orig_row = orig_data_frame.iloc[batch_idx]
        doc_sent_checker(row, orig_row)

        if batch_idx >=100:
            break