def main(args): corpus = read_amr(args.in_amr).amrs print(f'Read {args.in_amr}') num_amrs = len(corpus) if args.indices: indices = list(map(int, args.indices)) else: indices = list(range(num_amrs)) shuffle(indices) # get one sample for index in indices: amr = corpus[index] if 7 > len(amr.tokens) > 10: continue # Get tokens aligned to nodes aligned_tokens = [ amr.tokens[i - 1] for indices in amr.alignments.values() for i in indices ] # skip amr not meeting criteria if (args.has_nodes and not set(args.has_nodes) <= set(amr.nodes.values()) ) or (args.has_edges and not set(args.has_edges) <= set( [x[1][1:] for x in amr.edges])) or (args.has_repeated_nodes and len( set(amr.nodes.values())) == len(amr.nodes.values())) or ( args.has_repeated_tokens and len(set(aligned_tokens)) == len(aligned_tokens)): continue # convert IBM AMR format to the one used here # tokens, nodes, edges, alignments = convert_format(amr) print('\n'.join([ x for x in amr.toJAMRString().split('\n') if not x.startswith('# ::edge') ])) print(index) # plot alignments = {k: v[0] - 1 for k, v in amr.alignments.items()} plot_graph(amr.tokens, amr.nodes, amr.edges, alignments) response = input('Quit [N/y]?') if response == 'y': break
def main(): # Argument handling args = argument_parser() global entities_with_preds entities_with_preds = args.in_pred_entities.split(",") # Load AMR (replace some unicode characters) # TODO: unicode fixes and other normalizations should be applied more # transparently print(f'Reading {args.in_amr}') corpus = read_amr(args.in_amr, unicode_fixes=True) gold_amrs = corpus.amrs # sanity check AMRS print_corpus_info(gold_amrs) sanity_check_amr(gold_amrs) # Load propbank if provided # TODO: Use here XML propbank reader instead of txt reader propbank_args = None if args.in_propbank_args: propbank_args = read_propbank(args.in_propbank_args) # read/write multi-task (labeled shift) action # TODO: Add conditional if here multitask_words = process_multitask_words( [list(amr.tokens) for amr in gold_amrs], args.multitask_max_words, args.in_multitask_words, args.out_multitask_words, add_root=True) # run the oracle for the entire corpus stats = run_oracle(gold_amrs, args.copy_lemma_action, multitask_words) # print stats about actions sanity_check_actions(stats['sentence_tokens'], stats['oracle_actions']) # Save statistics write_tokenized_sentences(args.out_actions, stats['oracle_actions'], separator='\t') write_tokenized_sentences(args.out_sentences, stats['sentence_tokens'], separator='\t') # State machine stats for this sentence if args.out_rule_stats: with open(args.out_rule_stats, 'w') as fid: fid.write(json.dumps(stats['rules']))
def main(): # Argument handling args = argument_parser() # Read # Load AMR (replace some unicode characters) if args.in_amr: corpus = read_amr(args.in_amr, unicode_fixes=True) amrs = corpus.amrs # Load tokens if args.in_tokens: sentences = read_tokenized_sentences(args.in_tokens, separator='\t') # Load actions i.e. oracle if args.in_actions: actions = read_tokenized_sentences(args.in_actions, separator='\t') # Load scored actions i.e. mined oracle if args.in_scored_actions: scored_actions = read_action_scores(args.in_scored_actions) # measure performance print_score_action_stats(scored_actions) # Load rule stats if args.in_rule_stats: rule_stats = read_rule_stats(args.in_rule_stats) # Modify # merge --in-actions and --in-scored-actions and store in --out-actions if args.merge_mined: # sanity checks assert args.in_tokens, "--merge-mined requires --in-tokens" assert args.in_actions, "--merge-mined requires --in-actions" assert args.in_rule_stats, "--merge-mined requires --in-rule-stats" assert args.out_rule_stats, "--merge-mined requires --out-rule-stats" if args.in_actions: assert len(actions) == len(scored_actions) print(f'Merging {args.out_actions} and {args.in_scored_actions}') # actions actions = merge_actions(actions, scored_actions) # fix actions split by whitespace arguments if args.fix_actions: actions = fix_actions_split_by_spaces(actions) # merge rules if args.merge_mined: out_rule_stats = merge_rules(sentences, actions, rule_stats, entity_rules=args.entity_rules) print(f'Merging {args.out_rule_stats} and {args.in_rule_stats}') # Write # actions if args.out_actions: dirname = os.path.dirname(args.out_actions) if dirname: os.makedirs(dirname, exist_ok=True) write_tokenized_sentences( args.out_actions, actions, separator='\t' ) print(f'Wrote {args.out_actions}') # rule stats if args.out_rule_stats: write_rule_stats(args.out_rule_stats, out_rule_stats) print(f'Wrote {args.out_rule_stats}') # AMR if args.out_amr: with open(args.out_amr, 'w') as fid: for amr in amrs: fid.write(amr.toJAMRString())
def oracle(args): # Read AMR amrs = read_amr(args.in_aligned_amr, ibm_format=True) # broken annotations that we ignore in stats # 'DATA/AMR2.0/aligned/cofill/train.txt' ignore_indices = [ 8372, # (49, ':time', 49), (49, ':condition', 49) 17055, # (3, ':mod', 7), (3, ':mod', 7) 27076, # '0.0.2.1.0.0' is on ::edges but not ::nodes # for AMR 3.0 data: DATA/AMR3.0/aligned/cofill/train.txt # self-loop: # "# ::edge vote-01 condition vote-01 0.0.2 0.0.2", # "# ::edge vote-01 time vote-01 0.0.2 0.0.2" 9296, ] # NOTE we add indices to ignore for both amr2.0 and amr3.0 in the same list # and used for both oracles, since: this would NOT change the oracle # actions, but only ignore sanity checks and displayed stats after oracle # run # Initialize machine machine = AMRStateMachine(reduce_nodes=args.reduce_nodes, absolute_stack_pos=args.absolute_stack_positions, use_copy=args.use_copy) # Save machine config machine.save(args.out_machine_config) # initialize oracle oracle = AMROracle(reduce_nodes=args.reduce_nodes, absolute_stack_pos=args.absolute_stack_positions, use_copy=args.use_copy) # will store statistics and check AMR is recovered stats = Stats(ignore_indices, ngram_stats=False) stats_vocab = StatsForVocab(no_close=False) for idx, amr in tqdm(enumerate(amrs), desc='Oracle'): # debug # print(idx) # 96 for AMR2.0 test data infinit loop # if idx == 96: # breakpoint() # spawn new machine for this sentence machine.reset(amr.tokens) # initialize new oracle for this AMR oracle.reset(amr) # proceed left to right throught the sentence generating nodes while not machine.is_closed: # get valid actions _ = machine.get_valid_actions() # oracle actions, scores = oracle.get_actions(machine) # actions = [a for a in actions if a in valid_actions] # most probable action = actions[np.argmax(scores)] # if it is node generation, keep track of original id in gold amr if isinstance(action, tuple): action, gold_node_id = action node_id = len(machine.action_history) oracle.node_map[gold_node_id] = node_id oracle.node_reverse_map[node_id] = gold_node_id # update machine, machine.update(action) # update machine stats stats.update_machine_stats(machine) # update vocabulary stats_vocab.update(action, machine) # Sanity check: We recovered the full AMR stats.update_sentence_stats(oracle, machine) # do not write 'CLOSE' in the action sequences # this might change the machine.action_history in place, but it is the # end of this machine already close_action = stats.action_sequences[-1].pop() assert close_action == 'CLOSE' # display statistics stats.display() # save action sequences and tokens write_tokenized_sentences(args.out_actions, stats.action_sequences, '\t') write_tokenized_sentences(args.out_tokens, stats.tokens, '\t') # save action vocabulary stats # debug stats_vocab.display() if getattr(args, 'out_stats_vocab', None) is not None: stats_vocab.write(args.out_stats_vocab) print(f'Action vocabulary stats written in {args.out_stats_vocab}.*')
def get_propbank_name(amr_pred): items = amr_pred.split('-') prop_pred = '-'.join(items[:-1]) + '.' + items[-1] if prop_pred.endswith('.91') or prop_pred in ['have-half-life.01']: pass else: prop_pred = prop_pred.replace('-', '_') return prop_pred if __name__ == '__main__': # Argument handling in_amr, in_propbank_json = sys.argv[1:] corpus = read_amr(in_amr) with open(in_propbank_json) as fid: propbank = json.loads(fid.read()) pred_regex = re.compile('.+-[0-9]+$') amr_alerts = defaultdict(list) sid = 0 num_preds = 0 for amr in tqdm(corpus.amrs): predicate_ids = [ k for k, v in amr.nodes.items() if pred_regex.match(v) ] num_preds += len(predicate_ids) for pred_id in predicate_ids: pred = get_propbank_name(amr.nodes[pred_id])
def main(): # Argument handling args = argument_parser() # Load AMR (replace some unicode characters) corpus = read_amr(args.in_amr, unicode_fixes=True) # Load propbank propbank_args = None if args.in_propbank_args: propbank_args = read_propbank(args.in_propbank_args) # read/write multi-task (labeled shift) action multitask_words = process_multitask_words( [list(amr.tokens) for amr in corpus.amrs], args.multitask_max_words, args.in_multitask_words, args.out_multitask_words, add_root=True) # TODO: At the end, an oracle is just a parser with oracle info. This could # be turner into a loop similar to parser.py (ore directly use that and a # AMROracleParser()) print_log("amr", "Processing oracle") oracle = AMR_Oracle(args.entity_rules, verbose=args.verbose) oracle.runOracle(corpus.amrs, propbank_args, out_oracle=args.out_oracle, out_amr=args.out_amr, out_sentences=args.out_sentences, out_actions=args.out_actions, out_rule_stats=args.out_rule_stats, add_unaligned=0, no_whitespace_in_actions=args.no_whitespace_in_actions, multitask_words=multitask_words, copy_lemma_action=args.copy_lemma_action, addnode_count_cutoff=args.addnode_count_cutoff) # inform user for stat in oracle.stats: if args.verbose: print_log("amr", stat) print_log("amr", oracle.stats[stat].most_common(100)) print_log("amr", "") if args.out_action_stats: # Store rule statistics with open(args.out_action_stats, 'w') as fid: fid.write(json.dumps(oracle.stats)) if use_addnode_rules: for x in entity_rule_totals: perc = entity_rule_stats[x] / entity_rule_totals[x] if args.verbose: print_log(x, entity_rule_stats[x], '/', entity_rule_totals[x], '=', f'{perc:.2f}') perc = sum(entity_rule_stats.values()) / \ sum(entity_rule_totals.values()) print_log('Totals:', f'{perc:.2f}') print_log('Totals:', 'Failed Entity Predictions:')
def main(args): # sanity checks assert bool(args.in_aligned_amr) ^ bool(args.in_amr), \ "Needs either --in-amr or --in-aligned-amr" if args.compare: assert args.in_aligned_amr, "--compare only with --in-aligned-amr" # files if args.in_amr: amrs = read_amr(args.in_amr, tokenize=args.tokenize) else: amrs = read_amr(args.in_aligned_amr, ibm_format=True, tokenize=args.tokenize) # normalize tokens for matching purposes, but keep the original for writing original_tokens = [] for amr in amrs: original_tokens.append(amr.tokens) amr.tokens = normalize_tokens(amr.tokens) assert args.em_epochs > 0 or args.rule_prior_strength > 0, \ "Either set --em-epochs > 0 or --rule-prior-strength > 0" # if not given pick random order if args.indices is None: indices = list(range(len(amrs))) if args.shuffle: shuffle(indices) else: indices = args.indices eval_indices = indices # Initialize aligner. This is an IBM model 1 using surface matching rules # as prior and graph vicinity rules post-processing if args.in_checkpoint_json: amr_aligner = AMRAligner.from_checkpoint(args.in_checkpoint_json) else: amr_aligner = AMRAligner(rule_prior_strength=args.rule_prior_strength, force_align_ner=args.force_align_ner, not_align_tokens=IGNORE_TOKENS, ignore_nodes=IGNORE_NODES, ignore_node_regex=IGNORE_REGEX) # loop over EM epochs av_log_lik = None for epoch in range(args.em_epochs): if av_log_lik: bar_desc = \ f'EM epoch {epoch+1}/{args.em_epochs} loglik {av_log_lik}' else: bar_desc = f'EM epoch {epoch+1}/{args.em_epochs}' for index in tqdm(indices, desc=bar_desc): # accumulate stats while fixing the posterior amr_aligner.update_counts(amrs[index], cache_key=index) # compute loglik av_log_lik = amr_aligner.train_loglik / amr_aligner.train_num_examples # update the model parameters amr_aligner.update_parameters() # save model if args.out_checkpoint_json: amr_aligner.save(args.out_checkpoint_json) # check some examples of alignment visualy if args.visual_eval: visual_eval(amr_aligner, eval_indices, amrs, args.compare, args.alignment_format) # add or replace alignments if args.out_aligned_amr or args.compare: save_aligned(amrs, original_tokens, indices, amr_aligner, args.out_aligned_amr, args.compare, args.alignment_format)