def get_amr(tokens, actions, entity_rules): # play state machine to get AMR state_machine = AMRStateMachine(tokens, entity_rules=entity_rules) for action in actions: # FIXME: It is unclear that this will be allways the right option # manual exploration of dev yielded 4 cases and it works for the 4 if action == "<unk>": action = f'PRED({state_machine.get_top_of_stack()[0].lower()})' state_machine.applyAction(action) # sanity check: foce close if not state_machine.is_closed: alert_str = yellow_font('Machine not closed!') print(alert_str) state_machine.CLOSE() # TODO: Probably waisting ressources here amr_str = state_machine.amr.toJAMRString() return AMR.get_amr_line(amr_str.split('\n'))
def runOracle(self, gold_amrs, propbank_args=None, out_oracle=None, out_amr=None, out_sentences=None, out_actions=None, out_rule_stats=None, add_unaligned=0, no_whitespace_in_actions=False, multitask_words=None, copy_lemma_action=False, addnode_count_cutoff=None): print_log("oracle", "Parsing data") # deep copy of gold AMRs self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs] # print about inconsistencies in annotations alert_inconsistencies(self.gold_amrs) # open all files (if paths provided) and get writers to them oracle_write = writer(out_oracle) amr_write = writer(out_amr) sentence_write = writer(out_sentences, add_return=True) actions_write = writer(out_actions, add_return=True) # This will store overall stats self.stats = { 'possible_predicates': Counter(), 'action_vocabulary': Counter(), 'addnode_counts': Counter() } # unaligned tokens included_unaligned = [ '-', 'and', 'multi-sentence', 'person', 'cause-01', 'you', 'more', 'imperative', '1', 'thing', ] # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = None if copy_lemma_action: spacy_lemmatizer = get_spacy_lemmatizer() # Loop over golf AMRs for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs), desc=f'computing oracle', total=len(self.gold_amrs)): if self.verbose: print("New Sentence " + str(sent_idx) + "\n\n\n") # TODO: Describe what is this pre-processing gold_amr = preprocess_amr(gold_amr, add_unaligned, included_unaligned) # Initialize state machine tr = AMRStateMachine(gold_amr.tokens, verbose=self.verbose, add_unaligned=add_unaligned, spacy_lemmatizer=spacy_lemmatizer, entity_rules=self.entity_rules) self.transitions.append(tr) self.amrs.append(tr.amr) # Loop over potential actions while tr.buffer or tr.stack: if self.tryMerge(tr, tr.amr, gold_amr): action = 'MERGE' elif self.tryEntity(tr, tr.amr, gold_amr): action = f'ADDNODE({self.entity_type})' elif self.tryDependent(tr, tr.amr, gold_amr): edge = self.new_edge[1:] \ if self.new_edge.startswith(':') else self.new_edge action = f'DEPENDENT({self.new_node},{edge})' self.dep_id = None elif self.tryConfirm(tr, tr.amr, gold_amr): # if --copy-lemma-action check if lemma or first sense # equal node name. Use corresponding action if copy_lemma_action: lemma, _ = tr.get_top_of_stack(lemma=True) if copy_lemma_action and lemma == self.new_node: action = 'COPY_LEMMA' elif copy_lemma_action and f'{lemma}-01' == self.new_node: action = 'COPY_SENSE01' else: action = f'PRED({self.new_node})' else: action = f'PRED({self.new_node})' elif self.tryIntroduce(tr, tr.amr, gold_amr): action = 'INTRODUCE' elif self.tryLA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'LA({self.new_edge})' else: action = f'LA({self.new_edge[1:]})' elif self.tryRA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'RA({self.new_edge})' else: action = f'RA({self.new_edge[1:]})' elif self.tryReduce(tr, tr.amr, gold_amr): action = 'REDUCE' elif self.trySWAP(tr, tr.amr, gold_amr): action = 'UNSHIFT' elif tr.buffer: action = 'SHIFT' else: tr.stack = [] tr.buffer = [] break # Store stats # get token(s) at the top of the stack token, merged_tokens = tr.get_top_of_stack() action_label = action.split('(')[0] # check action has not invalid chars and normalize # TODO: --no-whitespace-in-actions being deprecated if no_whitespace_in_actions and action_label == 'PRED': assert '_' not in action, \ "--no-whitespace-in-actions prohibits use of _ in actions" if ' ' in action_label: action = action.replace(' ', '_') # Add prediction ot top of the buffer if action == 'SHIFT' and multitask_words is not None: action = label_shift(tr, multitask_words) # APPLY ACTION tr.applyAction(action) # Close machine tr.CLOSE(training=True, gold_amr=gold_amr, use_addnonde_rules=use_addnode_rules) # update files if out_oracle: # to avoid printing oracle_write(str(tr)) # JAMR format AMR amr_write(tr.amr.toJAMRString()) # Tokens and actions # extra tag to be reduced at start tokens = tr.amr.tokens actions = tr.actions # Update action count self.stats['action_vocabulary'].update(actions) del gold_amr.nodes[-1] addnode_actions = [a for a in actions if a.startswith('ADDNODE')] self.stats['addnode_counts'].update(addnode_actions) # separator if no_whitespace_in_actions: sep = " " else: sep = "\t" tokens = sep.join(tokens) actions = sep.join(actions) # Write sentence_write(tokens) actions_write(actions) print_log("oracle", "Done") # close files if open oracle_write() amr_write() sentence_write() actions_write() self.labelsO2idx = {'<pad>': 0} self.labelsA2idx = {'<pad>': 0} self.pred2idx = {'<pad>': 0} self.action2idx = {'<pad>': 0} for tr in self.transitions: for a in tr.actions: a = AMRStateMachine.readAction(a)[0] self.action2idx.setdefault(a, len(self.action2idx)) for p in tr.predicates: self.pred2idx.setdefault(p, len(self.pred2idx)) for l in tr.labels: self.labelsO2idx.setdefault(l, len(self.labelsO2idx)) for l in tr.labelsA: self.labelsA2idx.setdefault(l, len(self.labelsA2idx)) self.stats["action2idx"] = self.action2idx self.stats["pred2idx"] = self.pred2idx self.stats["labelsO2idx"] = self.labelsO2idx self.stats["labelsA2idx"] = self.labelsA2idx # Compute the word dictionary self.char2idx = {'<unk>': 0} self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3} self.node2idx = {} self.word_counter = Counter() for amr in self.gold_amrs: for tok in amr.tokens: self.word_counter[tok] += 1 self.word2idx.setdefault(tok, len(self.word2idx)) for ch in tok: self.char2idx.setdefault(ch, len(self.char2idx)) for n in amr.nodes: self.node2idx.setdefault(amr.nodes[n], len(self.node2idx)) self.stats["char2idx"] = self.char2idx self.stats["word2idx"] = self.word2idx self.stats["node2idx"] = self.node2idx self.stats["word_counter"] = self.word_counter self.stats['possible_predicates'] = self.possiblePredicates if addnode_count_cutoff: self.stats['addnode_blacklist'] = [ a for a, c in self.stats['addnode_counts'].items() if c <= addnode_count_cutoff ] num_addnode_blackl = len(self.stats['addnode_blacklist']) num_addnode = len(self.stats['addnode_counts']) print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES') del self.stats['addnode_counts'] # State machine stats for this senetnce if out_rule_stats: with open(out_rule_stats, 'w') as fid: fid.write(json.dumps(self.stats))