def __init__(self, src_dict, tgt_dict, machine_type, machine_rules=None, 
                 entity_rules=None):

        # Get all actions indexed by prefix
        self.action_indexer = get_action_indexer(tgt_dict.symbols)

        # Load rule stats if provided
        if machine_rules is not None:
            assert os.path.isfile(machine_rules)
            rule_stats = read_rule_stats(machine_rules)
            # self.state_machine = StateMachine(folder)
            self.get_new_state_machine = machine_generator(
                rule_stats['possible_predicates'],
                entity_rules=entity_rules
            )
        else:
            assert machine_type != 'AMR', \
                "AMR machine expects --machine-rules"
            rule_stats = None
            self.get_new_state_machine = machine_generator(None)

        # store some variables
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.machine_type = machine_type
def main():

    # Argument handling
    args = argument_parser()

    # Read
    # Load AMR (replace some unicode characters)
    if args.in_amr:
        corpus = read_amr(args.in_amr, unicode_fixes=True)
        amrs = corpus.amrs
    # Load tokens    
    if args.in_tokens:
        sentences = read_tokenized_sentences(args.in_tokens, separator='\t')
    # Load actions i.e. oracle
    if args.in_actions:
        actions = read_tokenized_sentences(args.in_actions, separator='\t')
    # Load scored actions i.e. mined oracle     
    if args.in_scored_actions:
        scored_actions = read_action_scores(args.in_scored_actions)
        # measure performance
        print_score_action_stats(scored_actions)
    # Load rule stats
    if args.in_rule_stats:
        rule_stats = read_rule_stats(args.in_rule_stats)

    # Modify
    # merge --in-actions and --in-scored-actions and store in --out-actions
    if args.merge_mined:
        # sanity checks
        assert args.in_tokens, "--merge-mined requires --in-tokens"
        assert args.in_actions, "--merge-mined requires --in-actions"
        assert args.in_rule_stats, "--merge-mined requires --in-rule-stats"
        assert args.out_rule_stats, "--merge-mined requires --out-rule-stats"
        if args.in_actions:
            assert len(actions) == len(scored_actions)
        print(f'Merging {args.out_actions} and {args.in_scored_actions}')

        # actions
        actions = merge_actions(actions, scored_actions)

    # fix actions split by whitespace arguments 
    if args.fix_actions:
        actions = fix_actions_split_by_spaces(actions)

    # merge rules
    if args.merge_mined:
        out_rule_stats = merge_rules(sentences, actions, rule_stats, entity_rules=args.entity_rules)
        print(f'Merging {args.out_rule_stats} and {args.in_rule_stats}')

    # Write
    # actions
    if args.out_actions:
        dirname = os.path.dirname(args.out_actions)
        if dirname:
            os.makedirs(dirname, exist_ok=True)
        write_tokenized_sentences(
            args.out_actions,
            actions,
            separator='\t'
        )
        print(f'Wrote {args.out_actions}')

    # rule stats
    if args.out_rule_stats:
        write_rule_stats(args.out_rule_stats, out_rule_stats)
        print(f'Wrote {args.out_rule_stats}')

    # AMR
    if args.out_amr:
        with open(args.out_amr, 'w') as fid:
            for amr in amrs:
                fid.write(amr.toJAMRString())
def make_binary_stack(args, target_vocab, input_prefix, output_prefix, eos_idx, pad_idx, mask_predicates=False, allow_unk=False, tokenize=None):

    assert tokenize

    # involved files
    # for debug
    input_senteces = input_prefix + '.en'
    input_actions = input_prefix + '.actions'

    # The AMR state machine allways expects rules
    if args.machine_type == 'AMR':

        assert args.machine_rules and os.path.isfile(args.machine_rules), \
            f'Missing {args.machine_rules}'

        # Read rules
        train_rule_stats = read_rule_stats(args.machine_rules)
        actions_by_stack_rules = train_rule_stats['possible_predicates']
    else:
        actions_by_stack_rules = None
        
    action_indexer = get_action_indexer(target_vocab.symbols)

    # initialize indices for each of variables
    # memory (stack, buffer, dead) (position in memory)
    stack_buffer_names = ['memory', 'memory_pos']
    # FIXME: These values are hard-coded elsewhere in code
    state_indices = [3, 4, 5]
    assert eos_idx not in state_indices, "Can not reuse EOS index"
    assert pad_idx not in state_indices, "Can not reuse PAD index"
    indexed_data = {}
    for name in stack_buffer_names:
        indexed_data[name] = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, name, "bin"),
            impl=args.dataset_impl,
        )

    if mask_predicates:
        # mask of target predictions 
        masks_path = dataset_dest_file(args, output_prefix, 'target_masks', "bin")
        indexed_target_masks = indexed_dataset.make_builder(
            masks_path,
            impl=args.dataset_impl,
        )

        # active indices 
        active_logits_path = dataset_dest_file(args, output_prefix, 'active_logits', "bin")
        indexed_active_logits = indexed_dataset.make_builder(
            active_logits_path,
            impl=args.dataset_impl,
        )

    # Returns function that generates initialized state machines given
    # sentence 
    get_new_state_machine = machine_generator(actions_by_stack_rules, entity_rules=args.entity_rules)

    num_sents = 0
    missing_actions = Counter()
    with open(input_actions, 'r') as fid_act, \
         open(input_senteces, 'r') as fid_sent:

        # Loop over sentences
        for sentence in tqdm(fid_sent):

            # Get actions, tokens
            sent_tokens = tokenize(sentence)
            sent_actions = tokenize(fid_act.readline())

            # intialize state machine batch for size 1
            state_machine = get_new_state_machine(
                sent_tokens,
                machine_type=args.machine_type
            )

            # collect target and source masks
            sent_data = {}
            for name in stack_buffer_names:
                sent_data[name] = []

            shape = (len(sent_actions), len(target_vocab.symbols))
            logits_mask = np.zeros(shape)
            active_logits = set()
            for action_idx, gold_action in enumerate(sent_actions):

                # active logits for this action
                if mask_predicates:

                    # Get total valid actions by expanding base ones
                    valid_actions, invalid_actions = state_machine.get_valid_actions()
                    valid_action_idx = (
                        action_indexer(valid_actions) 
                        - action_indexer(invalid_actions)
                    )

                    # if action is missing add it and count it
                    if gold_action in target_vocab.symbols:
                        gold_action_index = target_vocab.symbols.index(gold_action) 
                    else:
                        gold_action_index = target_vocab.symbols.index('<unk>') 
                    if gold_action_index not in valid_action_idx:
                        valid_action_idx.add(gold_action_index)
                        missing_actions.update([gold_action])

                    # if length 1 add pad to avoid deltas during learning
                    if len(valid_action_idx) == 1:
                        valid_action_idx.add(pad_idx)

                    # append number of nodes to regain matrix
                    logits_mask[action_idx, list(valid_action_idx)] = 1
                    active_logits |= valid_action_idx

                # stack and buffer 
                memory, memory_pos = get_word_states(
                    state_machine,
                    sent_tokens,
                    indices=state_indices
                )

                # word states
                sent_data['memory'].append(torch.Tensor(memory))
                # note we use position 0 for reduced words
                sent_data['memory_pos'].append(
                    torch.Tensor(memory_pos)
                )

                # Update machine
                state_machine.applyAction(gold_action)

            for name in stack_buffer_names:
                # note that data needs to be stores as a 1d array
                indexed_data[name].add_item(
                    torch.stack(sent_data[name]).view(-1)
                )

            # valid nodes
            if mask_predicates:
                active_logits = list(active_logits)
                # reduce size to active items
                logits_mask = logits_mask[:, active_logits]
                indexed_target_masks.add_item(
                    torch.Tensor(logits_mask).view(-1)
                )

                # active indices 
                indexed_active_logits.add_item(torch.Tensor(
                    active_logits
                ))

            # update number of sents
            num_sents += 1
            if not num_sents % 100:
                print("\r%d sentences" % num_sents, end = '')

        print("")

    # close indexed data files
    for name in stack_buffer_names:
        output_file_idx = dataset_dest_file(args, output_prefix, name, "idx")
        indexed_data[name].finalize(output_file_idx)
    # close valid action mask
    if mask_predicates:
        target_mask_idx = dataset_dest_file(args, output_prefix, 'target_masks', "idx")
        indexed_target_masks.finalize(target_mask_idx)

        # active indices 
        active_logits_idx = dataset_dest_file(args, output_prefix, 'active_logits', "idx")
        indexed_active_logits.finalize(active_logits_idx)

    # inform about mssing actions
    if missing_actions: 
        print(yellow_font("There were missing actions"))
        print(missing_actions)
def main():

    # Argument handling
    args = argument_parser()

    # read rules
    train_rule_stats = read_rule_stats(args.in_rule_stats)
    assert 'action_vocabulary' in train_rule_stats
    assert 'possible_predicates' in train_rule_stats
    action_list = list(sorted(train_rule_stats['action_vocabulary'].keys()))

    # get all actions indexec by action root
    action_by_basic = defaultdict(list)
    for action in train_rule_stats['action_vocabulary'].keys():
        key = action.split('(')[0]
        action_by_basic[key].append(action)

    # open file for reading if provided
    write_out_states = h5_writer(args.out_word_states)
    write_out_valid_actions = h5_writer(args.out_valid_actions)

    # Read content
    # TODO: Point to LDC data
    sentences = readlines(args.in_sentences)
    actions = readlines(args.in_actions)
    assert len(sentences) == len(actions)

    # initialize spacy lemmatizer out of the sentence loop for speed
    spacy_lemmatizer = get_spacy_lemmatizer()

    sent_idx = -1
    stats = {
        'missing_pred_count': Counter(),
        'missing_action_count': Counter(),
        'fan_out_count': Counter()
    }
    for sent_tokens, sent_actions in tqdm(
        zip(sentences, actions),
        desc='extracting oracle masks',
        total=len(actions)
    ):

        # keep count of sentence index
        sent_idx += 1
        if args.offset and sent_idx < args.offset:
            continue

        # Initialize state machine
        amr_state_machine = AMRStateMachine(
            sent_tokens,
            spacy_lemmatizer=spacy_lemmatizer
        )

        # process each action
        word_states_sent = []
        valid_actions_sent = []
        for raw_action in sent_actions:

            # Store states BEFORE ACTION
            # state of each word (buffer B, stack S, reduced X)
            word_states = get_word_states(amr_state_machine, sent_tokens)

            # Get actions valid for this state
            valid_actions = get_valid_actions(
                action_list,
                amr_state_machine,
                train_rule_stats,
                action_by_basic,
                raw_action,
                stats
            )

            # update info
            word_states_sent.append(word_states)
            valid_actions_sent.append(valid_actions)

            # Update machine
            amr_state_machine.applyAction(raw_action)

        # Write states for this sentence
        write_out_states(word_states_sent)
        write_out_valid_actions(valid_actions_sent)

    # inform usre about missing predicates
    for miss in ['missing_pred_count', 'missing_action_count']:
        num_missing = len(stats[miss])
        if num_missing:
            alert_str = f'{num_missing} {miss} rule_stats'
            print(yellow_font(alert_str))

    # inform user about fan-out stats
    mode_fan_out = stats['fan_out_count'].most_common(1)[0][0]
    max_fan_out = max(stats['fan_out_count'].keys())
    alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}'
    print(alert_str)

    # Close file
    write_out_states()
    write_out_valid_actions()
def main():

    # Argument handling
    args = argument_parser()

    # Get data
    sentences = read_tokenized_sentences(args.in_sentences,
                                         separator=args.separator)

    # Initialize logger/printer
    logger = Logger(step_by_step=args.step_by_step,
                    clear_print=args.clear_print,
                    pause_time=args.pause_time,
                    verbose=args.verbose)

    # generate rules to restrict action space by stack content
    if args.action_rules_from_stats:
        rule_stats = read_rule_stats(args.action_rules_from_stats)
        actions_by_stack_rules = rule_stats['possible_predicates']
        for token, counter in rule_stats['possible_predicates'].items():
            actions_by_stack_rules[token] = Counter(counter)

    else:
        actions_by_stack_rules = None

    # Fake parser built from actions
    actions = read_tokenized_sentences(args.in_actions,
                                       separator=args.separator)
    assert len(sentences) == len(actions)
    parsing_model = FakeAMRParser(
        from_sent_act_pairs=zip(sentences, actions),
        machine_type=args.machine_type,
        logger=logger,
        actions_by_stack_rules=actions_by_stack_rules,
        no_whitespace_in_actions=args.no_whitespace_in_actions,
        entity_rules=args.entity_rules)

    # Get output AMR writer
    if args.out_amr:
        amr_write = writer(args.out_amr)
    if args.out_bio_tags:
        bio_write = writer(args.out_bio_tags)

    # Loop over sentences
    for sent_idx, tokens in tqdm(enumerate(sentences), desc='parsing'):

        # fast-forward until desired sentence number
        if args.offset and sent_idx < args.offset:
            continue

        # parse
        # NOTE: To simulate the real endpoint, input provided as a string of
        # whitespace separated tokens
        machine, bio_tags = parsing_model.parse_sentence(" ".join(tokens))

        # sanity check annotations
        dupes = get_duplicate_edges(machine.amr)
        if args.sanity_check and any(dupes):
            msg = yellow_font('WARNING:')
            print(f'{msg} duplicated edges in sent {sent_idx}', end=' ')
            print(dict(dupes))
            print(' '.join(machine.tokens))

        # store output AMR
        if args.out_bio_tags:
            tag_str = '\n'.join([f'{to} {ta}' for to, ta in bio_tags])
            tag_str += '\n\n'
            bio_write(tag_str)
        if args.out_amr:
            try:
                amr_write(machine.amr.toJAMRString())
            except InvalidAMRError as exception:
                print(f'\nFailed at sentence {sent_idx}\n')
                raise exception

    if (getattr(parsing_model, "rule_violation")
            and parsing_model.rule_violation):
        print(yellow_font("There were one or more action rule violations"))
        print(parsing_model.rule_violation)

    if args.action_rules_from_stats:
        print("Predict rules had following statistics")
        print(parsing_model.pred_counts)

    # close output writers
    if args.out_amr:
        amr_write()
    if args.out_bio_tags:
        bio_write()