示例#1
0
def main(args):

    corpus = read_amr(args.in_amr).amrs
    print(f'Read {args.in_amr}')
    num_amrs = len(corpus)
    if args.indices:
        indices = list(map(int, args.indices))
    else:
        indices = list(range(num_amrs))
        shuffle(indices)

    # get one sample
    for index in indices:

        amr = corpus[index]

        if 7 > len(amr.tokens) > 10:
            continue

        # Get tokens aligned to nodes
        aligned_tokens = [
            amr.tokens[i - 1] for indices in amr.alignments.values()
            for i in indices
        ]

        # skip amr not meeting criteria
        if (args.has_nodes
                and not set(args.has_nodes) <= set(amr.nodes.values())
            ) or (args.has_edges and not set(args.has_edges) <= set(
                [x[1][1:]
                 for x in amr.edges])) or (args.has_repeated_nodes and len(
                     set(amr.nodes.values())) == len(amr.nodes.values())) or (
                         args.has_repeated_tokens
                         and len(set(aligned_tokens)) == len(aligned_tokens)):
            continue

        # convert IBM AMR format to the one used here
        # tokens, nodes, edges, alignments = convert_format(amr)

        print('\n'.join([
            x for x in amr.toJAMRString().split('\n')
            if not x.startswith('# ::edge')
        ]))
        print(index)

        # plot
        alignments = {k: v[0] - 1 for k, v in amr.alignments.items()}
        plot_graph(amr.tokens, amr.nodes, amr.edges, alignments)

        response = input('Quit [N/y]?')
        if response == 'y':
            break
def main():

    # Argument handling
    args = argument_parser()
    global entities_with_preds
    entities_with_preds = args.in_pred_entities.split(",")

    # Load AMR (replace some unicode characters)
    # TODO: unicode fixes and other normalizations should be applied more
    # transparently
    print(f'Reading {args.in_amr}')
    corpus = read_amr(args.in_amr, unicode_fixes=True)
    gold_amrs = corpus.amrs
    # sanity check AMRS
    print_corpus_info(gold_amrs)
    sanity_check_amr(gold_amrs)

    # Load propbank if provided
    # TODO: Use here XML propbank reader instead of txt reader
    propbank_args = None
    if args.in_propbank_args:
        propbank_args = read_propbank(args.in_propbank_args)

    # read/write multi-task (labeled shift) action
    # TODO: Add conditional if here
    multitask_words = process_multitask_words(
        [list(amr.tokens) for amr in gold_amrs],
        args.multitask_max_words,
        args.in_multitask_words,
        args.out_multitask_words,
        add_root=True)

    # run the oracle for the entire corpus
    stats = run_oracle(gold_amrs, args.copy_lemma_action, multitask_words)

    # print stats about actions
    sanity_check_actions(stats['sentence_tokens'], stats['oracle_actions'])

    # Save statistics
    write_tokenized_sentences(args.out_actions,
                              stats['oracle_actions'],
                              separator='\t')
    write_tokenized_sentences(args.out_sentences,
                              stats['sentence_tokens'],
                              separator='\t')
    # State machine stats for this sentence
    if args.out_rule_stats:
        with open(args.out_rule_stats, 'w') as fid:
            fid.write(json.dumps(stats['rules']))
示例#3
0
def main():

    # Argument handling
    args = argument_parser()

    # Read
    # Load AMR (replace some unicode characters)
    if args.in_amr:
        corpus = read_amr(args.in_amr, unicode_fixes=True)
        amrs = corpus.amrs
    # Load tokens    
    if args.in_tokens:
        sentences = read_tokenized_sentences(args.in_tokens, separator='\t')
    # Load actions i.e. oracle
    if args.in_actions:
        actions = read_tokenized_sentences(args.in_actions, separator='\t')
    # Load scored actions i.e. mined oracle     
    if args.in_scored_actions:
        scored_actions = read_action_scores(args.in_scored_actions)
        # measure performance
        print_score_action_stats(scored_actions)
    # Load rule stats
    if args.in_rule_stats:
        rule_stats = read_rule_stats(args.in_rule_stats)

    # Modify
    # merge --in-actions and --in-scored-actions and store in --out-actions
    if args.merge_mined:
        # sanity checks
        assert args.in_tokens, "--merge-mined requires --in-tokens"
        assert args.in_actions, "--merge-mined requires --in-actions"
        assert args.in_rule_stats, "--merge-mined requires --in-rule-stats"
        assert args.out_rule_stats, "--merge-mined requires --out-rule-stats"
        if args.in_actions:
            assert len(actions) == len(scored_actions)
        print(f'Merging {args.out_actions} and {args.in_scored_actions}')

        # actions
        actions = merge_actions(actions, scored_actions)

    # fix actions split by whitespace arguments 
    if args.fix_actions:
        actions = fix_actions_split_by_spaces(actions)

    # merge rules
    if args.merge_mined:
        out_rule_stats = merge_rules(sentences, actions, rule_stats, entity_rules=args.entity_rules)
        print(f'Merging {args.out_rule_stats} and {args.in_rule_stats}')

    # Write
    # actions
    if args.out_actions:
        dirname = os.path.dirname(args.out_actions)
        if dirname:
            os.makedirs(dirname, exist_ok=True)
        write_tokenized_sentences(
            args.out_actions,
            actions,
            separator='\t'
        )
        print(f'Wrote {args.out_actions}')

    # rule stats
    if args.out_rule_stats:
        write_rule_stats(args.out_rule_stats, out_rule_stats)
        print(f'Wrote {args.out_rule_stats}')

    # AMR
    if args.out_amr:
        with open(args.out_amr, 'w') as fid:
            for amr in amrs:
                fid.write(amr.toJAMRString())
示例#4
0
def oracle(args):

    # Read AMR
    amrs = read_amr(args.in_aligned_amr, ibm_format=True)

    # broken annotations that we ignore in stats
    # 'DATA/AMR2.0/aligned/cofill/train.txt'
    ignore_indices = [
        8372,  # (49, ':time', 49), (49, ':condition', 49)
        17055,  # (3, ':mod', 7), (3, ':mod', 7)
        27076,  # '0.0.2.1.0.0' is on ::edges but not ::nodes
        # for AMR 3.0 data: DATA/AMR3.0/aligned/cofill/train.txt
        # self-loop:
        # "# ::edge vote-01 condition vote-01 0.0.2 0.0.2",
        # "# ::edge vote-01 time vote-01 0.0.2 0.0.2"
        9296,
    ]
    # NOTE we add indices to ignore for both amr2.0 and amr3.0 in the same list
    # and used for both oracles, since: this would NOT change the oracle
    # actions, but only ignore sanity checks and displayed stats after oracle
    # run

    # Initialize machine
    machine = AMRStateMachine(reduce_nodes=args.reduce_nodes,
                              absolute_stack_pos=args.absolute_stack_positions,
                              use_copy=args.use_copy)
    # Save machine config
    machine.save(args.out_machine_config)

    # initialize oracle
    oracle = AMROracle(reduce_nodes=args.reduce_nodes,
                       absolute_stack_pos=args.absolute_stack_positions,
                       use_copy=args.use_copy)

    # will store statistics and check AMR is recovered
    stats = Stats(ignore_indices, ngram_stats=False)
    stats_vocab = StatsForVocab(no_close=False)
    for idx, amr in tqdm(enumerate(amrs), desc='Oracle'):

        # debug
        # print(idx)    # 96 for AMR2.0 test data infinit loop
        # if idx == 96:
        #     breakpoint()

        # spawn new machine for this sentence
        machine.reset(amr.tokens)

        # initialize new oracle for this AMR
        oracle.reset(amr)

        # proceed left to right throught the sentence generating nodes
        while not machine.is_closed:

            # get valid actions
            _ = machine.get_valid_actions()

            # oracle
            actions, scores = oracle.get_actions(machine)
            # actions = [a for a in actions if a in valid_actions]
            # most probable
            action = actions[np.argmax(scores)]

            # if it is node generation, keep track of original id in gold amr
            if isinstance(action, tuple):
                action, gold_node_id = action
                node_id = len(machine.action_history)
                oracle.node_map[gold_node_id] = node_id
                oracle.node_reverse_map[node_id] = gold_node_id

            # update machine,
            machine.update(action)

            # update machine stats
            stats.update_machine_stats(machine)

            # update vocabulary
            stats_vocab.update(action, machine)

        # Sanity check: We recovered the full AMR
        stats.update_sentence_stats(oracle, machine)

        # do not write 'CLOSE' in the action sequences
        # this might change the machine.action_history in place, but it is the
        # end of this machine already
        close_action = stats.action_sequences[-1].pop()
        assert close_action == 'CLOSE'

    # display statistics
    stats.display()

    # save action sequences and tokens
    write_tokenized_sentences(args.out_actions, stats.action_sequences, '\t')
    write_tokenized_sentences(args.out_tokens, stats.tokens, '\t')

    # save action vocabulary stats
    # debug

    stats_vocab.display()
    if getattr(args, 'out_stats_vocab', None) is not None:
        stats_vocab.write(args.out_stats_vocab)
        print(f'Action vocabulary stats written in {args.out_stats_vocab}.*')
示例#5
0
def get_propbank_name(amr_pred):
    items = amr_pred.split('-')
    prop_pred = '-'.join(items[:-1]) + '.' + items[-1]
    if prop_pred.endswith('.91') or prop_pred in ['have-half-life.01']:
        pass
    else:
        prop_pred = prop_pred.replace('-', '_')
    return prop_pred


if __name__ == '__main__':

    # Argument handling
    in_amr, in_propbank_json = sys.argv[1:]

    corpus = read_amr(in_amr)
    with open(in_propbank_json) as fid:
        propbank = json.loads(fid.read())

    pred_regex = re.compile('.+-[0-9]+$')

    amr_alerts = defaultdict(list)
    sid = 0
    num_preds = 0
    for amr in tqdm(corpus.amrs):
        predicate_ids = [
            k for k, v in amr.nodes.items() if pred_regex.match(v)
        ]
        num_preds += len(predicate_ids)
        for pred_id in predicate_ids:
            pred = get_propbank_name(amr.nodes[pred_id])
def main():

    # Argument handling
    args = argument_parser()

    # Load AMR (replace some unicode characters)
    corpus = read_amr(args.in_amr, unicode_fixes=True)

    # Load propbank
    propbank_args = None
    if args.in_propbank_args:
        propbank_args = read_propbank(args.in_propbank_args)

    # read/write multi-task (labeled shift) action
    multitask_words = process_multitask_words(
        [list(amr.tokens) for amr in corpus.amrs],
        args.multitask_max_words,
        args.in_multitask_words,
        args.out_multitask_words,
        add_root=True)

    # TODO: At the end, an oracle is just a parser with oracle info. This could
    # be turner into a loop similar to parser.py (ore directly use that and a
    # AMROracleParser())
    print_log("amr", "Processing oracle")
    oracle = AMR_Oracle(args.entity_rules, verbose=args.verbose)
    oracle.runOracle(corpus.amrs,
                     propbank_args,
                     out_oracle=args.out_oracle,
                     out_amr=args.out_amr,
                     out_sentences=args.out_sentences,
                     out_actions=args.out_actions,
                     out_rule_stats=args.out_rule_stats,
                     add_unaligned=0,
                     no_whitespace_in_actions=args.no_whitespace_in_actions,
                     multitask_words=multitask_words,
                     copy_lemma_action=args.copy_lemma_action,
                     addnode_count_cutoff=args.addnode_count_cutoff)

    # inform user
    for stat in oracle.stats:
        if args.verbose:
            print_log("amr", stat)
            print_log("amr", oracle.stats[stat].most_common(100))
            print_log("amr", "")

    if args.out_action_stats:
        # Store rule statistics
        with open(args.out_action_stats, 'w') as fid:
            fid.write(json.dumps(oracle.stats))

    if use_addnode_rules:
        for x in entity_rule_totals:
            perc = entity_rule_stats[x] / entity_rule_totals[x]
            if args.verbose:
                print_log(x, entity_rule_stats[x], '/', entity_rule_totals[x],
                          '=', f'{perc:.2f}')
        perc = sum(entity_rule_stats.values()) / \
            sum(entity_rule_totals.values())
        print_log('Totals:', f'{perc:.2f}')
        print_log('Totals:', 'Failed Entity Predictions:')
示例#7
0
def main(args):

    # sanity checks
    assert bool(args.in_aligned_amr) ^ bool(args.in_amr), \
        "Needs either --in-amr or --in-aligned-amr"
    if args.compare:
        assert args.in_aligned_amr, "--compare only with --in-aligned-amr"

    # files
    if args.in_amr:
        amrs = read_amr(args.in_amr, tokenize=args.tokenize)
    else:
        amrs = read_amr(args.in_aligned_amr,
                        ibm_format=True,
                        tokenize=args.tokenize)
    # normalize tokens for matching purposes, but keep the original for writing
    original_tokens = []
    for amr in amrs:
        original_tokens.append(amr.tokens)
        amr.tokens = normalize_tokens(amr.tokens)

    assert args.em_epochs > 0 or args.rule_prior_strength > 0, \
        "Either set --em-epochs > 0 or --rule-prior-strength > 0"

    # if not given pick random order
    if args.indices is None:
        indices = list(range(len(amrs)))
        if args.shuffle:
            shuffle(indices)
    else:
        indices = args.indices

    eval_indices = indices

    # Initialize aligner. This is an IBM model 1 using surface matching rules
    # as prior and graph vicinity rules post-processing
    if args.in_checkpoint_json:
        amr_aligner = AMRAligner.from_checkpoint(args.in_checkpoint_json)
    else:
        amr_aligner = AMRAligner(rule_prior_strength=args.rule_prior_strength,
                                 force_align_ner=args.force_align_ner,
                                 not_align_tokens=IGNORE_TOKENS,
                                 ignore_nodes=IGNORE_NODES,
                                 ignore_node_regex=IGNORE_REGEX)

    # loop over EM epochs
    av_log_lik = None
    for epoch in range(args.em_epochs):
        if av_log_lik:
            bar_desc = \
                f'EM epoch {epoch+1}/{args.em_epochs} loglik {av_log_lik}'
        else:
            bar_desc = f'EM epoch {epoch+1}/{args.em_epochs}'
        for index in tqdm(indices, desc=bar_desc):
            # accumulate stats while fixing the posterior
            amr_aligner.update_counts(amrs[index], cache_key=index)

        # compute loglik
        av_log_lik = amr_aligner.train_loglik / amr_aligner.train_num_examples

        # update the model parameters
        amr_aligner.update_parameters()

    # save model
    if args.out_checkpoint_json:
        amr_aligner.save(args.out_checkpoint_json)

    # check some examples of alignment visualy
    if args.visual_eval:
        visual_eval(amr_aligner, eval_indices, amrs, args.compare,
                    args.alignment_format)

    # add or replace alignments
    if args.out_aligned_amr or args.compare:
        save_aligned(amrs, original_tokens, indices, amr_aligner,
                     args.out_aligned_amr, args.compare, args.alignment_format)