예제 #1
0
def play(args):

    sentences = read_tokenized_sentences(args.in_tokens, '\t')
    action_sequences = read_tokenized_sentences(args.in_actions, '\t')
    assert len(sentences) == len(action_sequences)

    # This will store the annotations to write
    annotations = []

    # Initialize machine
    machine = AMRStateMachine.from_config(args.in_machine_config)
    for index in tqdm(range(len(action_sequences)), desc='Machine'):

        # New machine for this sentence
        machine.reset(sentences[index])

        # add back the 'CLOSE' action if it is not written in file
        if action_sequences[index][-1] != 'CLOSE':
            action_sequences[index].append('CLOSE')

        for action in action_sequences[index]:
            machine.update(action)

        assert machine.is_closed

        # print AMR
        annotations.append(machine.get_annotation())

    with open(args.out_amr, 'w') as fid:
        for annotation in annotations:
            fid.write(annotation)
예제 #2
0
def main():

    raise NotImplementedError(
        'Sorry, no standalone version yet, use action-pointer branch')

    # argument handling
    args = argument_parsing()

    # set inspector to use on action loop
    inspector = None
    if args.set_trace:
        inspector = breakpoint_inspector
    if args.step_by_step:
        inspector = simple_inspector

    # load parser
    start = time.time()
    parser = AMRParser.from_checkpoint(args.in_checkpoint, inspector=inspector)
    end = time.time()
    time_secs = timedelta(seconds=float(end - start))
    print(f'Total time taken to load parser: {time_secs}')

    # TODO: max batch sizes could be computed from max sentence length
    if args.service:

        # set orderd exit
        signal.signal(signal.SIGINT, ordered_exit)
        signal.signal(signal.SIGTERM, ordered_exit)

        while True:
            sentence = input("Write sentence:\n")
            os.system('clear')
            if not sentence.strip():
                continue
            result = parser.parse_sentences(
                [sentence.split()],
                batch_size=args.batch_size,
                roberta_batch_size=args.roberta_batch_size,
            )
            #
            os.system('clear')
            print('\n')
            print(''.join(result[0]))

    else:

        # Parse sentences
        result = parser.parse_sentences(
            read_tokenized_sentences(args.in_tokenized_sentences),
            batch_size=args.batch_size,
            roberta_batch_size=args.roberta_batch_size)

        with open(args.out_amr, 'w') as fid:
            fid.write(''.join(result[0]))
예제 #3
0
def main():

    # Argument handling
    args = argument_parser()

    # Read
    # Load AMR (replace some unicode characters)
    if args.in_amr:
        corpus = read_amr(args.in_amr, unicode_fixes=True)
        amrs = corpus.amrs
    # Load tokens    
    if args.in_tokens:
        sentences = read_tokenized_sentences(args.in_tokens, separator='\t')
    # Load actions i.e. oracle
    if args.in_actions:
        actions = read_tokenized_sentences(args.in_actions, separator='\t')
    # Load scored actions i.e. mined oracle     
    if args.in_scored_actions:
        scored_actions = read_action_scores(args.in_scored_actions)
        # measure performance
        print_score_action_stats(scored_actions)
    # Load rule stats
    if args.in_rule_stats:
        rule_stats = read_rule_stats(args.in_rule_stats)

    # Modify
    # merge --in-actions and --in-scored-actions and store in --out-actions
    if args.merge_mined:
        # sanity checks
        assert args.in_tokens, "--merge-mined requires --in-tokens"
        assert args.in_actions, "--merge-mined requires --in-actions"
        assert args.in_rule_stats, "--merge-mined requires --in-rule-stats"
        assert args.out_rule_stats, "--merge-mined requires --out-rule-stats"
        if args.in_actions:
            assert len(actions) == len(scored_actions)
        print(f'Merging {args.out_actions} and {args.in_scored_actions}')

        # actions
        actions = merge_actions(actions, scored_actions)

    # fix actions split by whitespace arguments 
    if args.fix_actions:
        actions = fix_actions_split_by_spaces(actions)

    # merge rules
    if args.merge_mined:
        out_rule_stats = merge_rules(sentences, actions, rule_stats, entity_rules=args.entity_rules)
        print(f'Merging {args.out_rule_stats} and {args.in_rule_stats}')

    # Write
    # actions
    if args.out_actions:
        dirname = os.path.dirname(args.out_actions)
        if dirname:
            os.makedirs(dirname, exist_ok=True)
        write_tokenized_sentences(
            args.out_actions,
            actions,
            separator='\t'
        )
        print(f'Wrote {args.out_actions}')

    # rule stats
    if args.out_rule_stats:
        write_rule_stats(args.out_rule_stats, out_rule_stats)
        print(f'Wrote {args.out_rule_stats}')

    # AMR
    if args.out_amr:
        with open(args.out_amr, 'w') as fid:
            for amr in amrs:
                fid.write(amr.toJAMRString())
예제 #4
0
def main():

    # Argument handling
    args = argument_parser()

    # Get data
    sentences = read_tokenized_sentences(args.in_sentences,
                                         separator=args.separator)

    # Initialize logger/printer
    logger = Logger(step_by_step=args.step_by_step,
                    clear_print=args.clear_print,
                    pause_time=args.pause_time,
                    verbose=args.verbose)

    # generate rules to restrict action space by stack content
    if args.action_rules_from_stats:
        rule_stats = read_rule_stats(args.action_rules_from_stats)
        actions_by_stack_rules = rule_stats['possible_predicates']
        for token, counter in rule_stats['possible_predicates'].items():
            actions_by_stack_rules[token] = Counter(counter)

    else:
        actions_by_stack_rules = None

    # Fake parser built from actions
    actions = read_tokenized_sentences(args.in_actions,
                                       separator=args.separator)
    assert len(sentences) == len(actions)
    parsing_model = FakeAMRParser(
        from_sent_act_pairs=zip(sentences, actions),
        machine_type=args.machine_type,
        logger=logger,
        actions_by_stack_rules=actions_by_stack_rules,
        no_whitespace_in_actions=args.no_whitespace_in_actions,
        entity_rules=args.entity_rules)

    # Get output AMR writer
    if args.out_amr:
        amr_write = writer(args.out_amr)
    if args.out_bio_tags:
        bio_write = writer(args.out_bio_tags)

    # Loop over sentences
    for sent_idx, tokens in tqdm(enumerate(sentences), desc='parsing'):

        # fast-forward until desired sentence number
        if args.offset and sent_idx < args.offset:
            continue

        # parse
        # NOTE: To simulate the real endpoint, input provided as a string of
        # whitespace separated tokens
        machine, bio_tags = parsing_model.parse_sentence(" ".join(tokens))

        # sanity check annotations
        dupes = get_duplicate_edges(machine.amr)
        if args.sanity_check and any(dupes):
            msg = yellow_font('WARNING:')
            print(f'{msg} duplicated edges in sent {sent_idx}', end=' ')
            print(dict(dupes))
            print(' '.join(machine.tokens))

        # store output AMR
        if args.out_bio_tags:
            tag_str = '\n'.join([f'{to} {ta}' for to, ta in bio_tags])
            tag_str += '\n\n'
            bio_write(tag_str)
        if args.out_amr:
            try:
                amr_write(machine.amr.toJAMRString())
            except InvalidAMRError as exception:
                print(f'\nFailed at sentence {sent_idx}\n')
                raise exception

    if (getattr(parsing_model, "rule_violation")
            and parsing_model.rule_violation):
        print(yellow_font("There were one or more action rule violations"))
        print(parsing_model.rule_violation)

    if args.action_rules_from_stats:
        print("Predict rules had following statistics")
        print(parsing_model.pred_counts)

    # close output writers
    if args.out_amr:
        amr_write()
    if args.out_bio_tags:
        bio_write()
예제 #5
0
        idw = i + 1
        total_heads += 1
        if idw in gold_heads and idw in hyp_heads:
            if hyp_heads[idw] == gold_heads[idw]:
                correct_heads += 1
            if hyp_heads[idw] == gold_heads[idw] and hyp_labels[
                    idw] == gold_labels[idw]:
                correct_labels += 1

    return total_heads, correct_heads, correct_labels


if __name__ == '__main__':

    args = argument_parser()
    in_tokens = read_tokenized_sentences(args.in_tokens)
    in_actions = read_tokenized_sentences(args.in_actions, separator='\t')
    #in_actions = read_tokenized_sentences(args.in_actions)
    in_gold_actions = read_tokenized_sentences(args.in_gold_actions)

    assert len(in_tokens) == len(in_actions)
    assert len(in_gold_actions) == len(in_actions)

    num_sentences = len(in_tokens)
    total_heads = 0
    correct_heads = 0
    correct_labels = 0
    for index in tqdm(range(num_sentences)):

        # Compute correct for this sentence
        total, correct, labels = compute_correct(in_tokens[index],