コード例 #1
0
ファイル: trainer.py プロジェクト: arkhycat/mrqa
    def get_features(self, train_folder, debug=False):
        pickled_folder = self.args.pickled_folder + "_{}_{}".format(
            self.args.bert_model, str(self.args.skip_no_ans))

        features_lst = []

        files = [f for f in os.listdir(train_folder) if f.endswith(".gz")]
        names = [f.split(".")[0] for f in files]
        print("Number of data sets:{}".format(len(files)))

        for filename in files:
            data_name = filename.split(".")[0]
            # Check whether pkl file already exists
            pickle_file_name = '{}.pkl'.format(data_name)
            pickle_file_path = os.path.join(pickled_folder, pickle_file_name)
            if os.path.exists(pickle_file_path):
                with open(pickle_file_path, 'rb') as pkl_f:
                    print("Loading {} file as pkl...".format(data_name))
                    features_lst.append(pickle.load(pkl_f))
            else:
                print("processing {} file".format(data_name))
                file_path = os.path.join(train_folder, filename)

                train_examples = read_squad_examples(file_path, debug=debug)
                train_features = convert_examples_to_features(
                    examples=train_examples,
                    tokenizer=self.tokenizer,
                    max_seq_length=self.args.max_seq_length,
                    max_query_length=self.args.max_query_length,
                    doc_stride=self.args.doc_stride,
                    is_training=True,
                    skip_no_ans=self.args.skip_no_ans)

                features_lst.append(train_features)

                # Save feature lst as pickle (For reuse & fast loading)
                if not debug and self.args.rank == 0:
                    with open(pickle_file_path, 'wb') as pkl_f:
                        print("Saving {} file from pkl file...".format(
                            data_name))
                        pickle.dump(train_features, pkl_f)

        return features_lst, names
コード例 #2
0
def eval_qa(model, file_path, prediction_file, args, tokenizer, batch_size=50):
    eval_examples = read_squad_examples(file_path, debug=False)

    eval_features = convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length,
        doc_stride=args.doc_stride,
        is_training=False
    )
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
    sampler = SequentialSampler(eval_data)
    eval_loader = DataLoader(eval_data, sampler=sampler, batch_size=batch_size)

    RawResult = collections.namedtuple("RawResult",
                                       ["unique_id", "start_logits", "end_logits"])

    model.eval()
    all_results = []
    example_index = -1
    for _, batch in enumerate(eval_loader):
        input_ids, input_mask, seg_ids = batch
        seq_len = torch.sum(torch.sign(input_ids), 1)
        max_len = torch.max(seq_len)

        input_ids = input_ids[:, :max_len].clone()
        input_mask = input_mask[:, :max_len].clone()
        seg_ids = seg_ids[:, :max_len].clone()

        if args.use_cuda:
            input_ids = input_ids.cuda(args.gpu, non_blocking=True)
            input_mask = input_mask.cuda(args.gpu, non_blocking=True)
            seg_ids = seg_ids.cuda(args.gpu, non_blocking=True)

        with torch.no_grad():
            batch_start_logits, batch_end_logits = model(input_ids, seg_ids, input_mask)
            batch_size = batch_start_logits.size(0)
        for i in range(batch_size):
            example_index += 1
            start_logits = batch_start_logits[i].detach().cpu().tolist()
            end_logits = batch_end_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index]
            unique_id = int(eval_feature.unique_id)
            all_results.append(RawResult(unique_id=unique_id,
                                         start_logits=start_logits,
                                         end_logits=end_logits))

    preds = write_predictions(eval_examples, eval_features, all_results,
                              n_best_size=20, max_answer_length=30, do_lower_case=args.do_lower_case,
                              output_prediction_file=prediction_file)

    answers = read_answers(file_path)
    preds_dict = json.loads(preds)
    metrics = evaluate(answers, preds_dict, skip_no_answer=args.skip_no_ans)

    return metrics