Esempio n. 1
0
 def evaluate(self):
     self.model.eval()
     sents = [g.metadata['snt'] for g in self.graphs_gold]
     graphs_gen = self.inference.parse_sents(sents,
                                             return_penman=True,
                                             disable_progress=False,
                                             pbar_desc='%-14s' %
                                             'Evaluating:')
     assert len(graphs_gen) == len(self.graphs_gold)
     # Detect bad graphs. In Penman 1.2.0, metadata does not impact penam.Graph.__eq__()
     num_bad = sum(g == Inference.invalid_graph for g in graphs_gen)
     print('Out of %d graphs, %d did not generate properly.' %
           (len(graphs_gen), num_bad))
     # Save the final graphs
     print('Generated graphs written to', self.dev_pred_path)
     penman.dump(graphs_gen, self.dev_pred_path, indent=6, model=amr_model)
     # Run smatch
     try:
         gold_entries = get_entries(self.dev_gold_path)
         test_entries = get_entries(self.dev_pred_path)
         precision, recall, f_score = compute_smatch(
             test_entries, gold_entries)
         print('SMATCH -> P: %.3f,  R: %.3f,  F: %.3f' %
               (precision, recall, f_score))
     except:
         logger.exception('Failed to compute smatch score.')
         precision, recall, f_score = 0, 0, 0
     return f_score
Esempio n. 2
0
 def wikify_file(self, infpath, outfpath):
     print('Loading', infpath)
     pgraphs = penman.load(infpath)
     winfo_list = self.find_wiki_nodes_for_graphs(pgraphs)
     print('Running BLINK to get wiki values')
     winfo_list = self.predict_blink(winfo_list)
     print('Adding and saving graphs to', outfpath)
     pgraphs = self.add_wiki_to_graphs(pgraphs, winfo_list)
     penman.dump(pgraphs, outfpath, indent=6)
Esempio n. 3
0
def wiki_remove_file(indir, infn, outdir, outfn):
    graphs = []
    inpath = os.path.join(indir, infn)
    entries = load_amr_entries(inpath)
    for entry in tqdm(entries, ncols=100):
        graph = _process_entry(entry)
        graphs.append(graph)
    outpath = os.path.join(outdir, outfn)
    print('Saving file to ', outpath)
    penman.dump(graphs, outpath, indent=6)
Esempio n. 4
0
def annotate_file(indir, infn, outdir, outfn):
    load_spacy()
    graphs = []
    inpath = os.path.join(indir, infn)
    entries = load_amr_entries(inpath)
    pool = multiprocessing.Pool()
    #for pen in tqdm(map(_process_entry, entries), total=len(entries)):
    for pen in tqdm(pool.imap(_process_entry, entries), total=len(entries)):
        graphs.append(pen)
    pool.close()
    pool.join()
    infn = infn[:-3] if infn.endswith('.gz') else infn  # strip .gz if needed
    outpath = os.path.join(outdir, outfn)
    print('Saving file to ', outpath)
    penman.dump(graphs, outpath, indent=6)
Esempio n. 5
0
 def load_eval_data(self):
     print('Loading eval data from ', self.config['dev'])
     self.inference = Inference(model=self.model,
                                tokenizer=self.tokenizer,
                                device=self.device,
                                num_beams=self.config['eval_beam_size'],
                                batch_size=self.config['eval_batch_sents'],
                                config=self.config)
     self.graphs_gold = read_raw_amr_data(
         self.config['dev'],
         use_recategorization=self.config['use_recategorization'],
         dereify=self.config['dereify'],
         remove_wiki=self.config['remove_wiki'])
     penman.dump(self.graphs_gold,
                 self.dev_gold_path,
                 indent=6,
                 model=amr_model)
    print('%d generated graphs do not deserialize out of %d = %.1f%%' % (len(bad_graphs), num_non_clipped, pct))
    print()

    # Save the reference, omitting any clipped or bad
    ref_fpath = os.path.join(out_dir, ref_out_fn)
    print('Saving', ref_fpath)
    skipped = 0
    with open(ref_fpath, 'w') as f:
        for i, graph in enumerate(ref_in_graphs):
            if i in bad_graphs or i in clip_index_set:
                skipped += 1
                continue
            f.write(graph + '\n\n')
    print('Skipped writing %d as either bad or clipped' % skipped)
    print('Wrote a total of %d reference AMR graphs' % (len(ref_in_graphs) - skipped))
    print()

    # Save the generated
    gen_fpath = os.path.join(out_dir, gen_out_fn)
    print('Saving', gen_fpath)
    penman.dump(gen_out_graphs, gen_fpath, indent=6, model=NoOpModel())
    print('Wrote a total of %d generated AMR graphs' % len(gen_out_graphs))
    print()

    # Score the resultant files
    print('Scoring the above files with SMATCH')
    gold_entries = get_entries(ref_fpath)
    test_entries = get_entries(gen_fpath)
    precision, recall, f_score = compute_smatch(test_entries, gold_entries)
    print('SMATCH -> P: %.3f,  R: %.3f,  F: %.3f' % (precision, recall, f_score))
    # Convert to penman and add lemmas
    print('Annotating')
    load_spacy(
    )  # do this in the main process to prevent doing it multiple times
    graphs = []
    annotate = partial(add_lemmas, snt_key='snt',
                       verify_tok_key=None)  # no existing tok key
    with Pool() as pool:
        for graph in pool.imap(annotate, entries):
            if graph is not None:
                graphs.append(graph)
    print('%d graphs left with the same tokenization length' % len(graphs))

    # Run the aligner
    print('Aligning Graphs')
    new_graphs = []
    keep_keys = ('id', 'snt', 'tokens', 'lemmas', 'rbw_alignments')
    for graph in graphs:
        aligner = RBWAligner.from_penman_w_json(
            graph, align_str_name='rbw_alignments')
        pgraph = aligner.get_penman_graph()
        pgraph.metadata = {
            k: v
            for k, v in pgraph.metadata.items() if k in keep_keys
        }
        new_graphs.append(pgraph)

    # Save the graphs
    print('Saving to', out_fname)
    penman.dump(new_graphs, out_fname, model=NoOpModel(), indent=6)
Esempio n. 8
0
 def wikify_file(self, infn, outfn):
     new_graphs = []
     for graph in tqdm(penman.load(infn)):
         new_graph = self.wikify_graph(graph)
         new_graphs.append(new_graph)
     penman.dump(new_graphs, outfn, indent=6)
Esempio n. 9
0
    tokenizer = inference.tokenizer
    config = inference.config

    # Load the data
    print('Loading the dataset')
    graphs_gold = read_raw_amr_data(
        test_fns,
        use_recategorization=config['use_recategorization'],
        dereify=config['dereify'],
        remove_wiki=config['remove_wiki'])
    graphs_gold = graphs_gold[:max_entries]
    sents = [g.metadata['snt'] for g in graphs_gold]

    # Create the gold test file
    os.makedirs(os.path.dirname(gold_path), exist_ok=True)
    penman.dump(graphs_gold, gold_path, indent=4, model=amr_model)

    # Run the inference
    print('Generating/testing')
    graphs_gen = inference.parse_sents(sents,
                                       return_penman=True,
                                       disable_progress=False)
    assert len(graphs_gen) == len(graphs_gold)

    # Detect bad graphs
    # In Penman 1.2.0, metadata does not impact penam.Graph.__eq__()
    num_bad = sum(g == Inference.invalid_graph for g in graphs_gen)
    print('Out of %d graphs, %d did not generate properly.' %
          (len(graphs_gen), num_bad))

    # Save the final graphs
Esempio n. 10
0
    graph_fn = 'amrlib/data/alignments/test_w_surface.txt'
    graph_ns_fn = 'amrlib/data/alignments/test_no_surface.txt'

    os.makedirs(os.path.dirname(graph_fn), exist_ok=True)

    # Loop through the files and load all entries
    entries = []
    print('Loading data from', corp_dir)
    fpaths = [os.path.join(corp_dir, fn) for fn in os.listdir(corp_dir)]
    for fpath in fpaths:
        entries += load_raw_amr(fpath)
    print('Loaded {:,} entries'.format(len(entries)))

    # Check for the penman decode/re-encode issue and strip some metadata
    good_graphs = []
    good_graphs_ns = []
    for entry in entries:
        # Create a version with No Surface alignments
        entry_ns = strip_surface_alignments(entry)
        graph, is_good = test_for_decode_encode_issue(entry)
        graph_ns, is_good_ns = test_for_decode_encode_issue(entry_ns)
        if is_good and is_good_ns:
            good_graphs.append(mod_graph_meta(graph))
            good_graphs_ns.append(mod_graph_meta(graph_ns))

    # Save the collated data
    print('Saving {:,} good graphs to {:} and {:}'.format(
        len(good_graphs), graph_fn, graph_ns_fn))
    penman.dump(good_graphs, graph_fn, indent=6)
    penman.dump(good_graphs_ns, graph_ns_fn, indent=6)
Esempio n. 11
0
        for i, graph in enumerate(ref_in_graphs):
            if i in bad_graphs or i in clip_index_set:
                skipped += 1
                continue
            # Add a test index so we can identify the graph
            f.write('# ::test_id %d\n' % i)
            f.write(graph + '\n\n')
    print('Skipped writing %d as either bad or clipped' % skipped)
    print('Wrote a total of %d reference AMR graphs' %
          (len(ref_in_graphs) - skipped))
    print()

    # Save the generated
    gen_fpath = os.path.join(test_dir, gen_out_fn)
    print('Saving', gen_fpath)
    penman.dump(gen_out_graphs, gen_fpath, indent=6)
    print('Wrote a total of %d generated AMR graphs' % len(gen_out_graphs))

    # Print some info
    print()
    print('Clipped: ', sorted(clip_index_set))
    print('Bad graphs: ', sorted(bad_graphs))
    print()

    # Score the resultant files
    print('Scoring the above files with SMATCH')
    gold_entries = get_entries(ref_fpath)
    test_entries = get_entries(gen_fpath)
    precision, recall, f_score = compute_smatch(test_entries, gold_entries)
    print('SMATCH -> P: %.3f,  R: %.3f,  F: %.3f' %
          (precision, recall, f_score))