def merge_rules(sentences, actions, rule_stats, entity_rules=None): # generate rules to restrict action space by stack content actions_by_stack_rules = rule_stats['possible_predicates'] for token, counter in rule_stats['possible_predicates'].items(): actions_by_stack_rules[token] = Counter(counter) spacy_lemmatizer = get_spacy_lemmatizer() possible_predicates = defaultdict(lambda: Counter()) for index, sentence_actions in tqdm(enumerate(actions), desc='merge rules'): tokens = sentences[index] # Initialize machine state_machine = AMRStateMachine( tokens, actions_by_stack_rules=actions_by_stack_rules, spacy_lemmatizer=spacy_lemmatizer, entity_rules=entity_rules ) for action in sentence_actions: # NOTE: At the oracle, possible predicates are collected before # PRED/COPY decision (tryConfirm action) we have to take all of # them into account position, mpositions = \ state_machine.get_top_of_stack(positions=True) if action.startswith('PRED'): node = action[5:-1] possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) elif action == 'COPY_LEMMA': lemma, _ = state_machine.get_top_of_stack(lemma=True) node = lemma possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) elif action == 'COPY_SENSE01': lemma, _ = state_machine.get_top_of_stack(lemma=True) node = f'{lemma}-01' possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) # execute action state_machine.applyAction(action) out_rule_stats = rule_stats new_possible_predicates = merge_both_rules(possible_predicates, actions_by_stack_rules) out_rule_stats['possible_predicates'] = new_possible_predicates return out_rule_stats
def get_amr(tokens, actions, entity_rules): # play state machine to get AMR state_machine = AMRStateMachine(tokens, entity_rules=entity_rules) for action in actions: # FIXME: It is unclear that this will be allways the right option # manual exploration of dev yielded 4 cases and it works for the 4 if action == "<unk>": action = f'PRED({state_machine.get_top_of_stack()[0].lower()})' state_machine.applyAction(action) # sanity check: foce close if not state_machine.is_closed: alert_str = yellow_font('Machine not closed!') print(alert_str) state_machine.CLOSE() # TODO: Probably waisting ressources here amr_str = state_machine.amr.toJAMRString() return AMR.get_amr_line(amr_str.split('\n'))
def main(): # Argument handling args = argument_parser() # read rules train_rule_stats = read_rule_stats(args.in_rule_stats) assert 'action_vocabulary' in train_rule_stats assert 'possible_predicates' in train_rule_stats action_list = list(sorted(train_rule_stats['action_vocabulary'].keys())) # get all actions indexec by action root action_by_basic = defaultdict(list) for action in train_rule_stats['action_vocabulary'].keys(): key = action.split('(')[0] action_by_basic[key].append(action) # open file for reading if provided write_out_states = h5_writer(args.out_word_states) write_out_valid_actions = h5_writer(args.out_valid_actions) # Read content # TODO: Point to LDC data sentences = readlines(args.in_sentences) actions = readlines(args.in_actions) assert len(sentences) == len(actions) # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = get_spacy_lemmatizer() sent_idx = -1 stats = { 'missing_pred_count': Counter(), 'missing_action_count': Counter(), 'fan_out_count': Counter() } for sent_tokens, sent_actions in tqdm( zip(sentences, actions), desc='extracting oracle masks', total=len(actions) ): # keep count of sentence index sent_idx += 1 if args.offset and sent_idx < args.offset: continue # Initialize state machine amr_state_machine = AMRStateMachine( sent_tokens, spacy_lemmatizer=spacy_lemmatizer ) # process each action word_states_sent = [] valid_actions_sent = [] for raw_action in sent_actions: # Store states BEFORE ACTION # state of each word (buffer B, stack S, reduced X) word_states = get_word_states(amr_state_machine, sent_tokens) # Get actions valid for this state valid_actions = get_valid_actions( action_list, amr_state_machine, train_rule_stats, action_by_basic, raw_action, stats ) # update info word_states_sent.append(word_states) valid_actions_sent.append(valid_actions) # Update machine amr_state_machine.applyAction(raw_action) # Write states for this sentence write_out_states(word_states_sent) write_out_valid_actions(valid_actions_sent) # inform usre about missing predicates for miss in ['missing_pred_count', 'missing_action_count']: num_missing = len(stats[miss]) if num_missing: alert_str = f'{num_missing} {miss} rule_stats' print(yellow_font(alert_str)) # inform user about fan-out stats mode_fan_out = stats['fan_out_count'].most_common(1)[0][0] max_fan_out = max(stats['fan_out_count'].keys()) alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}' print(alert_str) # Close file write_out_states() write_out_valid_actions()
def parse_sentence(self, sentence_str): """ sentence_str is a string with whitespace separated tokens """ # simulated actions given by a parsing model key = " ".join(sentence_str) assert sentence_str in self.actions_by_sentence, \ "Fake parser has no actions for sentence: %s" % sentence_str actions = self.actions_by_sentence[sentence_str] tokens = sentence_str.split() # Initialize state machine if self.machine_type == 'AMR': state_machine = AMRStateMachine( tokens, actions_by_stack_rules=self.actions_by_stack_rules, spacy_lemmatizer=self.spacy_lemmatizer, entity_rules=self.entity_rules) elif self.machine_type == 'dep-parsing': state_machine = DepParsingStateMachine(tokens) # this will store AMR parsing as BIO tag (PRED, ADDNODE) bio_alignments = {} # execute parsing model while not state_machine.is_closed: # Print state (pause if solicited) self.logger.update(self.sent_idx, state_machine) if len(actions) <= state_machine.time_step: # if machine is not propperly closed hard exit print( yellow_font( f'machine not closed at step {state_machine.time_step}' )) raw_action = 'CLOSE' else: # get action from model raw_action = actions[state_machine.time_step] # restrict action space according to machine restrictions and # statistics if self.machine_type == 'AMR': raw_action = restrict_action(state_machine, raw_action, self.pred_counts, self.rule_violation) # update bio tags from AMR bio_alignments.update( get_bio_from_machine(state_machine, raw_action)) # Update state machine state_machine.applyAction(raw_action) # build bio tags bio_tags = get_bio_tags(state_machine, bio_alignments) # count one sentence more self.sent_idx += 1 return state_machine, bio_tags
def runOracle(self, gold_amrs, propbank_args=None, out_oracle=None, out_amr=None, out_sentences=None, out_actions=None, out_rule_stats=None, add_unaligned=0, no_whitespace_in_actions=False, multitask_words=None, copy_lemma_action=False, addnode_count_cutoff=None): print_log("oracle", "Parsing data") # deep copy of gold AMRs self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs] # print about inconsistencies in annotations alert_inconsistencies(self.gold_amrs) # open all files (if paths provided) and get writers to them oracle_write = writer(out_oracle) amr_write = writer(out_amr) sentence_write = writer(out_sentences, add_return=True) actions_write = writer(out_actions, add_return=True) # This will store overall stats self.stats = { 'possible_predicates': Counter(), 'action_vocabulary': Counter(), 'addnode_counts': Counter() } # unaligned tokens included_unaligned = [ '-', 'and', 'multi-sentence', 'person', 'cause-01', 'you', 'more', 'imperative', '1', 'thing', ] # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = None if copy_lemma_action: spacy_lemmatizer = get_spacy_lemmatizer() # Loop over golf AMRs for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs), desc=f'computing oracle', total=len(self.gold_amrs)): if self.verbose: print("New Sentence " + str(sent_idx) + "\n\n\n") # TODO: Describe what is this pre-processing gold_amr = preprocess_amr(gold_amr, add_unaligned, included_unaligned) # Initialize state machine tr = AMRStateMachine(gold_amr.tokens, verbose=self.verbose, add_unaligned=add_unaligned, spacy_lemmatizer=spacy_lemmatizer, entity_rules=self.entity_rules) self.transitions.append(tr) self.amrs.append(tr.amr) # Loop over potential actions while tr.buffer or tr.stack: if self.tryMerge(tr, tr.amr, gold_amr): action = 'MERGE' elif self.tryEntity(tr, tr.amr, gold_amr): action = f'ADDNODE({self.entity_type})' elif self.tryDependent(tr, tr.amr, gold_amr): edge = self.new_edge[1:] \ if self.new_edge.startswith(':') else self.new_edge action = f'DEPENDENT({self.new_node},{edge})' self.dep_id = None elif self.tryConfirm(tr, tr.amr, gold_amr): # if --copy-lemma-action check if lemma or first sense # equal node name. Use corresponding action if copy_lemma_action: lemma, _ = tr.get_top_of_stack(lemma=True) if copy_lemma_action and lemma == self.new_node: action = 'COPY_LEMMA' elif copy_lemma_action and f'{lemma}-01' == self.new_node: action = 'COPY_SENSE01' else: action = f'PRED({self.new_node})' else: action = f'PRED({self.new_node})' elif self.tryIntroduce(tr, tr.amr, gold_amr): action = 'INTRODUCE' elif self.tryLA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'LA({self.new_edge})' else: action = f'LA({self.new_edge[1:]})' elif self.tryRA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'RA({self.new_edge})' else: action = f'RA({self.new_edge[1:]})' elif self.tryReduce(tr, tr.amr, gold_amr): action = 'REDUCE' elif self.trySWAP(tr, tr.amr, gold_amr): action = 'UNSHIFT' elif tr.buffer: action = 'SHIFT' else: tr.stack = [] tr.buffer = [] break # Store stats # get token(s) at the top of the stack token, merged_tokens = tr.get_top_of_stack() action_label = action.split('(')[0] # check action has not invalid chars and normalize # TODO: --no-whitespace-in-actions being deprecated if no_whitespace_in_actions and action_label == 'PRED': assert '_' not in action, \ "--no-whitespace-in-actions prohibits use of _ in actions" if ' ' in action_label: action = action.replace(' ', '_') # Add prediction ot top of the buffer if action == 'SHIFT' and multitask_words is not None: action = label_shift(tr, multitask_words) # APPLY ACTION tr.applyAction(action) # Close machine tr.CLOSE(training=True, gold_amr=gold_amr, use_addnonde_rules=use_addnode_rules) # update files if out_oracle: # to avoid printing oracle_write(str(tr)) # JAMR format AMR amr_write(tr.amr.toJAMRString()) # Tokens and actions # extra tag to be reduced at start tokens = tr.amr.tokens actions = tr.actions # Update action count self.stats['action_vocabulary'].update(actions) del gold_amr.nodes[-1] addnode_actions = [a for a in actions if a.startswith('ADDNODE')] self.stats['addnode_counts'].update(addnode_actions) # separator if no_whitespace_in_actions: sep = " " else: sep = "\t" tokens = sep.join(tokens) actions = sep.join(actions) # Write sentence_write(tokens) actions_write(actions) print_log("oracle", "Done") # close files if open oracle_write() amr_write() sentence_write() actions_write() self.labelsO2idx = {'<pad>': 0} self.labelsA2idx = {'<pad>': 0} self.pred2idx = {'<pad>': 0} self.action2idx = {'<pad>': 0} for tr in self.transitions: for a in tr.actions: a = AMRStateMachine.readAction(a)[0] self.action2idx.setdefault(a, len(self.action2idx)) for p in tr.predicates: self.pred2idx.setdefault(p, len(self.pred2idx)) for l in tr.labels: self.labelsO2idx.setdefault(l, len(self.labelsO2idx)) for l in tr.labelsA: self.labelsA2idx.setdefault(l, len(self.labelsA2idx)) self.stats["action2idx"] = self.action2idx self.stats["pred2idx"] = self.pred2idx self.stats["labelsO2idx"] = self.labelsO2idx self.stats["labelsA2idx"] = self.labelsA2idx # Compute the word dictionary self.char2idx = {'<unk>': 0} self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3} self.node2idx = {} self.word_counter = Counter() for amr in self.gold_amrs: for tok in amr.tokens: self.word_counter[tok] += 1 self.word2idx.setdefault(tok, len(self.word2idx)) for ch in tok: self.char2idx.setdefault(ch, len(self.char2idx)) for n in amr.nodes: self.node2idx.setdefault(amr.nodes[n], len(self.node2idx)) self.stats["char2idx"] = self.char2idx self.stats["word2idx"] = self.word2idx self.stats["node2idx"] = self.node2idx self.stats["word_counter"] = self.word_counter self.stats['possible_predicates'] = self.possiblePredicates if addnode_count_cutoff: self.stats['addnode_blacklist'] = [ a for a, c in self.stats['addnode_counts'].items() if c <= addnode_count_cutoff ] num_addnode_blackl = len(self.stats['addnode_blacklist']) num_addnode = len(self.stats['addnode_counts']) print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES') del self.stats['addnode_counts'] # State machine stats for this senetnce if out_rule_stats: with open(out_rule_stats, 'w') as fid: fid.write(json.dumps(self.stats))