Esempio n. 1
0
def main(args):
    err = None
    metrics = {}
    bleu4, rouge_l = 0.0, 0.0
    alpha = args.ab
    beta = args.ab
    bleu_eval = BLEUWithBonus(4, alpha=alpha, beta=beta)
    rouge_eval = RougeL(alpha=alpha, beta=beta, gamma=1.2)
    try:
        pred_result = read_file(args.pred_file)
        ref_result = read_file(args.ref_file, is_ref=True)
        for qid, results in ref_result.items():
            cand_result = pred_result.get(qid, {})
            #pred_answers = cand_result.get('answers', [EMPTY])[0]
            pred_answers = cand_result.get('answers', [])
            if not pred_answers:
                pred_answers = EMPTY
            else:
                pred_answers = pred_answers[0]
            pred_yn_label = None
            ref_entities = None
            ref_answers = results.get('answers', [])
            if not ref_answers:
                continue
            if results['question_type'] == 'ENTITY':
                ref_entities = set(
                        itertools.chain(*results.get('entity_answers', [[]])))
                if not ref_entities:
                    ref_entities = None
            if results['question_type'] == 'YES_NO':
                cand_yesno = cand_result.get('yesno_answers', [])
                pred_yn_label = None if len(cand_yesno) == 0 \
                        else cand_yesno[0]
            bleu_eval.add_inst(
                    pred_answers,
                    ref_answers,
                    yn_label=pred_yn_label,
                    yn_ref=results['yesno_answers'],
                    entity_ref=ref_entities)
            rouge_eval.add_inst(
                    pred_answers,
                    ref_answers,
                    yn_label=pred_yn_label,
                    yn_ref=results['yesno_answers'],
                    entity_ref=ref_entities)
        bleu4 = bleu_eval.score()[-1]
        rouge_l = rouge_eval.score()
    except ValueError as ve:
        err = ve
    except AssertionError as ae:
        err = ae
    # too keep compatible to leaderboard evaluation.
    metrics['errorMsg'] = 'success' if err is None else err
    metrics['errorCode'] = 0 if err is None else 1
    metrics['data'] = [
            {'type': 'BOTH', 'name': 'ROUGE-L', 'value': round(rouge_l* 100, 2)},
            {'type': 'BOTH', 'name': 'BLEU-4', 'value': round(bleu4 * 100, 2)},
            ]
    print json.dumps(metrics, ensure_ascii=False).encode('utf8')
Esempio n. 2
0
def main(args):
    err = None
    metrics = {}
    bleu4, rouge_l = 0.0, 0.0
    alpha = args.alpha  # default 1.0
    beta = args.beta  # default 1.0
    bleu_eval = BLEUWithBonus(4, alpha=alpha, beta=beta)
    rouge_eval = RougeLWithBonus(alpha=alpha, beta=beta, gamma=1.2)
    pred_result = read_file(args.pred_file)
    ref_result = read_file(args.ref_file, is_ref=True)
    bleu4, rouge_l = calc_metrics(pred_result, ref_result, bleu_eval,
                                  rouge_eval)
    metrics = {
        'ROUGE-L': round(rouge_l * 100, 2),
        'BLEU-4': round(bleu4 * 100, 2),
    }
    print json.dumps(metrics, ensure_ascii=False).encode('utf8')
Esempio n. 3
0
def evaluate_batch(model,
                   num_batches,
                   eval_file,
                   sess,
                   data_type,
                   handle,
                   str_handle,
                   args,
                   logger,
                   result_prefix=None):
    losses = []
    pred_answers, ref_answers = [], []
    padded_p_len = args.max_p_len
    for i in range(num_batches):
        qa_id, loss, start_probs, end_probs = sess.run(
            [model.qa_id, model.loss, model.logits1, model.logits2],
            feed_dict={handle: str_handle} if handle is not None else None)
        losses.append(loss)
        start, end = 0, 0
        for id, start_prob, end_prob in zip(qa_id, start_probs, end_probs):
            best_p_idx, best_span, best_score = None, None, 0
            sample = eval_file[str(id)]
            for p_idx, passage_len in enumerate(sample['passages_len']):
                if p_idx >= args.max_p_num:
                    continue
                # 为每个passage找到best answer
                end = start + passage_len
                answer_span, score = find_best_answer_for_passage(
                    start_prob[start:end], end_prob[start:end], passage_len,
                    args.max_a_len)
                answer_span[0] += start
                answer_span[1] += start
                # 各passage间最大score
                if score > best_score:
                    best_score = score
                    best_p_idx = p_idx
                    best_span = answer_span
                end = start
            # best_span = [start_prob, end_prob]
            # best_answer = sample['passages'][best_span[0]: best_span[1] + 1]
            # 根据span找到token
            if best_p_idx is None or best_span is None:
                best_answer = ''
            else:
                best_answer = ''.join(
                    sample['passages'][best_span[0]:best_span[1] + 1])
            # TODO 加入question tokens
            pred_answers.append({
                'question_id': sample['question_id'],
                'question_type': sample['question_type'],
                'answers': [best_answer],
                'yesno_answers': []
            })
            # 标准答案
            # if 'answers' in sample and len(sample['answers']) > 0:
            if 'answers' in sample:
                ref_answers.append({
                    'question_id': sample['question_id'],
                    'question_type': sample['question_type'],
                    'answers': sample['answers'],
                    'yesno_answers': []
                })

    if result_prefix is not None:
        result_file = os.path.join(args.result_dir, result_prefix + '.json')
        with open(result_file, 'w') as fout:
            for pred_answer in pred_answers:
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
        logger.info('Saving {} results to {}'.format(result_prefix,
                                                     result_file))

    avg_loss = np.mean(losses)
    bleu4, rouge_l = 0, 0
    if len(ref_answers) > 0:
        # K-V 问题ID-答案
        pred_dict, ref_dict, bleu_rouge = {}, {}, {}
        for pred, ref in zip(pred_answers, ref_answers):
            question_id = ref['question_id']
            if len(ref['answers']) > 0:
                # 将answer tokens转换为由空格连接的一句话
                pred_dict[question_id] = {
                    'answers': mrc_eval.normalize(pred['answers']),
                    'yesno_answers': []
                }
                ref_dict[question_id] = {
                    'question_type': ref['question_type'],
                    'answers': mrc_eval.normalize(ref['answers']),
                    'yesno_answers': []
                }
        bleu_eval = BLEUWithBonus(4, alpha=1.0, beta=1.0)
        rouge_eval = RougeLWithBonus(alpha=1.0, beta=1.0, gamma=1.2)
        bleu4, rouge_l = mrc_eval.calc_metrics(pred_dict, ref_dict, bleu_eval,
                                               rouge_eval)
        bleu_rouge['Bleu-4'] = bleu4
        bleu_rouge['Rouge-L'] = rouge_l
    else:
        bleu_rouge = None

    loss_sum = tf.Summary(value=[
        tf.Summary.Value(tag="{}/loss".format(data_type),
                         simple_value=avg_loss),
    ])
    bleu_sum = tf.Summary(value=[
        tf.Summary.Value(tag="{}/f1".format(data_type), simple_value=bleu4),
    ])
    rouge_sum = tf.Summary(value=[
        tf.Summary.Value(tag="{}/em".format(data_type), simple_value=rouge_l),
    ])
    return avg_loss, bleu_rouge, [loss_sum, bleu_sum, rouge_sum]