示例#1
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    clean_id_file = open(os.path.join(DATA_DIR, "clean_qids.txt"), "w+")
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    bad_examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        chosen_offset = ex[-2]
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            if pred_s[i][0] >= len(offsets[ex_id[i]]) or pred_e[i][0] >= len(
                    offsets[ex_id[i]]):
                bad_examples += 1
                continue
            if args.use_sentence_selector:
                s_offset = chosen_offset[i][pred_s[i][0]][0]
                e_offset = chosen_offset[i][pred_e[i][0]][1]
            else:
                s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

            f1_example = utils.metric_max_over_ground_truths(
                utils.f1_score, prediction, ground_truths)

            if f1_example != 0:
                clean_id_file.write(ex_id + "\n")

        examples += batch_size

    clean_id_file.close()
    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))
    logger.info('Bad Offset Examples during official eval: %d' % bad_examples)
    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
示例#2
0
def validate_official(args,
                      data_loader,
                      model,
                      global_stats,
                      offsets,
                      texts,
                      answers,
                      mode="dev"):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(
                utils.metric_max_over_ground_truths(utils.exact_match_score,
                                                    prediction, ground_truths))
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info(mode + ' valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
示例#3
0
def evaluate(dataset_file, prediction_file, regex=False):
    print("-" * 50)
    print("Dataset: %s" % dataset_file)
    print("Predictions: %s" % prediction_file)

    answers = []
    for line in open(args.dataset):
        data = json.loads(line)
        answer = [normalize(a) for a in data["answer"]]
        answers.append(answer)

    predictions = []
    with open(prediction_file) as f:
        for line in f:
            data = json.loads(line)
            prediction = normalize(data[0]["span"])
            predictions.append(prediction)

    exact_match = 0
    for i in range(len(predictions)):
        match_fn = regex_match_score if regex else exact_match_score
        exact_match += metric_max_over_ground_truths(match_fn, predictions[i],
                                                     answers[i])
    total = len(predictions)
    exact_match = 100.0 * exact_match / total
    print({"exact_match": exact_match})
示例#4
0
文件: eval.py 项目: athiwatp/DrQA
def evaluate(dataset_file, prediction_file, regex=False):
    print('-' * 50)
    print('Dataset: %s' % dataset_file)
    print('Predictions: %s' % prediction_file)

    answers = []
    for line in open(args.dataset):
        data = json.loads(line)
        answer = [normalize(a) for a in data['answer']]
        answers.append(answer)

    predictions = []
    with open(prediction_file) as f:
        for line in f:
            data = json.loads(line)
            prediction = normalize(data[0]['span'])
            predictions.append(prediction)

    exact_match = 0
    for i in range(len(predictions)):
        match_fn = regex_match_score if regex else exact_match_score
        exact_match += metric_max_over_ground_truths(
            match_fn, predictions[i], answers[i]
        )
    total = len(predictions)
    exact_match = 100.0 * exact_match / total
    print({'exact_match': exact_match})
示例#5
0
def get_rank(prediction_, answer_, use_regex_=False):
    for rank_, entry in enumerate(prediction_):
        if use_regex_:
            match_fn = regex_match_score
        else:
            match_fn = exact_match_score
        exact_match = metric_max_over_ground_truths(match_fn, normalize(entry['span']), answer_)
        if exact_match:
            return rank_ + 1
    return 1000
示例#6
0
文件: train.py 项目: athiwatp/DrQA
def validate_official(args, data_loader, model, global_stats,
                      offsets, texts, answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]

            # Compute metrics
            ground_truths = answers[ex_id[i]]
            exact_match.update(utils.metric_max_over_ground_truths(
                utils.exact_match_score, prediction, ground_truths))
            f1.update(utils.metric_max_over_ground_truths(
                utils.f1_score, prediction, ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
示例#7
0
def compute_expected_metric(args, data_loader, model, global_stats, offsets,
                            texts, answers):
    scores = {}
    preds = {}
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        chosen_offset = ex[-2]
        tup, score_s, score_e = model.predict_probs(ex)
        pred_s, pred_e, _ = tup
        for i in range(batch_size):
            if args.use_sentence_selector:
                s_offset = chosen_offset[i][pred_s[i][0]][0]
                e_offset = chosen_offset[i][pred_e[i][0]][1]
            else:
                s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][s_offset:e_offset]
            ground_truths = answers[ex_id[i]]

            beam = get_y_pred_beam(score_s[i].numpy(), score_e[i].numpy(),
                                   BEAM_SIZE)
            total_prob = sum(x[2] for x in beam)
            score = 0.0
            for (start, end, prob) in beam:
                if args.use_sentence_selector:
                    s_offset = chosen_offset[i][start][0]
                    e_offset = chosen_offset[i][end][1]
                else:
                    s_offset = offsets[ex_id[i]][start][0]
                    e_offset = offsets[ex_id[i]][end][1]
                phrase = texts[ex_id[i]][s_offset:e_offset]
                cur_f1 = utils.metric_max_over_ground_truths(
                    utils.f1_score, prediction, ground_truths)
                score += prob / total_prob * cur_f1
            scores[ex_id[i]] = score
            preds[ex_id[i]] = prediction
    return scores, preds
示例#8
0
def validate_adversarial(args, model, global_stats, mode="dev"):
    # create dataloader for dev sets, load thier jsons, integrate the function

    for idx, dataset_file in enumerate(args.adv_dev_json):

        predictions = {}

        logger.info("Validating Adversarial Dataset %s" % dataset_file)
        exs = utils.load_data(args, args.adv_dev_file[idx])
        logger.info('Num dev examples = %d' % len(exs))
        ## Create dataloader
        dev_dataset = reader_data.ReaderDataset(exs,
                                                model,
                                                single_answer=False)
        if args.sort_by_len:
            dev_sampler = reader_data.SortedBatchSampler(dev_dataset.lengths(),
                                                         args.test_batch_size,
                                                         shuffle=False)
        else:
            dev_sampler = torch.utils.data.sampler.SequentialSampler(
                dev_dataset)
        if args.use_sentence_selector:
            dev_batcher = reader_vector.sentence_batchifier(
                model, single_answer=False)
            #batching_function = dev_batcher.batchify
            batching_function = reader_vector.batchify
        else:
            batching_function = reader_vector.batchify
        dev_loader = torch.utils.data.DataLoader(
            dev_dataset,
            batch_size=args.test_batch_size,
            sampler=dev_sampler,
            num_workers=args.data_workers,
            collate_fn=batching_function,
            pin_memory=args.cuda,
        )

        texts = utils.load_text(dataset_file)
        offsets = {ex['id']: ex['offsets'] for ex in exs}
        answers = utils.load_answers(dataset_file)

        eval_time = utils.Timer()
        f1 = utils.AverageMeter()
        exact_match = utils.AverageMeter()

        examples = 0
        bad_examples = 0
        for ex in dev_loader:
            ex_id, batch_size = ex[-1], ex[0].size(0)
            chosen_offset = ex[-2]
            pred_s, pred_e, _ = model.predict(ex)

            for i in range(batch_size):
                if pred_s[i][0] >= len(
                        offsets[ex_id[i]]) or pred_e[i][0] >= len(
                            offsets[ex_id[i]]):
                    bad_examples += 1
                    continue
                if args.use_sentence_selector:
                    s_offset = chosen_offset[i][pred_s[i][0]][0]
                    e_offset = chosen_offset[i][pred_e[i][0]][1]
                else:
                    s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                    e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
                prediction = texts[ex_id[i]][s_offset:e_offset]

                if args.select_k > 1:
                    prediction = ""
                    offset_subset = chosen_offset[i][pred_s[i][0]:pred_e[i][0]]
                    for enum_, o in enumerate(offset_subset):
                        prediction += texts[ex_id[i]][o[0]:o[1]] + " "
                    prediction = prediction.strip()

                predictions[ex_id[i]] = prediction

                ground_truths = answers[ex_id[i]]
                exact_match.update(
                    utils.metric_max_over_ground_truths(
                        utils.exact_match_score, prediction, ground_truths))
                f1.update(
                    utils.metric_max_over_ground_truths(
                        utils.f1_score, prediction, ground_truths))

            examples += batch_size

        logger.info(
            'dev valid official for dev file %s : Epoch = %d | EM = %.2f | ' %
            (dataset_file, global_stats['epoch'], exact_match.avg * 100) +
            'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
            (f1.avg * 100, examples, eval_time.time()))

        orig_f1_score = 0.0
        orig_exact_match_score = 0.0
        adv_f1_scores = {}  # Map from original ID to F1 score
        adv_exact_match_scores = {
        }  # Map from original ID to exact match score
        adv_ids = {}
        all_ids = set()  # Set of all original IDs
        f1 = exact_match = 0
        dataset = json.load(open(dataset_file))['data']
        for article in dataset:
            for paragraph in article['paragraphs']:
                for qa in paragraph['qas']:
                    orig_id = qa['id'].split('-')[0]
                    all_ids.add(orig_id)
                    if qa['id'] not in predictions:
                        message = 'Unanswered question ' + qa[
                            'id'] + ' will receive score 0.'
                        # logger.info(message)
                        continue
                    ground_truths = list(
                        map(lambda x: x['text'], qa['answers']))
                    prediction = predictions[qa['id']]
                    cur_exact_match = utils.metric_max_over_ground_truths(
                        utils.exact_match_score, prediction, ground_truths)
                    cur_f1 = utils.metric_max_over_ground_truths(
                        utils.f1_score, prediction, ground_truths)
                    if orig_id == qa['id']:
                        # This is an original example
                        orig_f1_score += cur_f1
                        orig_exact_match_score += cur_exact_match
                        if orig_id not in adv_f1_scores:
                            # Haven't seen adversarial example yet, so use original for adversary
                            adv_ids[orig_id] = orig_id
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
                    else:
                        # This is an adversarial example
                        if (orig_id not in adv_f1_scores
                                or adv_ids[orig_id] == orig_id
                                or adv_f1_scores[orig_id] > cur_f1):
                            # Always override if currently adversary currently using orig_id
                            adv_ids[orig_id] = qa['id']
                            adv_f1_scores[orig_id] = cur_f1
                            adv_exact_match_scores[orig_id] = cur_exact_match
        orig_f1 = 100.0 * orig_f1_score / len(all_ids)
        orig_exact_match = 100.0 * orig_exact_match_score / len(all_ids)
        adv_exact_match = 100.0 * sum(
            adv_exact_match_scores.values()) / len(all_ids)
        adv_f1 = 100.0 * sum(adv_f1_scores.values()) / len(all_ids)
        logger.info(
            "For the file %s Original Exact Match : %.4f ; Original F1 : : %.4f | "
            % (dataset_file, orig_exact_match, orig_f1) +
            "Adversarial Exact Match : %.4f ; Adversarial F1 : : %.4f " %
            (adv_exact_match, adv_f1))
示例#9
0
def validate_official(args, data_loader, model, global_stats, offsets, texts,
                      questions, answers):
    """Run one full official validation. Uses exact spans and same
    exact match/F1 score computation as in the SQuAD script.

    Extra arguments:
        offsets: The character start/end indices for the tokens in each context.
        texts: Map of qid --> raw text of examples context (matches offsets).
        answers: Map of qid --> list of accepted answers.
    """
    eval_time = utils.Timer()
    f1 = utils.AverageMeter()
    exact_match = utils.AverageMeter()

    # Run through examples
    examples = 0
    em_false = {}  # cid -> (context, [(qid, question, answer)...])
    predictions = {}  # qid -> prediction
    for ex in data_loader:
        ex_id, batch_size = ex[-1], ex[0].size(0)
        pred_s, pred_e, _ = model.predict(ex)

        for i in range(batch_size):
            s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
            e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
            prediction = texts[ex_id[i]][1][s_offset:e_offset]
            cid = texts[ex_id[i]][0]
            predictions[ex_id[i]] = prediction

            # Compute metrics
            ground_truths = answers[ex_id[i]]

            em_score = utils.metric_max_over_ground_truths(
                utils.exact_match_score, prediction, ground_truths)
            if em_score < 1:
                if cid not in em_false:
                    em_false[cid] = {
                        'text':
                        texts[ex_id[i]][1],
                        'qa': [{
                            'qid': ex_id[i],
                            'question': questions[ex_id[i]],
                            'answers': answers[ex_id[i]],
                            'prediction': prediction
                        }]
                    }
                else:
                    em_false[cid]['qa'].append({
                        'qid': ex_id[i],
                        'question': questions[ex_id[i]],
                        'answers': answers[ex_id[i]],
                        'prediction': prediction
                    })

            exact_match.update(em_score)
            f1.update(
                utils.metric_max_over_ground_truths(utils.f1_score, prediction,
                                                    ground_truths))

        examples += batch_size

    logger.info('dev valid official: Epoch = %d | EM = %.2f | ' %
                (global_stats['epoch'], exact_match.avg * 100) +
                'F1 = %.2f | examples = %d | valid time = %.2f (s)' %
                (f1.avg * 100, examples, eval_time.time()))

    return {
        'exact_match': exact_match.avg * 100,
        'f1': f1.avg * 100
    }, em_false, predictions
示例#10
0
        question = data['question']
        questions.append(question)
        answer = [normalize(a) for a in data['answer']]
        answers.append(answer)

    doc_ids_ans = []
    prediction_file = os.path.join(args.prediction_dir, PREDICTION_FILE)
    for line, answer in zip(open(prediction_file, encoding=ENCODING), answers):
        data = json.loads(line)
        if len(data):
            prediction = normalize(data[0]['span'])
            id_ans = data[0]['doc_id']
        else:
            prediction = ''
            id_ans = ''
        exact_match = metric_max_over_ground_truths(regex_match_score,
                                                    prediction, answer)
        doc_ids_ans.append((id_ans, exact_match))

    query_doc_dict = {}
    for line in extract_lines(args.log_file, 'question_d:', ' ]'):
        question, sec_strings = line.split(', query:')
        start_index = sec_strings.index('doc_ids:') + len('doc_ids:')
        end_index = sec_strings.index(', doc_scores:')
        doc_id_strings = sec_strings[start_index:end_index]
        top_doc_ids = ast.literal_eval(doc_id_strings)
        query_doc_dict[question.strip()] = top_doc_ids

    ranks = []
    right_ranks = []
    for question, id_ans in zip(questions, doc_ids_ans):
        doc_id, ans = id_ans