示例#1
0
def merge_rules(sentences, actions, rule_stats, entity_rules=None):

    # generate rules to restrict action space by stack content
    actions_by_stack_rules = rule_stats['possible_predicates']
    for token, counter in rule_stats['possible_predicates'].items():
        actions_by_stack_rules[token] = Counter(counter)

    spacy_lemmatizer = get_spacy_lemmatizer()

    possible_predicates = defaultdict(lambda: Counter())
    for index, sentence_actions in tqdm(enumerate(actions), desc='merge rules'):

        tokens = sentences[index]

        # Initialize machine
        state_machine = AMRStateMachine(
            tokens,
            actions_by_stack_rules=actions_by_stack_rules,
            spacy_lemmatizer=spacy_lemmatizer,
            entity_rules=entity_rules
        )

        for action in sentence_actions:
            # NOTE: At the oracle, possible predicates are collected before
            # PRED/COPY decision (tryConfirm action) we have to take all of
            # them into account
            position, mpositions = \
                state_machine.get_top_of_stack(positions=True)
            if action.startswith('PRED'):
                node = action[5:-1]
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_LEMMA':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = lemma
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_SENSE01':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = f'{lemma}-01'
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            # execute action
            state_machine.applyAction(action)

    out_rule_stats = rule_stats
    new_possible_predicates = merge_both_rules(possible_predicates, actions_by_stack_rules)
    out_rule_stats['possible_predicates'] = new_possible_predicates

    return out_rule_stats
    def runOracle(self,
                  gold_amrs,
                  propbank_args=None,
                  out_oracle=None,
                  out_amr=None,
                  out_sentences=None,
                  out_actions=None,
                  out_rule_stats=None,
                  add_unaligned=0,
                  no_whitespace_in_actions=False,
                  multitask_words=None,
                  copy_lemma_action=False,
                  addnode_count_cutoff=None):

        print_log("oracle", "Parsing data")
        # deep copy of gold AMRs
        self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs]

        # print about inconsistencies in annotations
        alert_inconsistencies(self.gold_amrs)

        # open all files (if paths provided) and get writers to them
        oracle_write = writer(out_oracle)
        amr_write = writer(out_amr)
        sentence_write = writer(out_sentences, add_return=True)
        actions_write = writer(out_actions, add_return=True)

        # This will store overall stats
        self.stats = {
            'possible_predicates': Counter(),
            'action_vocabulary': Counter(),
            'addnode_counts': Counter()
        }

        # unaligned tokens
        included_unaligned = [
            '-',
            'and',
            'multi-sentence',
            'person',
            'cause-01',
            'you',
            'more',
            'imperative',
            '1',
            'thing',
        ]

        # initialize spacy lemmatizer out of the sentence loop for speed
        spacy_lemmatizer = None
        if copy_lemma_action:
            spacy_lemmatizer = get_spacy_lemmatizer()

        # Loop over golf AMRs
        for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs),
                                       desc=f'computing oracle',
                                       total=len(self.gold_amrs)):

            if self.verbose:
                print("New Sentence " + str(sent_idx) + "\n\n\n")

            # TODO: Describe what is this pre-processing
            gold_amr = preprocess_amr(gold_amr, add_unaligned,
                                      included_unaligned)

            # Initialize state machine
            tr = AMRStateMachine(gold_amr.tokens,
                                 verbose=self.verbose,
                                 add_unaligned=add_unaligned,
                                 spacy_lemmatizer=spacy_lemmatizer,
                                 entity_rules=self.entity_rules)
            self.transitions.append(tr)
            self.amrs.append(tr.amr)

            # Loop over potential actions
            while tr.buffer or tr.stack:

                if self.tryMerge(tr, tr.amr, gold_amr):
                    action = 'MERGE'

                elif self.tryEntity(tr, tr.amr, gold_amr):
                    action = f'ADDNODE({self.entity_type})'

                elif self.tryDependent(tr, tr.amr, gold_amr):
                    edge = self.new_edge[1:] \
                        if self.new_edge.startswith(':') else self.new_edge
                    action = f'DEPENDENT({self.new_node},{edge})'
                    self.dep_id = None

                elif self.tryConfirm(tr, tr.amr, gold_amr):
                    # if --copy-lemma-action check if lemma or first sense
                    # equal node name. Use corresponding action
                    if copy_lemma_action:
                        lemma, _ = tr.get_top_of_stack(lemma=True)
                        if copy_lemma_action and lemma == self.new_node:
                            action = 'COPY_LEMMA'
                        elif copy_lemma_action and f'{lemma}-01' == self.new_node:
                            action = 'COPY_SENSE01'
                        else:
                            action = f'PRED({self.new_node})'
                    else:
                        action = f'PRED({self.new_node})'

                elif self.tryIntroduce(tr, tr.amr, gold_amr):
                    action = 'INTRODUCE'

                elif self.tryLA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'LA({self.new_edge})'
                    else:
                        action = f'LA({self.new_edge[1:]})'

                elif self.tryRA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'RA({self.new_edge})'
                    else:
                        action = f'RA({self.new_edge[1:]})'

                elif self.tryReduce(tr, tr.amr, gold_amr):
                    action = 'REDUCE'

                elif self.trySWAP(tr, tr.amr, gold_amr):
                    action = 'UNSHIFT'

                elif tr.buffer:
                    action = 'SHIFT'

                else:
                    tr.stack = []
                    tr.buffer = []
                    break

                # Store stats
                # get token(s) at the top of the stack
                token, merged_tokens = tr.get_top_of_stack()
                action_label = action.split('(')[0]

                # check action has not invalid chars and normalize
                # TODO: --no-whitespace-in-actions being deprecated
                if no_whitespace_in_actions and action_label == 'PRED':
                    assert '_' not in action, \
                        "--no-whitespace-in-actions prohibits use of _ in actions"
                    if ' ' in action_label:
                        action = action.replace(' ', '_')

                # Add prediction ot top of the buffer
                if action == 'SHIFT' and multitask_words is not None:
                    action = label_shift(tr, multitask_words)

                # APPLY ACTION
                tr.applyAction(action)

            # Close machine
            tr.CLOSE(training=True,
                     gold_amr=gold_amr,
                     use_addnonde_rules=use_addnode_rules)

            # update files
            if out_oracle:
                # to avoid printing
                oracle_write(str(tr))
            # JAMR format AMR
            amr_write(tr.amr.toJAMRString())
            # Tokens and actions
            # extra tag to be reduced at start
            tokens = tr.amr.tokens
            actions = tr.actions

            # Update action count
            self.stats['action_vocabulary'].update(actions)
            del gold_amr.nodes[-1]
            addnode_actions = [a for a in actions if a.startswith('ADDNODE')]
            self.stats['addnode_counts'].update(addnode_actions)

            # separator
            if no_whitespace_in_actions:
                sep = " "
            else:
                sep = "\t"
            tokens = sep.join(tokens)
            actions = sep.join(actions)
            # Write
            sentence_write(tokens)
            actions_write(actions)

        print_log("oracle", "Done")

        # close files if open
        oracle_write()
        amr_write()
        sentence_write()
        actions_write()

        self.labelsO2idx = {'<pad>': 0}
        self.labelsA2idx = {'<pad>': 0}
        self.pred2idx = {'<pad>': 0}
        self.action2idx = {'<pad>': 0}

        for tr in self.transitions:
            for a in tr.actions:
                a = AMRStateMachine.readAction(a)[0]
                self.action2idx.setdefault(a, len(self.action2idx))
            for p in tr.predicates:
                self.pred2idx.setdefault(p, len(self.pred2idx))
            for l in tr.labels:
                self.labelsO2idx.setdefault(l, len(self.labelsO2idx))
            for l in tr.labelsA:
                self.labelsA2idx.setdefault(l, len(self.labelsA2idx))

        self.stats["action2idx"] = self.action2idx
        self.stats["pred2idx"] = self.pred2idx
        self.stats["labelsO2idx"] = self.labelsO2idx
        self.stats["labelsA2idx"] = self.labelsA2idx

        # Compute the word dictionary

        self.char2idx = {'<unk>': 0}
        self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3}
        self.node2idx = {}
        self.word_counter = Counter()

        for amr in self.gold_amrs:
            for tok in amr.tokens:
                self.word_counter[tok] += 1
                self.word2idx.setdefault(tok, len(self.word2idx))
                for ch in tok:
                    self.char2idx.setdefault(ch, len(self.char2idx))
            for n in amr.nodes:
                self.node2idx.setdefault(amr.nodes[n], len(self.node2idx))

        self.stats["char2idx"] = self.char2idx
        self.stats["word2idx"] = self.word2idx
        self.stats["node2idx"] = self.node2idx
        self.stats["word_counter"] = self.word_counter

        self.stats['possible_predicates'] = self.possiblePredicates

        if addnode_count_cutoff:
            self.stats['addnode_blacklist'] = [
                a for a, c in self.stats['addnode_counts'].items()
                if c <= addnode_count_cutoff
            ]
            num_addnode_blackl = len(self.stats['addnode_blacklist'])
            num_addnode = len(self.stats['addnode_counts'])
            print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES')
            del self.stats['addnode_counts']

        # State machine stats for this senetnce
        if out_rule_stats:
            with open(out_rule_stats, 'w') as fid:
                fid.write(json.dumps(self.stats))