Exemplo n.º 1
0
def hotpotqa_preprocess_example():
    start_time = time()
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=PRE_TAINED_LONFORMER_BASE)
    longformer_tokenizer = LongformerQATensorizer(tokenizer=tokenizer,
                                                  max_length=-1)
    dev_data, _ = HOTPOT_DevData_Distractor()
    print('*' * 75)
    dev_test_data = Hotpot_Test_Data_PreProcess(data=dev_data,
                                                tokenizer=longformer_tokenizer)
    print('Get {} dev-test records'.format(dev_test_data.shape[0]))
    dev_test_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_test_distractor_wiki_tokenized.json'))
    print('*' * 75)
    dev_data, _ = HOTPOT_DevData_Distractor()
    dev_data = Hotpot_Dev_Data_Preprocess(data=dev_data,
                                          tokenizer=longformer_tokenizer)
    print('Get {} dev records'.format(dev_data.shape[0]))
    dev_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_dev_distractor_wiki_tokenized.json'))
    print('*' * 75)
    train_data, _ = HOTPOT_TrainData()
    train_data = Hotpot_Train_Data_Preprocess(data=train_data,
                                              tokenizer=longformer_tokenizer)
    print('Get {} training records'.format(train_data.shape[0]))
    train_data.to_json(
        os.path.join(abs_distractor_wiki_path,
                     'hotpot_train_distractor_wiki_tokenized.json'))
    print('Runtime = {:.4f} seconds'.format(time() - start_time))
    print('*' * 75)
def test_data_loader_checker():
    file_path = '../data/hotpotqa/distractor_qa'
    dev_file_name = 'hotpot_dev_distractor_wiki_tokenized.json'
    from torch.utils.data import DataLoader
    batch_size = 1
    data_frame = read_train_dev_data_frame(PATH=file_path,
                                           json_fileName=dev_file_name)
    longtokenizer = get_hotpotqa_longformer_tokenizer()
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=longtokenizer,
                                               max_length=4096)
    start_time = time()
    test_dataloader = DataLoader(HotpotDevDataset(
        data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 collate_fn=HotpotDevDataset.collate_fn)
    for batch_idx, sample in enumerate(test_dataloader):
        sd_mask = sample['sd_mask']
        # print(sd_mask)
        # print(sd_mask[0])
        print(sample['doc_lens'])
        print(sample['sent_lens'])

        ss_mask = sample['ss_mask']
        # print(ss_mask[0].detach().tolist())
        print(ss_mask.shape)
        print(ss_mask[0].sum(dim=1))
        print(sd_mask.shape)
        break
    print('Runtime = {}'.format(time() - start_time))
def get_model(args):
    start_time = time()
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=args.pretrained_cfg_name)
    longEncoder = LongformerEncoder.init_encoder(
        cfg_name=args.pretrained_cfg_name,
        projection_dim=args.project_dim,
        hidden_dropout=args.input_drop,
        attn_dropout=args.attn_drop,
        seq_project=args.seq_project)
    longEncoder.resize_token_embeddings(len(tokenizer))
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    if args.frozen_layer_num > 0:
        modules = [
            longEncoder.embeddings,
            *longEncoder.encoder.layer[:args.frozen_layer_num]
        ]
        for module in modules:
            for param in module.parameters():
                param.requires_grad = False
        logging.info('Frozen the first {} layers'.format(
            args.frozen_layer_num))
    logging.info('Loading encoder takes {:.4f}'.format(time() - start_time))
    model = LongformerHotPotQAModel(longformer=longEncoder,
                                    num_labels=args.num_labels)
    logging.info(
        'Constructing reasonModel completes in {:.4f}'.format(time() -
                                                              start_time))
    return model
def data_loader_checker():
    file_path = '../data/hotpotqa/distractor_qa'
    dev_file_name = 'hotpot_dev_distractor_wiki_tokenized.json'
    from torch.utils.data import DataLoader
    batch_size = 6

    data_frame = read_train_dev_data_frame(PATH=file_path,
                                           json_fileName=dev_file_name)
    for col in data_frame.columns:
        print(col)
    longtokenizer = get_hotpotqa_longformer_tokenizer()
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=longtokenizer,
                                               max_length=4096)
    start_time = time()
    dev_dataloader = DataLoader(HotpotDevDataset(
        data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=6,
                                collate_fn=HotpotDevDataset.collate_fn)

    for batch_idx, sample in enumerate(dev_dataloader):
        x = sample['doc_start']
        # print(sample['doc_start'].shape)
        # print(sample['sent_start'].shape)
    print('Runtime = {}'.format(time() - start_time))
Exemplo n.º 5
0
def performance_collection(folder_name):
    print('Loading tokenizer')
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=PRE_TAINED_LONFORMER_BASE, do_lower_case=True)
    json_file_names = get_all_json_files(file_path=folder_name)
    json_file_names = [x for x in json_file_names if x != 'config.json']
    print('{} json files have been found'.format(len(json_file_names)))
    max_sp_sent_f1 = 0
    max_metric_res = None
    max_json_file_name = None
    for idx, json_file_name in enumerate(json_file_names):
        if json_file_name != 'config.json':
            data_frame_i = load_data_frame_align_with_dev(
                file_path=folder_name, json_fileName=json_file_name)
            metrics_i = convert2leadBoard(data=data_frame_i,
                                          tokenizer=tokenizer)
            if max_sp_sent_f1 < metrics_i['sp_f1']:
                max_sp_sent_f1 = metrics_i['sp_f1']
                max_metric_res = metrics_i
                max_json_file_name = json_file_name
            print_metrics(name=json_file_name, metrics=metrics_i)
            print('*' * 75)
    print('+' * 75)
    print_metrics(name=max_json_file_name, metrics=max_metric_res)
    print('+' * 75)
Exemplo n.º 6
0
def get_model(args):
    start_time = time()
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=args.pretrained_cfg_name)
    longEncoder = LongformerEncoder.init_encoder(
        cfg_name=args.pretrained_cfg_name,
        projection_dim=args.project_dim,
        hidden_dropout=args.input_drop,
        attn_dropout=args.attn_drop,
        seq_project=args.seq_project)
    longEncoder.resize_token_embeddings(len(tokenizer))
    model = LongformerHotPotQAModel(longformer=longEncoder,
                                    num_labels=args.num_labels,
                                    args=args)
    logging.info(
        'Constructing reasonModel completes in {:.4f}'.format(time() -
                                                              start_time))
    return model
Exemplo n.º 7
0
def get_test_data_loader(args):
    data_frame = read_train_dev_data_frame(file_path=args.data_path,
                                           json_fileName=args.dev_data_name)
    batch_size = args.test_batch_size
    data_size = data_frame.shape[0]
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=args.pretrained_cfg_name, do_lower_case=True)
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=tokenizer,
                                               max_length=args.max_ctx_len)
    dataloader = DataLoader(HotpotDevDataset(
        data_frame=data_frame,
        hotpot_tensorizer=hotpot_tensorizer,
        max_sent_num=args.max_sent_num,
        global_mask_type=args.global_mask_type),
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=max(1, args.cpu_num // 2),
                            collate_fn=HotpotDevDataset.collate_fn)
    return dataloader, data_size
def get_train_data_loader(args):
    data_frame = read_train_dev_data_frame(file_path=args.data_path,
                                           json_fileName=args.train_data_name)
    batch_size = args.batch_size
    #####################################################
    training_data_shuffle = args.training_shuffle == 1
    #####################################################
    data_size = data_frame.shape[0]
    if args.train_data_filtered == 1:
        data_frame = data_frame[data_frame['level'] != 'easy']
        logging.info('Filtered data by removing easy case {} to {}'.format(
            data_size, data_frame.shape[0]))
    elif args.train_data_filtered == 2:
        data_frame = data_frame[data_frame['level'] == 'hard']
        logging.info(
            'Filtered data by removing easy and medium case {} to {}'.format(
                data_size, data_frame.shape[0]))
    else:
        logging.info('Using all training data {}'.format(data_size))

    data_size = data_frame.shape[0]
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=args.pretrained_cfg_name, do_lower_case=True)
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=tokenizer,
                                               max_length=args.max_ctx_len)
    dataloader = DataLoader(HotpotTrainDataset(
        data_frame=data_frame,
        hotpot_tensorizer=hotpot_tensorizer,
        max_doc_num=args.max_doc_num,
        max_sent_num=args.max_sent_num,
        global_mask_type=args.global_mask_type,
        training_shuffle=training_data_shuffle),
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=max(1, args.cpu_num // 2),
                            collate_fn=HotpotTrainDataset.collate_fn)
    return dataloader, data_size
Exemplo n.º 9
0
def main(model_args):
    args = get_config(PATH=model_args.model_path,
                      config_json_name=model_args.model_config_name)
    args.check_point = model_args.model_name
    args.data_path = model_args.data_path
    args.test_batch_size = model_args.test_batch_size
    args.doc_threshold = model_args.doc_threshold
    args.save_path = model_args.model_path
    if torch.cuda.is_available():
        args.cuda = True
    else:
        args.cuda = False
    ###################
    if args.data_path is None:
        raise ValueError('one of data_path must be chosed.')
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    set_logger(args)
    ########+++++++++++++++++++++++++++++
    abs_path = os.path.abspath(args.data_path)
    args.data_path = abs_path
    ########+++++++++++++++++++++++++++++
    # Write logs to checkpoint and console
    if args.cuda:
        if args.gpu_num > 1:
            device_ids, used_memory = gpu_setting(args.gpu_num)
        else:
            device_ids, used_memory = gpu_setting()
        if used_memory > 100:
            logging.info('Using memory = {}'.format(used_memory))
        if device_ids is not None:
            if len(device_ids) > args.gpu_num:
                device_ids = device_ids[:args.gpu_num]
            device = torch.device('cuda:{}'.format(device_ids[0]))
        else:
            device = torch.device('cuda:0')
        logging.info('Set the cuda with idxes = {}'.format(device_ids))
        logging.info('cuda setting {}'.format(device))
        logging.info('GPU setting')
    else:
        device_ids = None
        device = torch.device('cpu')
        logging.info('CPU setting')
    ########+++++++++++++++++++++++++++++
    logging.info('Loading development data...')
    test_data_loader, _ = get_test_data_loader(args=args)
    logging.info('Loading data completed')
    logging.info('*' * 75)
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    logging.info('Loading Model...')
    model = get_model(args=args).to(device)
    ##+++++++++++
    model_path = args.save_path
    model_file_name = args.check_point
    hotpot_qa_model_name = os.path.join(model_path, model_file_name)
    model = load_model(model=model, PATH=hotpot_qa_model_name)
    model = model.to(device)
    if device_ids is not None:
        if len(device_ids) > 1:
            model = DataParallel(model,
                                 device_ids=device_ids,
                                 output_device=device)
            logging.info('Data Parallel model setting')
    ##+++++++++++
    logging.info('Model Parameter Configuration:')
    for name, param in model.named_parameters():
        logging.info('Parameter {}: {}, require_grad = {}'.format(
            name, str(param.size()), str(param.requires_grad)))
    logging.info('*' * 75)
    logging.info("Model hype-parameter information...")
    for key, value in vars(args).items():
        logging.info('Hype-parameter\t{} = {}'.format(key, value))
    logging.info('*' * 75)
    logging.info("Model hype-parameter information...")
    for key, value in vars(model_args).items():
        logging.info('Hype-parameter\t{} = {}'.format(key, value))
    logging.info('*' * 75)
    logging.info('projection_dim = {}'.format(args.project_dim))
    logging.info('Multi-task encoding')
    logging.info('*' * 75)
    logging.info('Loading tokenizer')
    tokenizer = get_hotpotqa_longformer_tokenizer()
    logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    # logging.info('Multi-task encoding')
    # metric_dict = multi_task_decoder(model=model, device=device, test_data_loader=test_data_loader, args=args)
    # answer_type_acc = metric_dict['answer_type_acc']
    # logging.info('*' * 75)
    # logging.info('Answer type prediction accuracy: {}'.format(answer_type_acc))
    # logging.info('*' * 75)
    # for key, value in metric_dict.items():
    #     if key.endswith('metrics'):
    #         logging.info('{} prediction'.format(key))
    #         log_metrics('Valid', 'final', value)
    # logging.info('*' * 75)
    # ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    # dev_data_frame = metric_dict['res_dataframe']
    # ##################################################
    # leadboard_metric, res_data_frame = convert2leadBoard(data=dev_data_frame, tokenizer=tokenizer)
    # ##=================================================
    # logging.info('*' * 75)
    # log_metrics('Evaluation', step='leadboard', metrics=leadboard_metric)
    # logging.info('*' * 75)
    # date_time_str = get_date_time()
    # dev_result_name = os.path.join(args.save_path,
    #                                date_time_str + '_mt_evaluation.json')
    # res_data_frame.to_json(dev_result_name, orient='records')
    # logging.info('Saving {} record results to {}'.format(res_data_frame.shape, dev_result_name))
    # logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    logging.info('Hierarchical encoding')
    metric_dict = hierartical_decoder(model=model,
                                      device=device,
                                      test_data_loader=test_data_loader,
                                      doc_topk=model_args.doc_topk,
                                      args=args)
    answer_type_acc = metric_dict['answer_type_acc']
    logging.info('*' * 75)
    logging.info('Answer type prediction accuracy: {}'.format(answer_type_acc))
    logging.info('*' * 75)
    for key, value in metric_dict.items():
        if key.endswith('metrics'):
            logging.info('{} prediction'.format(key))
            log_metrics('Valid', 'final', value)
        logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    topk_dev_data_frame = metric_dict['topk_dataframe']
    ##################################################
    topk_leadboard_metric, topk_res_data_frame = convert2leadBoard(
        data=topk_dev_data_frame, tokenizer=tokenizer)
    ##=================================================
    log_metrics('Topk Evaluation',
                step='leadboard',
                metrics=topk_leadboard_metric)
    date_time_str = get_date_time()
    topk_dev_result_name = os.path.join(
        args.save_path, date_time_str + '_topk_hi_evaluation.json')
    topk_res_data_frame.to_json(topk_dev_result_name, orient='records')
    logging.info('Saving {} record results to {}'.format(
        topk_res_data_frame.shape, topk_dev_result_name))
    logging.info('*' * 75)
    ##=================================================
    thresh_dev_data_frame = metric_dict['thresh_dataframe']
    ##################################################
    thresh_leadboard_metric, thresh_res_data_frame = convert2leadBoard(
        data=thresh_dev_data_frame, tokenizer=tokenizer)
    log_metrics('Thresh Evaluation',
                step='leadboard',
                metrics=thresh_leadboard_metric)
    ##=================================================
    date_time_str = get_date_time()
    thresh_dev_result_name = os.path.join(
        args.save_path, date_time_str + '_thresh_hi_evaluation.json')
    thresh_res_data_frame.to_json(thresh_dev_result_name, orient='records')
    logging.info('Saving {} record results to {}'.format(
        thresh_res_data_frame.shape, thresh_dev_result_name))
    logging.info('*' * 75)
def answer_consistent_checker():
    file_path = '../data/hotpotqa/distractor_qa'
    dev_file_name = 'hotpot_dev_distractor_wiki_tokenized.json'
    from torch.utils.data import DataLoader
    batch_size = 1

    data_frame = read_train_dev_data_frame(PATH=file_path,
                                           json_fileName=dev_file_name)
    print(data_frame['answer_len'].max())
    # for col in data_frame.columns:
    #     print(col)
    longtokenizer = get_hotpotqa_longformer_tokenizer()
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=longtokenizer,
                                               max_length=4096)
    start_time = time()
    dev_dataloader = DataLoader(HotpotTrainDataset(
        data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=14,
                                collate_fn=HotpotTrainDataset.collate_fn)
    max_seq_len = 0
    average_seq_len = 0
    count = 0
    max_answer_len = 0
    for batch_idx, sample in enumerate(dev_dataloader):
        # if batch_idx % 1000 == 0:
        #     print(batch_idx)
        ctx_encode = sample['ctx_encode']
        ctx_encode_lens = sample['doc_lens']

        answer_start = sample['ans_start'].squeeze(dim=-1)
        answer_end = sample['ans_end'].squeeze(dim=-1)
        doc_start = sample['doc_start'].squeeze(dim=-1)
        doc_end = sample['doc_end'].squeeze(dim=-1)
        sent_start = sample['sent_start'].squeeze(dim=-1)
        batch_size = ctx_encode.shape[0]
        for id in range(batch_size):
            # doc_token_num = ctx_encode_lens[id].sum().data.item()
            doc_token_num = doc_end[id].detach().tolist()[-1]
            if max_seq_len < doc_token_num:
                max_seq_len = doc_token_num
            average_seq_len = average_seq_len + doc_token_num
            count = count + 1
            doc_start_i = doc_start[id]
            sent_start_i = sent_start[id]
            ctx_encode_i = ctx_encode[id]
            ans_start_i = answer_start[id].data.item()
            ans_end_i = answer_end[id].data.item()
            if max_answer_len < (ans_end_i - ans_start_i) + 1:
                max_answer_len = (ans_end_i - ans_start_i) + 1
            decode_answer = longtokenizer.decode(
                ctx_encode_i[ans_start_i:(ans_end_i + 1)])

            # print('{}\t{}'.format(batch_idx, decode_answer))
            # if '<p>' in decode_answer or '<d>' in decode_answer or '<q>' in decode_answer or '</q>' in decode_answer:
            #     print('index = {}'.format(batch_idx))
            #     print('decode answer {}'.format(decode_answer))
            #     print('Decode Query {}'.format(longtokenizer.decode(ctx_encode_i[:doc_start_i[0]])))
            # print('decode answer {}'.format(decode_answer))

    print('max seq len: {} average seq len: {}, {}'.format(
        max_seq_len, average_seq_len / count, count))
    print('max answer len: {}'.format(max_answer_len))
    return
def data_consistent_checker(train=True):
    file_path = '../data/hotpotqa/distractor_qa'
    from torch.utils.data import DataLoader
    batch_size = 2
    longtokenizer = get_hotpotqa_longformer_tokenizer()
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=longtokenizer,
                                               max_length=4096)
    if train:
        dev_file_name = 'hotpot_train_distractor_wiki_tokenized.json'
        data_frame = read_train_dev_data_frame(PATH=file_path,
                                               json_fileName=dev_file_name)
        start_time = time()
        dev_dataloader = DataLoader(HotpotTrainDataset(
            data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=1,
                                    collate_fn=HotpotTrainDataset.collate_fn)
    else:
        dev_file_name = 'hotpot_dev_distractor_wiki_tokenized.json'
        data_frame = read_train_dev_data_frame(PATH=file_path,
                                               json_fileName=dev_file_name)
        start_time = time()
        dev_dataloader = DataLoader(HotpotDevDataset(
            data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=1,
                                    collate_fn=HotpotDevDataset.collate_fn)

    batch_data_frame = data_frame.head(batch_size)
    print(batch_data_frame.shape)
    for idx, row in batch_data_frame.iterrows():
        context = row['context']
        supp_fact_filtered = row['supp_facts_filtered']
        # for supp, sen_idx in supp_fact_filtered:
        #     print('Support doc: {}, sent id: {}'.format(supp, sen_idx))
        print('Query {}'.format(row['question']))
        for doc_idx, doc in enumerate(context):
            # print('doc {}: title = {} \n text = {}'.format(doc_idx + 1, doc[0], '\n'.join(doc[1])))
            print('doc {}: title = {}'.format(doc_idx + 1, doc[0]))
            for supp, sen_idx in supp_fact_filtered:
                if doc[0] == supp:
                    print('supp fact doc {}: sent = {} text = {}'.format(
                        doc_idx, sen_idx, doc[1][sen_idx]))
        print('*' * 70)
        print('Original answer = {}'.format(row['norm_answer']))
        print('=' * 70)
    print('+' * 70)
    print('\n' * 3)

    for batch_idx, sample in enumerate(dev_dataloader):
        # for key, value in sample.items():
        #     print(key)
        ctx_encode = sample['ctx_encode']
        ctx_marker_mask = sample['marker']
        global_atten = sample['ctx_global_mask']
        atten_mask = sample['ctx_attn_mask']
        sup_sent_labels = sample['sent_labels'].squeeze(dim=-1)
        sent2doc_map = sample['s2d_map']
        sentIndoc_map = sample['sInd_map']
        sent_start = sample['sent_start']
        sent_end = sample['sent_end']
        # print('sent num = {}'.format(sent_end.shape[1]))
        answer_start = sample['ans_start'].squeeze(dim=-1)
        answer_end = sample['ans_end'].squeeze(dim=-1)
        doc_start = sample['doc_start'].squeeze(dim=-1)
        token2sent_map = sample['t2s_map'].squeeze(dim=-1)
        if train:
            head_idx = sample['head_idx'].squeeze(dim=-1)
            tail_idx = sample['tail_idx'].squeeze(dim=-1)

        for id in range(batch_size):
            ctx_marker_i = ctx_marker_mask[id]
            supp_idxes = (sup_sent_labels[id] > 0).nonzero().squeeze()
            doc_idxes = sent2doc_map[id][supp_idxes].detach().tolist()
            sent_idxes = sentIndoc_map[id][supp_idxes].detach().tolist()
            doc_start_i = doc_start[id]
            doc_sent_pairs = list(zip(doc_idxes, sent_idxes))
            sent_start_i = sent_start[id]
            sent_end_i = sent_end[id]
            ctx_encode_i = ctx_encode[id]
            token2sent_map_i = token2sent_map[id]

            # print('token to sentence {}'.format(token2sent_map_i.max()))
            max_sent_num = token2sent_map_i.max().data.item()
            for ssss_id in range(max_sent_num):
                sent_iiii_idexs = (
                    token2sent_map_i == ssss_id).nonzero().squeeze()
                print('sent {} text = {}'.format(
                    ssss_id,
                    longtokenizer.decode(ctx_encode_i[sent_iiii_idexs])))

            if train:
                print('head doc idx = {}'.format(head_idx[id]))
                print('tail doc idx = {}'.format(tail_idx[id]))

            global_atten_i = global_atten[id]
            global_atten_i_indexes = (global_atten_i > 0).nonzero().squeeze()
            global_atten_text = longtokenizer.decode(
                ctx_encode_i[global_atten_i_indexes])
            print('global attention text: {}'.format(global_atten_text))

            atten_i = atten_mask[id]
            atten_i_indexes = (atten_i > 0).nonzero().squeeze()
            atten_text = longtokenizer.decode(ctx_encode_i[atten_i_indexes])
            # print('attention text: {}'.format(atten_text))
            print('x' * 75)
            # print('decode text: {}'.format(longtokenizer.decode(ctx_encode_i)))

            ans_start_i = answer_start[id].data.item()
            ans_end_i = answer_end[id].data.item()
            #
            print('Decode Query {}'.format(
                longtokenizer.decode(ctx_encode_i[:doc_start_i[0]])))
            print('Decode Answer {}'.format(
                longtokenizer.decode(ctx_encode_i[ans_start_i:(ans_end_i +
                                                               1)])))

            ctx_marker_i_indexes = (ctx_marker_i > 0).nonzero().squeeze()
            print('Decode marker text = {}'.format(
                longtokenizer.decode(ctx_encode_i[ctx_marker_i_indexes])))
            for ss_id, x in enumerate(doc_sent_pairs):
                supp_idddd = supp_idxes[ss_id]
                start_i, end_i = sent_start_i[
                    supp_idddd], sent_end_i[supp_idddd] + 1
                print('doc {}, sent {}, text {}'.format(
                    x[0], x[1],
                    longtokenizer.decode(ctx_encode_i[start_i:end_i])))
            print('=' * 70)
        break
    return
def data_loader_consistent_checker(train=True):
    file_path = '../data/hotpotqa/distractor_qa'
    if train:
        dev_file_name = 'hotpot_train_distractor_wiki_tokenized.json'
    else:
        dev_file_name = 'hotpot_dev_distractor_wiki_tokenized.json'
    data_frame = read_train_dev_data_frame(PATH=file_path,
                                           json_fileName=dev_file_name)
    longtokenizer = get_hotpotqa_longformer_tokenizer()
    hotpot_tensorizer = LongformerQATensorizer(tokenizer=longtokenizer,
                                               max_length=4096)
    start_time = time()
    from torch.utils.data import DataLoader
    batch_size = 1
    if train:
        dev_dataloader = DataLoader(HotpotTrainDataset(
            data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=1,
                                    collate_fn=HotpotTrainDataset.collate_fn)
    else:
        dev_dataloader = DataLoader(HotpotDevDataset(
            data_frame=data_frame, hotpot_tensorizer=hotpot_tensorizer),
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=1,
                                    collate_fn=HotpotDevDataset.collate_fn)

    head_two = data_frame.head(batch_size)
    print(type(head_two))
    for idx, row in head_two.iterrows():
        context = row['context']
        supp_fact_filtered = row['supp_facts_filtered']
        for supp, sen_idx in supp_fact_filtered:
            print('Support doc: {}, sent id: {}'.format(supp, sen_idx))
            print('-' * 70)
        print()
        print('Query {}'.format(row['question']))
        for doc_idx, doc in enumerate(context):
            print('doc {}: title = {} \n text = {}'.format(
                doc_idx + 1, doc[0], ' '.join(doc[1])))
            print('-' * 70)
        print('*' * 70)
        print()
        print('Original answer = {}'.format(row['norm_answer']))
        print('=' * 70)
    print('+' * 70)
    print('\n' * 5)
    for batch_idx, sample in enumerate(dev_dataloader):
        ctx_encode = sample['ctx_encode']
        doc_start = sample['doc_start'].squeeze(dim=-1)
        sent_start = sample['sent_start'].squeeze(dim=-1)
        answer_start = sample['ans_start'].squeeze(dim=-1)
        answer_end = sample['ans_end'].squeeze(dim=-1)
        if train:
            head_idx = sample['head_idx'].squeeze(dim=-1)
            tail_idx = sample['tail_idx'].squeeze(dim=-1)
        sent_lens = sample['sent_lens'].squeeze(dim=-1)
        attention = sample['ctx_attn_mask'].squeeze(dim=-1)
        global_attenion = sample['ctx_global_mask']
        print('global attention {}'.format(global_attenion))
        marker = sample['marker'].squeeze(dim=-1)

        doc_num = doc_start.shape[1]
        print('doc num: {}'.format(doc_start.shape))
        print('marker {}'.format(marker))
        print('marker shape {}'.format(marker.shape))

        for idx in range(ctx_encode.shape[0]):
            ctx_i = ctx_encode[idx]
            marker_i = marker[idx]

            marker_idx = marker_i.nonzero().squeeze()
            print('marker text {}'.format(
                longtokenizer.decode(ctx_i[marker_idx])))
            print('*' * 75)
            attention_i = attention[idx]
            attn_idx = (attention_i == 1).nonzero().squeeze()
            print('attn text {}'.format(longtokenizer.decode(ctx_i[attn_idx])))
            sent_start_i = sent_start[idx]
            doc_start_i = doc_start[idx]
            if train:
                head_i = head_idx[idx].data.item()
                tail_i = tail_idx[idx].data.item()
            ans_start_i = answer_start[idx].data.item()
            ans_end_i = answer_end[idx].data.item()

            print('Decode Query {}'.format(
                longtokenizer.decode(ctx_i[:doc_start_i[0]])))
            print('*' * 75)
            print('Decoded answer = {}'.format(
                hotpot_tensorizer.to_string(ctx_i[ans_start_i:(ans_end_i +
                                                               1)])))
            print('*' * 75)
            # print(ans_start_i)

            doc_marker = longtokenizer.decode(ctx_i[doc_start_i])
            print('doc_marker: {}'.format(doc_marker))

            sent_marker = longtokenizer.decode(ctx_i[sent_start_i])
            print('doc: {}\nsent: {}\n{}\n{}'.format(doc_marker, sent_marker,
                                                     sent_start_i.shape,
                                                     sent_lens[idx]))
            print('*' * 75)

            for k in range(doc_num):
                if k < doc_num - 1:
                    # doc_k = hotpot_tensorizer.to_string(ctx_i[doc_start_i[k]:doc_start_i[k+1]])
                    doc_k = longtokenizer.decode(
                        ctx_i[doc_start_i[k]:doc_start_i[k + 1]])
                else:
                    # doc_k = hotpot_tensorizer.to_string(ctx_i[doc_start_i[k]:])
                    doc_k = longtokenizer.decode(ctx_i[doc_start_i[k]:])
                # print(doc_marker)
                print('Supp doc {}: text = {}'.format(k + 1, doc_k))
                if train:
                    if k == head_i:
                        print('=' * 70)
                        print('Head positive doc {}: text: {}'.format(
                            head_i + 1, doc_k))
                        print('=' * 70)
                    if k == tail_i:
                        print('=' * 70)
                        print('Tail positive doc {}: text: {}'.format(
                            tail_i + 1, doc_k))
                        print('=' * 70)
                    print('-' * 70)
            print('*' * 70)
            print()
        # print(ctx_encode.shape)
        break
    print('Runtime = {}'.format(time() - start_time))
Exemplo n.º 13
0
def main(args):
    set_seeds(args.rand_seed)
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be chosed.')

    if args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be chosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained reasonModel?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    ########+++++++++++++++++++++++++++++
    abs_path = os.path.abspath(args.data_path)
    args.data_path = abs_path
    ########+++++++++++++++++++++++++++++
    # Write logs to checkpoint and console
    set_logger(args)
    if args.cuda:
        if args.do_debug:
            if args.gpu_num > 1:
                device_ids, used_memory = gpu_setting(args.gpu_num)
            else:
                device_ids, used_memory = gpu_setting()
            if used_memory > 100:
                logging.info('Using memory = {}'.format(used_memory))
            if device_ids is not None:
                if len(device_ids) > args.gpu_num:
                    device_ids = device_ids[:args.gpu_num]
                device = torch.device('cuda:{}'.format(device_ids[0]))
            else:
                device = torch.device('cuda:0')
            logging.info('Set the cuda with idxes = {}'.format(device_ids))
            logging.info('cuda setting {}'.format(device))
        else:
            if args.gpu_num > 1:
                logging.info("Using GPU!")
                available_device_count = torch.cuda.device_count()
                logging.info('GPU number is {}'.format(available_device_count))
                if args.gpu_num > available_device_count:
                    args.gpu_num = available_device_count
                # ++++++++++++++++++++++++++++++++++
                device_ids, used_memory = gpu_setting(args.gpu_num)
                # ++++++++++++++++++++++++++++++++++
                device = torch.device("cuda:{}".format(device_ids[0]))
                # ++++++++++++++++++++++++++++++++++
            else:
                device = torch.device("cuda:0")
                device_ids = None
                logging.info('Single GPU setting')
    else:
        device = torch.device('cpu')
        device_ids = None
        logging.info('CPU setting')

    logging.info('Device = {}, Device ids = {}'.format(device, device_ids))

    logging.info('Loading training data...')
    train_data_loader, train_data_size = get_train_data_loader(args=args)
    estimated_max_steps = args.epoch * (
        (train_data_size // args.batch_size) + 1)
    if estimated_max_steps > args.max_steps:
        args.max_steps = args.epoch * (
            (train_data_size // args.batch_size) + 1)
    logging.info('Loading development data...')
    dev_data_loader, _ = get_dev_data_loader(args=args)
    logging.info('Loading data completed')
    logging.info('*' * 75)
    tokenizer = get_hotpotqa_longformer_tokenizer()
    logging.info('*' * 75)
    # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    if args.do_train:
        # Set training configuration
        start_time = time()
        logging.info('Loading Model...')
        model = get_model(args=args).to(device)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     weight_decay=args.weight_decay)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if device_ids is not None:
            if len(device_ids) > 1:
                model = DataParallel(model,
                                     device_ids=device_ids,
                                     output_device=device)
                logging.info('Data Parallel model setting')
        logging.info('Model Parameter Configuration:')
        for name, param in model.named_parameters():
            logging.info('Parameter {}: {}, require_grad = {}'.format(
                name, str(param.size()), str(param.requires_grad)))
        logging.info('*' * 75)
        logging.info("Model hype-parameter information...")
        for key, value in vars(args).items():
            logging.info('Hype-parameter\t{} = {}'.format(key, value))
        logging.info('*' * 75)
        logging.info('batch_size = {}'.format(args.batch_size))
        logging.info('projection_dim = {}'.format(args.project_dim))
        logging.info('learning_rate = {}'.format(args.learning_rate))
        logging.info('Start training...')
        train_all_steps(model=model,
                        optimizer=optimizer,
                        dev_dataloader=dev_data_loader,
                        device=device,
                        train_dataloader=train_data_loader,
                        tokenizer=tokenizer,
                        args=args)
        logging.info('Completed training in {:.4f} seconds'.format(time() -
                                                                   start_time))
        logging.info('Evaluating on Valid Dataset...')
        metric_dict = test_all_steps(model=model,
                                     tokenizer=tokenizer,
                                     test_data_loader=dev_data_loader,
                                     args=args)
        answer_type_acc = metric_dict['answer_type_acc']
        logging.info('*' * 75)
        logging.info(
            'Answer type prediction accuracy: {}'.format(answer_type_acc))
        logging.info('*' * 75)
        for key, value in metric_dict.items():
            if key.endswith('metrics'):
                logging.info('{} prediction'.format(key))
                log_metrics('Valid', 'final', value)
        logging.info('*' * 75)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++
        model_save_path = save_check_point(model=model,
                                           optimizer=optimizer,
                                           step='all_step',
                                           loss=None,
                                           eval_metric=None,
                                           args=args)
        logging.info('Saving the mode in {}'.format(model_save_path))
Exemplo n.º 14
0
def hotpot_test_prediction(model, test_data_loader, args):
    model.eval()
    ###########################################################
    start_time = time()
    step = 0
    N = 0
    total_steps = len(test_data_loader)
    # **********************************************************
    answer_type_predicted = []
    answer_span_predicted = []
    supp_sent_predicted = []
    supp_doc_predicted = []
    # **********************************************************
    with torch.no_grad():
        for test_sample in test_data_loader:
            if args.cuda:
                sample = dict()
                for key, value in test_sample.items():
                    sample[key] = value.cuda()
            else:
                sample = test_sample
            output = model(sample)
            N = N + sample['ctx_encode'].shape[0]
            # ++++++++++++++++++
            answer_type_res = output['yn_score']
            if len(answer_type_res.shape) > 1:
                answer_type_res = answer_type_res.squeeze(dim=-1)
            answer_types = torch.argmax(answer_type_res, dim=-1)
            answer_type_predicted += answer_types.detach().tolist()
            # +++++++++++++++++++
            start_logits, end_logits = output['span_score']
            predicted_span_start = torch.argmax(start_logits, dim=-1)
            predicted_span_end = torch.argmax(end_logits, dim=-1)
            predicted_span_start = predicted_span_start.detach().tolist()
            predicted_span_end = predicted_span_end.detach().tolist()
            predicted_span_pair = list(
                zip(predicted_span_start, predicted_span_end))
            answer_span_predicted += predicted_span_pair
            # ++++++++++++++++++
            supp_doc_res = output['doc_score']
            doc_lens = sample['doc_lens']
            doc_mask = doc_lens.masked_fill(doc_lens > 0, 1)
            supp_doc_pred_i = supp_doc_prediction(scores=supp_doc_res,
                                                  mask=doc_mask,
                                                  pred_num=2)
            supp_doc_predicted += supp_doc_pred_i
            # ++++++++++++++++++
            supp_sent_res = output['sent_score']
            sent_lens = sample['sent_lens']
            sent_mask = sent_lens.masked_fill(sent_lens > 0, 1)
            sent_fact_doc_idx, sent_fact_sent_idx = sample['fact_doc'], sample[
                'fact_sent']
            supp_sent_pred_i = supp_sent_prediction(
                scores=supp_sent_res,
                mask=sent_mask,
                doc_fact=sent_fact_doc_idx,
                sent_fact=sent_fact_sent_idx,
                pred_num=2,
                threshold=args.sent_threshold)
            supp_sent_predicted += supp_sent_pred_i
            # +++++++++++++++++++
            step += 1
            if step % args.test_log_steps == 0:
                logging.info(
                    'Testing the reasonModel... {}/{} in {:.4f} seconds'.
                    format(step, total_steps,
                           time() - start_time))
    ##=================================================
    logging.info('Testing complete...')
    logging.info('Loading tokenizer')
    tokenizer = get_hotpotqa_longformer_tokenizer(
        model_name=PRE_TAINED_LONFORMER_BASE, do_lower_case=True)
    logging.info('Loading preprocessed data...')
    data = read_train_dev_data_frame(file_path=args.data_path,
                                     json_fileName=args.test_data_name)
    data['answer_prediction'] = answer_type_predicted
    data['answer_span_prediction'] = answer_span_predicted
    data['supp_doc_prediction'] = supp_doc_predicted
    data['supp_sent_prediction'] = supp_sent_predicted

    def row_process(row):
        answer_prediction = row['answer_prediction']
        answer_span_predicted = row['answer_span_prediction']
        span_start, span_end = answer_span_predicted
        encode_ids = row['ctx_encode']
        if answer_prediction > 0:
            predicted_answer = 'yes' if answer_prediction == 1 else 'no'
        else:
            predicted_answer = tokenizer.decode(
                encode_ids[span_start:(span_end + 1)],
                skip_special_tokens=True)

        ctx_contents = row['context']
        supp_doc_prediction = row['supp_doc_prediction']
        supp_doc_titles = [ctx_contents[idx][0] for idx in supp_doc_prediction]
        supp_sent_prediction = row['supp_sent_prediction']
        supp_sent_pairs = [(ctx_contents[pair_idx[0]][0], pair_idx[1])
                           for pair_idx in supp_sent_prediction]
        return predicted_answer, supp_doc_titles, supp_sent_pairs

    res_names = ['answer', 'sp_doc', 'sp']
    data[res_names] = data.apply(lambda row: pd.Series(row_process(row)),
                                 axis=1)
    return data