예제 #1
0
def main():
    args = docopt(__doc__, version='0.0.1')
    args = validate(args)

    if args['ptvsd']:
        address = ('0.0.0.0', args['ptvsd'])
        ptvsd.enable_attach(address)
        ptvsd.wait_for_attach()

    with open(args['counters'], 'rb') as f:
        counters = pickle.load(f)

    full_counter = counters['context'] + counters['caption']

    bleu_scorer = BleuScorer(n=4)
    rouge_scorer = Rouge()
    rouge_scores = []
    cider_scorer = CiderScorer(n=4, sigma=6.0)
    meteor_scorer = Meteor()
    meteor_scorer._stat = types.MethodType(_stat, meteor_scorer)
    meteor_scores = []
    eval_line = 'EVAL'
    meteor_scorer.lock.acquire()
    count = 0
    recalls, precisions = [], []
    rare_recall, rare_recall_total = 0, 0
    rare_precision, rare_precision_total = 0, 0
    full_recall, full_recall_total = 0, 0
    full_precision, full_precision_total = 0, 0
    full_rare_recall, full_rare_recall_total = 0, 0
    full_rare_precision, full_rare_precision_total = 0, 0
    lengths, gt_lengths = [], []
    n_uniques, gt_n_uniques = [], []

    gen_ttrs, cap_ttrs = [], []
    gen_flesch, cap_flesch = [], []

    ent_counter = defaultdict(int)

    with open(args['file']) as f:
        for line in tqdm(f):
            obj = json.loads(line)
            if args['use_processed']:
                caption = obj['caption']
                obj['caption_names'] = obj['processed_caption_names']
            else:
                caption = obj['raw_caption']

            generation = obj['generation']

            if obj['caption_names']:
                recalls.append(compute_recall(obj))
            if obj['generated_names']:
                precisions.append(compute_precision(obj))

            c, t = compute_full_recall(obj)
            full_recall += c
            full_recall_total += t

            c, t = compute_full_precision(obj)
            full_precision += c
            full_precision_total += t

            c, t = compute_rare_recall(obj, counters['caption'])
            rare_recall += c
            rare_recall_total += t

            c, t = compute_rare_precision(obj, counters['caption'])
            rare_precision += c
            rare_precision_total += t

            c, t = compute_rare_recall(obj, full_counter)
            full_rare_recall += c
            full_rare_recall_total += t

            c, t = compute_rare_precision(obj, full_counter)
            full_rare_precision += c
            full_rare_precision_total += t

            # Remove punctuation
            caption = re.sub(r'[^\w\s]', '', caption)
            generation = re.sub(r'[^\w\s]', '', generation)

            lengths.append(len(generation.split()))
            gt_lengths.append(len(caption.split()))

            n_uniques.append(len(set(generation.split())))
            gt_n_uniques.append(len(set(caption.split())))

            bleu_scorer += (generation, [caption])
            rouge_score = rouge_scorer.calc_score([generation], [caption])
            rouge_scores.append(rouge_score)
            cider_scorer += (generation, [caption])

            stat = meteor_scorer._stat(generation, [caption])
            eval_line += ' ||| {}'.format(stat)
            count += 1

            gen_ttrs.append(obj['gen_np']['basic_ttr'])
            cap_ttrs.append(obj['caption_np']['basic_ttr'])
            gen_flesch.append(obj['gen_readability']['flesch_reading_ease'])
            cap_flesch.append(
                obj['caption_readability']['flesch_reading_ease'])

            compute_entities(obj, ent_counter)

    meteor_scorer.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
    meteor_scorer.meteor_p.stdin.flush()
    for _ in range(count):
        meteor_scores.append(
            float(meteor_scorer.meteor_p.stdout.readline().strip()))
    meteor_score = float(meteor_scorer.meteor_p.stdout.readline().strip())
    meteor_scorer.lock.release()

    blue_score, _ = bleu_scorer.compute_score(option='closest')
    rouge_score = np.mean(np.array(rouge_scores))
    cider_score, _ = cider_scorer.compute_score()

    final_metrics = {
        'BLEU-1': blue_score[0],
        'BLEU-2': blue_score[1],
        'BLEU-3': blue_score[2],
        'BLEU-4': blue_score[3],
        'ROUGE': rouge_score,
        'METEOR': meteor_score,
        'CIDEr': cider_score,
        'All names - recall': {
            'count':
            full_recall,
            'total':
            full_recall_total,
            'percentage':
            (full_recall / full_recall_total) if full_recall_total else None,
        },
        'All names - precision': {
            'count':
            full_precision,
            'total':
            full_precision_total,
            'percentage':
            (full_precision /
             full_precision_total) if full_precision_total else None,
        },
        'Caption rare names - recall': {
            'count':
            rare_recall,
            'total':
            rare_recall_total,
            'percentage':
            (rare_recall / rare_recall_total) if rare_recall_total else None,
        },
        'Caption rare names - precision': {
            'count':
            rare_precision,
            'total':
            rare_precision_total,
            'percentage':
            (rare_precision /
             rare_precision_total) if rare_precision_total else None,
        },
        'Article rare names - recall': {
            'count':
            full_rare_recall,
            'total':
            full_rare_recall_total,
            'percentage':
            (full_rare_recall /
             full_rare_recall_total) if full_rare_recall_total else None,
        },
        'Article rare names - precision': {
            'count':
            full_rare_precision,
            'total':
            full_rare_precision_total,
            'percentage':
            (full_rare_precision /
             full_rare_precision_total) if full_rare_precision_total else None,
        },
        'Length - generation': sum(lengths) / len(lengths),
        'Length - reference': sum(gt_lengths) / len(gt_lengths),
        'Unique words - generation': sum(n_uniques) / len(n_uniques),
        'Unique words - reference': sum(gt_n_uniques) / len(gt_n_uniques),
        'Caption TTR': sum(cap_ttrs) / len(cap_ttrs),
        'Generation TTR': sum(gen_ttrs) / len(gen_ttrs),
        'Caption Flesch Reading Ease': sum(cap_flesch) / len(cap_flesch),
        'Generation Flesch Reading Ease': sum(gen_flesch) / len(gen_flesch),
        'Entity all - recall': {
            'count':
            ent_counter['n_caption_ent_matches'],
            'total':
            ent_counter['n_caption_ents'],
            'percentage':
            ent_counter['n_caption_ent_matches'] /
            ent_counter['n_caption_ents'],
        },
        'Entity all - precision': {
            'count':
            ent_counter['n_gen_ent_matches'],
            'total':
            ent_counter['n_gen_ents'],
            'percentage':
            ent_counter['n_gen_ent_matches'] / ent_counter['n_gen_ents'],
        },
        'Entity person - recall': {
            'count':
            ent_counter['n_caption_person_matches'],
            'total':
            ent_counter['n_caption_persons'],
            'percentage':
            ent_counter['n_caption_person_matches'] /
            ent_counter['n_caption_persons'],
        },
        'Entity person - precision': {
            'count':
            ent_counter['n_gen_person_matches'],
            'total':
            ent_counter['n_gen_persons'],
            'percentage':
            ent_counter['n_gen_person_matches'] / ent_counter['n_gen_persons'],
        },
        'Entity GPE - recall': {
            'count':
            ent_counter['n_caption_gpes_matches'],
            'total':
            ent_counter['n_caption_gpes'],
            'percentage':
            ent_counter['n_caption_gpes_matches'] /
            ent_counter['n_caption_gpes'],
        },
        'Entity GPE - precision': {
            'count':
            ent_counter['n_gen_gpes_matches'],
            'total':
            ent_counter['n_gen_gpes'],
            'percentage':
            ent_counter['n_gen_gpes_matches'] / ent_counter['n_gen_gpes'],
        },
        'Entity ORG - recall': {
            'count':
            ent_counter['n_caption_orgs_matches'],
            'total':
            ent_counter['n_caption_orgs'],
            'percentage':
            ent_counter['n_caption_orgs_matches'] /
            ent_counter['n_caption_orgs'],
        },
        'Entity ORG - precision': {
            'count':
            ent_counter['n_gen_orgs_matches'],
            'total':
            ent_counter['n_gen_orgs'],
            'percentage':
            ent_counter['n_gen_orgs_matches'] / ent_counter['n_gen_orgs'],
        },
        'Entity DATE - recall': {
            'count':
            ent_counter['n_caption_date_matches'],
            'total':
            ent_counter['n_caption_date'],
            'percentage':
            ent_counter['n_caption_date_matches'] /
            ent_counter['n_caption_date'],
        },
        'Entity DATE - precision': {
            'count':
            ent_counter['n_gen_date_matches'],
            'total':
            ent_counter['n_gen_date'],
            'percentage':
            ent_counter['n_gen_date_matches'] / ent_counter['n_gen_date'],
        },
    }

    serialization_dir = os.path.dirname(args['file'])
    filename = os.path.basename(args['file']).split('.')[0]
    if args['use_processed']:
        filename += '_processed'

    output_file = os.path.join(serialization_dir,
                               f'{filename}_reported_metrics.json')
    with open(output_file, 'w') as file:
        json.dump(final_metrics, file, indent=4)

    for key, metric in final_metrics.items():
        print(f"{key}: {metric}")
예제 #2
0
def rouge():
    scorer = Rouge()
    score = scorer.calc_score(hypo, ref1)
    print(score)