示例#1
0
def merge_rules(sentences, actions, rule_stats, entity_rules=None):

    # generate rules to restrict action space by stack content
    actions_by_stack_rules = rule_stats['possible_predicates']
    for token, counter in rule_stats['possible_predicates'].items():
        actions_by_stack_rules[token] = Counter(counter)

    spacy_lemmatizer = get_spacy_lemmatizer()

    possible_predicates = defaultdict(lambda: Counter())
    for index, sentence_actions in tqdm(enumerate(actions), desc='merge rules'):

        tokens = sentences[index]

        # Initialize machine
        state_machine = AMRStateMachine(
            tokens,
            actions_by_stack_rules=actions_by_stack_rules,
            spacy_lemmatizer=spacy_lemmatizer,
            entity_rules=entity_rules
        )

        for action in sentence_actions:
            # NOTE: At the oracle, possible predicates are collected before
            # PRED/COPY decision (tryConfirm action) we have to take all of
            # them into account
            position, mpositions = \
                state_machine.get_top_of_stack(positions=True)
            if action.startswith('PRED'):
                node = action[5:-1]
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_LEMMA':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = lemma
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_SENSE01':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = f'{lemma}-01'
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            # execute action
            state_machine.applyAction(action)

    out_rule_stats = rule_stats
    new_possible_predicates = merge_both_rules(possible_predicates, actions_by_stack_rules)
    out_rule_stats['possible_predicates'] = new_possible_predicates

    return out_rule_stats
示例#2
0
def get_amr(tokens, actions, entity_rules):

    # play state machine to get AMR
    state_machine = AMRStateMachine(tokens, entity_rules=entity_rules)
    for action in actions:
        # FIXME: It is unclear that this will be allways the right option
        # manual exploration of dev yielded 4 cases and it works for the 4
        if action == "<unk>":
            action = f'PRED({state_machine.get_top_of_stack()[0].lower()})'
        state_machine.applyAction(action)

    # sanity check: foce close
    if not state_machine.is_closed:
        alert_str = yellow_font('Machine not closed!')
        print(alert_str)
        state_machine.CLOSE()

    # TODO: Probably waisting ressources here
    amr_str = state_machine.amr.toJAMRString()
    return AMR.get_amr_line(amr_str.split('\n'))
def main():

    # Argument handling
    args = argument_parser()

    # read rules
    train_rule_stats = read_rule_stats(args.in_rule_stats)
    assert 'action_vocabulary' in train_rule_stats
    assert 'possible_predicates' in train_rule_stats
    action_list = list(sorted(train_rule_stats['action_vocabulary'].keys()))

    # get all actions indexec by action root
    action_by_basic = defaultdict(list)
    for action in train_rule_stats['action_vocabulary'].keys():
        key = action.split('(')[0]
        action_by_basic[key].append(action)

    # open file for reading if provided
    write_out_states = h5_writer(args.out_word_states)
    write_out_valid_actions = h5_writer(args.out_valid_actions)

    # Read content
    # TODO: Point to LDC data
    sentences = readlines(args.in_sentences)
    actions = readlines(args.in_actions)
    assert len(sentences) == len(actions)

    # initialize spacy lemmatizer out of the sentence loop for speed
    spacy_lemmatizer = get_spacy_lemmatizer()

    sent_idx = -1
    stats = {
        'missing_pred_count': Counter(),
        'missing_action_count': Counter(),
        'fan_out_count': Counter()
    }
    for sent_tokens, sent_actions in tqdm(
        zip(sentences, actions),
        desc='extracting oracle masks',
        total=len(actions)
    ):

        # keep count of sentence index
        sent_idx += 1
        if args.offset and sent_idx < args.offset:
            continue

        # Initialize state machine
        amr_state_machine = AMRStateMachine(
            sent_tokens,
            spacy_lemmatizer=spacy_lemmatizer
        )

        # process each action
        word_states_sent = []
        valid_actions_sent = []
        for raw_action in sent_actions:

            # Store states BEFORE ACTION
            # state of each word (buffer B, stack S, reduced X)
            word_states = get_word_states(amr_state_machine, sent_tokens)

            # Get actions valid for this state
            valid_actions = get_valid_actions(
                action_list,
                amr_state_machine,
                train_rule_stats,
                action_by_basic,
                raw_action,
                stats
            )

            # update info
            word_states_sent.append(word_states)
            valid_actions_sent.append(valid_actions)

            # Update machine
            amr_state_machine.applyAction(raw_action)

        # Write states for this sentence
        write_out_states(word_states_sent)
        write_out_valid_actions(valid_actions_sent)

    # inform usre about missing predicates
    for miss in ['missing_pred_count', 'missing_action_count']:
        num_missing = len(stats[miss])
        if num_missing:
            alert_str = f'{num_missing} {miss} rule_stats'
            print(yellow_font(alert_str))

    # inform user about fan-out stats
    mode_fan_out = stats['fan_out_count'].most_common(1)[0][0]
    max_fan_out = max(stats['fan_out_count'].keys())
    alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}'
    print(alert_str)

    # Close file
    write_out_states()
    write_out_valid_actions()
    def parse_sentence(self, sentence_str):
        """
        sentence_str is a string with whitespace separated tokens
        """

        # simulated actions given by a parsing model
        key = " ".join(sentence_str)
        assert sentence_str in self.actions_by_sentence, \
            "Fake parser has no actions for sentence: %s" % sentence_str
        actions = self.actions_by_sentence[sentence_str]
        tokens = sentence_str.split()

        # Initialize state machine
        if self.machine_type == 'AMR':
            state_machine = AMRStateMachine(
                tokens,
                actions_by_stack_rules=self.actions_by_stack_rules,
                spacy_lemmatizer=self.spacy_lemmatizer,
                entity_rules=self.entity_rules)
        elif self.machine_type == 'dep-parsing':
            state_machine = DepParsingStateMachine(tokens)

        # this will store AMR parsing as BIO tag (PRED, ADDNODE)
        bio_alignments = {}

        # execute parsing model
        while not state_machine.is_closed:

            # Print state (pause if solicited)
            self.logger.update(self.sent_idx, state_machine)

            if len(actions) <= state_machine.time_step:
                # if machine is not propperly closed hard exit
                print(
                    yellow_font(
                        f'machine not closed at step {state_machine.time_step}'
                    ))
                raw_action = 'CLOSE'
            else:
                # get action from model
                raw_action = actions[state_machine.time_step]

            # restrict action space according to machine restrictions and
            # statistics
            if self.machine_type == 'AMR':
                raw_action = restrict_action(state_machine, raw_action,
                                             self.pred_counts,
                                             self.rule_violation)

                # update bio tags from AMR
                bio_alignments.update(
                    get_bio_from_machine(state_machine, raw_action))

            # Update state machine
            state_machine.applyAction(raw_action)

        # build bio tags
        bio_tags = get_bio_tags(state_machine, bio_alignments)

        # count one sentence more
        self.sent_idx += 1

        return state_machine, bio_tags
    def runOracle(self,
                  gold_amrs,
                  propbank_args=None,
                  out_oracle=None,
                  out_amr=None,
                  out_sentences=None,
                  out_actions=None,
                  out_rule_stats=None,
                  add_unaligned=0,
                  no_whitespace_in_actions=False,
                  multitask_words=None,
                  copy_lemma_action=False,
                  addnode_count_cutoff=None):

        print_log("oracle", "Parsing data")
        # deep copy of gold AMRs
        self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs]

        # print about inconsistencies in annotations
        alert_inconsistencies(self.gold_amrs)

        # open all files (if paths provided) and get writers to them
        oracle_write = writer(out_oracle)
        amr_write = writer(out_amr)
        sentence_write = writer(out_sentences, add_return=True)
        actions_write = writer(out_actions, add_return=True)

        # This will store overall stats
        self.stats = {
            'possible_predicates': Counter(),
            'action_vocabulary': Counter(),
            'addnode_counts': Counter()
        }

        # unaligned tokens
        included_unaligned = [
            '-',
            'and',
            'multi-sentence',
            'person',
            'cause-01',
            'you',
            'more',
            'imperative',
            '1',
            'thing',
        ]

        # initialize spacy lemmatizer out of the sentence loop for speed
        spacy_lemmatizer = None
        if copy_lemma_action:
            spacy_lemmatizer = get_spacy_lemmatizer()

        # Loop over golf AMRs
        for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs),
                                       desc=f'computing oracle',
                                       total=len(self.gold_amrs)):

            if self.verbose:
                print("New Sentence " + str(sent_idx) + "\n\n\n")

            # TODO: Describe what is this pre-processing
            gold_amr = preprocess_amr(gold_amr, add_unaligned,
                                      included_unaligned)

            # Initialize state machine
            tr = AMRStateMachine(gold_amr.tokens,
                                 verbose=self.verbose,
                                 add_unaligned=add_unaligned,
                                 spacy_lemmatizer=spacy_lemmatizer,
                                 entity_rules=self.entity_rules)
            self.transitions.append(tr)
            self.amrs.append(tr.amr)

            # Loop over potential actions
            while tr.buffer or tr.stack:

                if self.tryMerge(tr, tr.amr, gold_amr):
                    action = 'MERGE'

                elif self.tryEntity(tr, tr.amr, gold_amr):
                    action = f'ADDNODE({self.entity_type})'

                elif self.tryDependent(tr, tr.amr, gold_amr):
                    edge = self.new_edge[1:] \
                        if self.new_edge.startswith(':') else self.new_edge
                    action = f'DEPENDENT({self.new_node},{edge})'
                    self.dep_id = None

                elif self.tryConfirm(tr, tr.amr, gold_amr):
                    # if --copy-lemma-action check if lemma or first sense
                    # equal node name. Use corresponding action
                    if copy_lemma_action:
                        lemma, _ = tr.get_top_of_stack(lemma=True)
                        if copy_lemma_action and lemma == self.new_node:
                            action = 'COPY_LEMMA'
                        elif copy_lemma_action and f'{lemma}-01' == self.new_node:
                            action = 'COPY_SENSE01'
                        else:
                            action = f'PRED({self.new_node})'
                    else:
                        action = f'PRED({self.new_node})'

                elif self.tryIntroduce(tr, tr.amr, gold_amr):
                    action = 'INTRODUCE'

                elif self.tryLA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'LA({self.new_edge})'
                    else:
                        action = f'LA({self.new_edge[1:]})'

                elif self.tryRA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'RA({self.new_edge})'
                    else:
                        action = f'RA({self.new_edge[1:]})'

                elif self.tryReduce(tr, tr.amr, gold_amr):
                    action = 'REDUCE'

                elif self.trySWAP(tr, tr.amr, gold_amr):
                    action = 'UNSHIFT'

                elif tr.buffer:
                    action = 'SHIFT'

                else:
                    tr.stack = []
                    tr.buffer = []
                    break

                # Store stats
                # get token(s) at the top of the stack
                token, merged_tokens = tr.get_top_of_stack()
                action_label = action.split('(')[0]

                # check action has not invalid chars and normalize
                # TODO: --no-whitespace-in-actions being deprecated
                if no_whitespace_in_actions and action_label == 'PRED':
                    assert '_' not in action, \
                        "--no-whitespace-in-actions prohibits use of _ in actions"
                    if ' ' in action_label:
                        action = action.replace(' ', '_')

                # Add prediction ot top of the buffer
                if action == 'SHIFT' and multitask_words is not None:
                    action = label_shift(tr, multitask_words)

                # APPLY ACTION
                tr.applyAction(action)

            # Close machine
            tr.CLOSE(training=True,
                     gold_amr=gold_amr,
                     use_addnonde_rules=use_addnode_rules)

            # update files
            if out_oracle:
                # to avoid printing
                oracle_write(str(tr))
            # JAMR format AMR
            amr_write(tr.amr.toJAMRString())
            # Tokens and actions
            # extra tag to be reduced at start
            tokens = tr.amr.tokens
            actions = tr.actions

            # Update action count
            self.stats['action_vocabulary'].update(actions)
            del gold_amr.nodes[-1]
            addnode_actions = [a for a in actions if a.startswith('ADDNODE')]
            self.stats['addnode_counts'].update(addnode_actions)

            # separator
            if no_whitespace_in_actions:
                sep = " "
            else:
                sep = "\t"
            tokens = sep.join(tokens)
            actions = sep.join(actions)
            # Write
            sentence_write(tokens)
            actions_write(actions)

        print_log("oracle", "Done")

        # close files if open
        oracle_write()
        amr_write()
        sentence_write()
        actions_write()

        self.labelsO2idx = {'<pad>': 0}
        self.labelsA2idx = {'<pad>': 0}
        self.pred2idx = {'<pad>': 0}
        self.action2idx = {'<pad>': 0}

        for tr in self.transitions:
            for a in tr.actions:
                a = AMRStateMachine.readAction(a)[0]
                self.action2idx.setdefault(a, len(self.action2idx))
            for p in tr.predicates:
                self.pred2idx.setdefault(p, len(self.pred2idx))
            for l in tr.labels:
                self.labelsO2idx.setdefault(l, len(self.labelsO2idx))
            for l in tr.labelsA:
                self.labelsA2idx.setdefault(l, len(self.labelsA2idx))

        self.stats["action2idx"] = self.action2idx
        self.stats["pred2idx"] = self.pred2idx
        self.stats["labelsO2idx"] = self.labelsO2idx
        self.stats["labelsA2idx"] = self.labelsA2idx

        # Compute the word dictionary

        self.char2idx = {'<unk>': 0}
        self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3}
        self.node2idx = {}
        self.word_counter = Counter()

        for amr in self.gold_amrs:
            for tok in amr.tokens:
                self.word_counter[tok] += 1
                self.word2idx.setdefault(tok, len(self.word2idx))
                for ch in tok:
                    self.char2idx.setdefault(ch, len(self.char2idx))
            for n in amr.nodes:
                self.node2idx.setdefault(amr.nodes[n], len(self.node2idx))

        self.stats["char2idx"] = self.char2idx
        self.stats["word2idx"] = self.word2idx
        self.stats["node2idx"] = self.node2idx
        self.stats["word_counter"] = self.word_counter

        self.stats['possible_predicates'] = self.possiblePredicates

        if addnode_count_cutoff:
            self.stats['addnode_blacklist'] = [
                a for a, c in self.stats['addnode_counts'].items()
                if c <= addnode_count_cutoff
            ]
            num_addnode_blackl = len(self.stats['addnode_blacklist'])
            num_addnode = len(self.stats['addnode_counts'])
            print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES')
            del self.stats['addnode_counts']

        # State machine stats for this senetnce
        if out_rule_stats:
            with open(out_rule_stats, 'w') as fid:
                fid.write(json.dumps(self.stats))