def get_new_state_machine(sent_tokens, machine_type=None): nonlocal actions_by_stack_rules nonlocal spacy_lemmatizer nonlocal entity_rules # automatic determination of machine if no flag provided if sent_tokens[0] in ['<NER>', '<AMR>', 'SRL']: assert machine_type is None, \ "specify --machine-type OR pre-append <machine-type token>" machine_type = sent_tokens[0][1:-1] elif machine_type is None: Exception( "needs either --machine-type or appending <machine-type token>" ) # select machine if machine_type == 'AMR': return AMRStateMachine( sent_tokens, actions_by_stack_rules=actions_by_stack_rules, spacy_lemmatizer=spacy_lemmatizer, entity_rules=entity_rules ) elif machine_type == 'dep-parsing': assert sent_tokens[-1] == 'ROOT' # sent_tokens.pop() # sent_tokens.append('<ROOT>') return DepParsingStateMachine(sent_tokens) elif machine_type in ['NER', 'SRL']: from bio_tags.machine import BIOStateMachine return BIOStateMachine(sent_tokens) else: raise Exception(f'Unknown machine {machine_type}')
def merge_rules(sentences, actions, rule_stats, entity_rules=None): # generate rules to restrict action space by stack content actions_by_stack_rules = rule_stats['possible_predicates'] for token, counter in rule_stats['possible_predicates'].items(): actions_by_stack_rules[token] = Counter(counter) spacy_lemmatizer = get_spacy_lemmatizer() possible_predicates = defaultdict(lambda: Counter()) for index, sentence_actions in tqdm(enumerate(actions), desc='merge rules'): tokens = sentences[index] # Initialize machine state_machine = AMRStateMachine( tokens, actions_by_stack_rules=actions_by_stack_rules, spacy_lemmatizer=spacy_lemmatizer, entity_rules=entity_rules ) for action in sentence_actions: # NOTE: At the oracle, possible predicates are collected before # PRED/COPY decision (tryConfirm action) we have to take all of # them into account position, mpositions = \ state_machine.get_top_of_stack(positions=True) if action.startswith('PRED'): node = action[5:-1] possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) elif action == 'COPY_LEMMA': lemma, _ = state_machine.get_top_of_stack(lemma=True) node = lemma possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) elif action == 'COPY_SENSE01': lemma, _ = state_machine.get_top_of_stack(lemma=True) node = f'{lemma}-01' possible_predicates[tokens[position]].update([node]) if mpositions: mtokens = ','.join([tokens[p] for p in mpositions]) possible_predicates[mtokens].update([node]) # execute action state_machine.applyAction(action) out_rule_stats = rule_stats new_possible_predicates = merge_both_rules(possible_predicates, actions_by_stack_rules) out_rule_stats['possible_predicates'] = new_possible_predicates return out_rule_stats
def get_amr(tokens, actions, entity_rules): # play state machine to get AMR state_machine = AMRStateMachine(tokens, entity_rules=entity_rules) for action in actions: # FIXME: It is unclear that this will be allways the right option # manual exploration of dev yielded 4 cases and it works for the 4 if action == "<unk>": action = f'PRED({state_machine.get_top_of_stack()[0].lower()})' state_machine.applyAction(action) # sanity check: foce close if not state_machine.is_closed: alert_str = yellow_font('Machine not closed!') print(alert_str) state_machine.CLOSE() # TODO: Probably waisting ressources here amr_str = state_machine.amr.toJAMRString() return AMR.get_amr_line(amr_str.split('\n'))
def build_amr(self, tokens, actions, labels, labelsA, predicates): apply_actions = [] for act, label, labelA, predicate in zip(actions, labels, labelsA, predicates): # print(act, label, labelA, predicate) if act.startswith('PR'): apply_actions.append(act + f'({predicate})') elif act.startswith('RA') or act.startswith( 'LA') and not act.endswith('(root)'): apply_actions.append(act + f'({label})') elif act.startswith('AD'): apply_actions.append(act + f'({labelA})') else: apply_actions.append(act) toks = [tok for tok in tokens if tok != "<eof>"] tr = AMRStateMachine(toks, verbose=False, spacy_lemmatizer=self.lemmatizer) tr.applyActions(apply_actions) return tr.amr
def read_actions(self, actions_file): transitions = [] with open(actions_file, 'r', encoding='utf8') as f: sentences = f.read() sentences = sentences.replace('\r', '') sentences = sentences.split('\n\n') for sent in sentences: if not sent.strip(): continue s = sent.split('\n') if len(s) < 2: raise IOError(f'Action file formatted incorrectly: {sent}') tokens = s[0].split('\t') actions = s[1].split('\t') transitions.append( AMRStateMachine(tokens, entity_rules=self.entity_rules)) transitions[-1].applyActions(actions) self.transitions = transitions
def eval_parser(exp_name, model, args, dev_sentences, h5py_test, epoch_idx, smatch_file, save_model, param_groups): # save also current learning rate learning_rate = ",".join([str(x['lr']) for x in param_groups]) model.eval() # evaluate on dev print_log('eval', f'Evaluating on: {args.amr_dev_data}') if save_model: predicted_amr_file = f'{save_model}/{exp_name}_amrs.epoch{epoch_idx}.dev.txt' else: predicted_amr_file = f'{exp_name}_amrs.epoch{epoch_idx}.dev.txt' with open(predicted_amr_file, 'w+') as f: f.write('') print_log('eval', f'Writing amr graphs to: {predicted_amr_file}') if save_model: actions_file = f'{save_model}/{exp_name}_actions.epoch{epoch_idx}.dev.txt' else: actions_file = f'{exp_name}_actions.epoch{epoch_idx}.dev.txt' if args.write_actions: print_log('eval', f'Writing actions to: {actions_file}') with open(actions_file, 'w+') as f: f.write('') sent_idx = 0 dev_hash = 0 for tokens in tqdm(dev_sentences): sent_rep = utils.vectorize_words(model, tokens, training=False, gpu=args.gpu) dev_b_emb = get_bert_embeddings(h5py_test, sent_idx, tokens) if not args.no_bert else None _, actions, labels, labelsA, predicates = model.forward_single( sent_rep, mode='predict', tokens=tokens, bert_embedding=dev_b_emb) # write amr graphs apply_actions = [] for act, label, labelA, predicate in zip(actions, labels, labelsA, predicates): # print(act, label, labelA, predicate) if act.startswith('PR'): apply_actions.append(act + f'({predicate})') elif act.startswith('RA') or act.startswith( 'LA') and not act.endswith('(root)'): apply_actions.append(act + f'({label})') elif act.startswith('AD'): apply_actions.append(act + f'({labelA})') else: apply_actions.append(act) if args.unit_tests: dev_hash += sum(model.action2idx[a] for a in actions) dev_hash += sum(model.labelsO2idx[l] for l in labels if l) dev_hash += sum(model.labelsA2idx[l] for l in labelsA if l) dev_hash += sum(model.pred2idx[p] if p in model.pred2idx else 0 for p in predicates if p) # print('[eval]',apply_actions) if args.write_actions: with open(actions_file, 'a') as f: f.write('\t'.join(tokens) + '\n') f.write('\t'.join(apply_actions) + '\n\n') tr = AMRStateMachine(tokens, verbose=False) tr.applyActions(apply_actions) with open(predicted_amr_file, 'a') as f: f.write(tr.amr.toJAMRString()) sent_idx += 1 # run smatch print_log('eval', f'Computing SMATCH') smatch_score = smatch_wrapper(args.amr_dev_data, predicted_amr_file, significant=3) print_log('eval', f'SMATCH: {smatch_score}') timestamp = str(datetime.now()).split('.')[0] # store all information in file print_log('eval', f'Writing SMATCH and other info to: {smatch_file}') with open(smatch_file, 'a') as fid: fid.write("\t".join([ f'epoch {epoch_idx}', f'learning_rate {learning_rate}', f'time {timestamp}', f'F-score {smatch_score}\n' ])) if args.unit_tests: test1 = (model.epoch_loss == 3360.1150283813477) test2 = (dev_hash == 6038) print( f'[run tests] epoch_loss==3360.1150283813477 (got {model.epoch_loss}) {"pass" if test1 else "fail"}', file=sys.stderr) print( f'[run tests] dev hash==6038 (got {dev_hash}) {"pass" if test2 else "fail"}', file=sys.stderr) assert (test1) assert (test2) return smatch_score
def main(): # Argument handling args = argument_parser() # read rules train_rule_stats = read_rule_stats(args.in_rule_stats) assert 'action_vocabulary' in train_rule_stats assert 'possible_predicates' in train_rule_stats action_list = list(sorted(train_rule_stats['action_vocabulary'].keys())) # get all actions indexec by action root action_by_basic = defaultdict(list) for action in train_rule_stats['action_vocabulary'].keys(): key = action.split('(')[0] action_by_basic[key].append(action) # open file for reading if provided write_out_states = h5_writer(args.out_word_states) write_out_valid_actions = h5_writer(args.out_valid_actions) # Read content # TODO: Point to LDC data sentences = readlines(args.in_sentences) actions = readlines(args.in_actions) assert len(sentences) == len(actions) # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = get_spacy_lemmatizer() sent_idx = -1 stats = { 'missing_pred_count': Counter(), 'missing_action_count': Counter(), 'fan_out_count': Counter() } for sent_tokens, sent_actions in tqdm( zip(sentences, actions), desc='extracting oracle masks', total=len(actions) ): # keep count of sentence index sent_idx += 1 if args.offset and sent_idx < args.offset: continue # Initialize state machine amr_state_machine = AMRStateMachine( sent_tokens, spacy_lemmatizer=spacy_lemmatizer ) # process each action word_states_sent = [] valid_actions_sent = [] for raw_action in sent_actions: # Store states BEFORE ACTION # state of each word (buffer B, stack S, reduced X) word_states = get_word_states(amr_state_machine, sent_tokens) # Get actions valid for this state valid_actions = get_valid_actions( action_list, amr_state_machine, train_rule_stats, action_by_basic, raw_action, stats ) # update info word_states_sent.append(word_states) valid_actions_sent.append(valid_actions) # Update machine amr_state_machine.applyAction(raw_action) # Write states for this sentence write_out_states(word_states_sent) write_out_valid_actions(valid_actions_sent) # inform usre about missing predicates for miss in ['missing_pred_count', 'missing_action_count']: num_missing = len(stats[miss]) if num_missing: alert_str = f'{num_missing} {miss} rule_stats' print(yellow_font(alert_str)) # inform user about fan-out stats mode_fan_out = stats['fan_out_count'].most_common(1)[0][0] max_fan_out = max(stats['fan_out_count'].keys()) alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}' print(alert_str) # Close file write_out_states() write_out_valid_actions()
def parse_sentence(self, sentence_str): """ sentence_str is a string with whitespace separated tokens """ # simulated actions given by a parsing model key = " ".join(sentence_str) assert sentence_str in self.actions_by_sentence, \ "Fake parser has no actions for sentence: %s" % sentence_str actions = self.actions_by_sentence[sentence_str] tokens = sentence_str.split() # Initialize state machine if self.machine_type == 'AMR': state_machine = AMRStateMachine( tokens, actions_by_stack_rules=self.actions_by_stack_rules, spacy_lemmatizer=self.spacy_lemmatizer, entity_rules=self.entity_rules) elif self.machine_type == 'dep-parsing': state_machine = DepParsingStateMachine(tokens) # this will store AMR parsing as BIO tag (PRED, ADDNODE) bio_alignments = {} # execute parsing model while not state_machine.is_closed: # Print state (pause if solicited) self.logger.update(self.sent_idx, state_machine) if len(actions) <= state_machine.time_step: # if machine is not propperly closed hard exit print( yellow_font( f'machine not closed at step {state_machine.time_step}' )) raw_action = 'CLOSE' else: # get action from model raw_action = actions[state_machine.time_step] # restrict action space according to machine restrictions and # statistics if self.machine_type == 'AMR': raw_action = restrict_action(state_machine, raw_action, self.pred_counts, self.rule_violation) # update bio tags from AMR bio_alignments.update( get_bio_from_machine(state_machine, raw_action)) # Update state machine state_machine.applyAction(raw_action) # build bio tags bio_tags = get_bio_tags(state_machine, bio_alignments) # count one sentence more self.sent_idx += 1 return state_machine, bio_tags
def runOracle(self, gold_amrs, propbank_args=None, out_oracle=None, out_amr=None, out_sentences=None, out_actions=None, out_rule_stats=None, add_unaligned=0, no_whitespace_in_actions=False, multitask_words=None, copy_lemma_action=False, addnode_count_cutoff=None): print_log("oracle", "Parsing data") # deep copy of gold AMRs self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs] # print about inconsistencies in annotations alert_inconsistencies(self.gold_amrs) # open all files (if paths provided) and get writers to them oracle_write = writer(out_oracle) amr_write = writer(out_amr) sentence_write = writer(out_sentences, add_return=True) actions_write = writer(out_actions, add_return=True) # This will store overall stats self.stats = { 'possible_predicates': Counter(), 'action_vocabulary': Counter(), 'addnode_counts': Counter() } # unaligned tokens included_unaligned = [ '-', 'and', 'multi-sentence', 'person', 'cause-01', 'you', 'more', 'imperative', '1', 'thing', ] # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = None if copy_lemma_action: spacy_lemmatizer = get_spacy_lemmatizer() # Loop over golf AMRs for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs), desc=f'computing oracle', total=len(self.gold_amrs)): if self.verbose: print("New Sentence " + str(sent_idx) + "\n\n\n") # TODO: Describe what is this pre-processing gold_amr = preprocess_amr(gold_amr, add_unaligned, included_unaligned) # Initialize state machine tr = AMRStateMachine(gold_amr.tokens, verbose=self.verbose, add_unaligned=add_unaligned, spacy_lemmatizer=spacy_lemmatizer, entity_rules=self.entity_rules) self.transitions.append(tr) self.amrs.append(tr.amr) # Loop over potential actions while tr.buffer or tr.stack: if self.tryMerge(tr, tr.amr, gold_amr): action = 'MERGE' elif self.tryEntity(tr, tr.amr, gold_amr): action = f'ADDNODE({self.entity_type})' elif self.tryDependent(tr, tr.amr, gold_amr): edge = self.new_edge[1:] \ if self.new_edge.startswith(':') else self.new_edge action = f'DEPENDENT({self.new_node},{edge})' self.dep_id = None elif self.tryConfirm(tr, tr.amr, gold_amr): # if --copy-lemma-action check if lemma or first sense # equal node name. Use corresponding action if copy_lemma_action: lemma, _ = tr.get_top_of_stack(lemma=True) if copy_lemma_action and lemma == self.new_node: action = 'COPY_LEMMA' elif copy_lemma_action and f'{lemma}-01' == self.new_node: action = 'COPY_SENSE01' else: action = f'PRED({self.new_node})' else: action = f'PRED({self.new_node})' elif self.tryIntroduce(tr, tr.amr, gold_amr): action = 'INTRODUCE' elif self.tryLA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'LA({self.new_edge})' else: action = f'LA({self.new_edge[1:]})' elif self.tryRA(tr, tr.amr, gold_amr): if self.new_edge == 'root': action = f'RA({self.new_edge})' else: action = f'RA({self.new_edge[1:]})' elif self.tryReduce(tr, tr.amr, gold_amr): action = 'REDUCE' elif self.trySWAP(tr, tr.amr, gold_amr): action = 'UNSHIFT' elif tr.buffer: action = 'SHIFT' else: tr.stack = [] tr.buffer = [] break # Store stats # get token(s) at the top of the stack token, merged_tokens = tr.get_top_of_stack() action_label = action.split('(')[0] # check action has not invalid chars and normalize # TODO: --no-whitespace-in-actions being deprecated if no_whitespace_in_actions and action_label == 'PRED': assert '_' not in action, \ "--no-whitespace-in-actions prohibits use of _ in actions" if ' ' in action_label: action = action.replace(' ', '_') # Add prediction ot top of the buffer if action == 'SHIFT' and multitask_words is not None: action = label_shift(tr, multitask_words) # APPLY ACTION tr.applyAction(action) # Close machine tr.CLOSE(training=True, gold_amr=gold_amr, use_addnonde_rules=use_addnode_rules) # update files if out_oracle: # to avoid printing oracle_write(str(tr)) # JAMR format AMR amr_write(tr.amr.toJAMRString()) # Tokens and actions # extra tag to be reduced at start tokens = tr.amr.tokens actions = tr.actions # Update action count self.stats['action_vocabulary'].update(actions) del gold_amr.nodes[-1] addnode_actions = [a for a in actions if a.startswith('ADDNODE')] self.stats['addnode_counts'].update(addnode_actions) # separator if no_whitespace_in_actions: sep = " " else: sep = "\t" tokens = sep.join(tokens) actions = sep.join(actions) # Write sentence_write(tokens) actions_write(actions) print_log("oracle", "Done") # close files if open oracle_write() amr_write() sentence_write() actions_write() self.labelsO2idx = {'<pad>': 0} self.labelsA2idx = {'<pad>': 0} self.pred2idx = {'<pad>': 0} self.action2idx = {'<pad>': 0} for tr in self.transitions: for a in tr.actions: a = AMRStateMachine.readAction(a)[0] self.action2idx.setdefault(a, len(self.action2idx)) for p in tr.predicates: self.pred2idx.setdefault(p, len(self.pred2idx)) for l in tr.labels: self.labelsO2idx.setdefault(l, len(self.labelsO2idx)) for l in tr.labelsA: self.labelsA2idx.setdefault(l, len(self.labelsA2idx)) self.stats["action2idx"] = self.action2idx self.stats["pred2idx"] = self.pred2idx self.stats["labelsO2idx"] = self.labelsO2idx self.stats["labelsA2idx"] = self.labelsA2idx # Compute the word dictionary self.char2idx = {'<unk>': 0} self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3} self.node2idx = {} self.word_counter = Counter() for amr in self.gold_amrs: for tok in amr.tokens: self.word_counter[tok] += 1 self.word2idx.setdefault(tok, len(self.word2idx)) for ch in tok: self.char2idx.setdefault(ch, len(self.char2idx)) for n in amr.nodes: self.node2idx.setdefault(amr.nodes[n], len(self.node2idx)) self.stats["char2idx"] = self.char2idx self.stats["word2idx"] = self.word2idx self.stats["node2idx"] = self.node2idx self.stats["word_counter"] = self.word_counter self.stats['possible_predicates'] = self.possiblePredicates if addnode_count_cutoff: self.stats['addnode_blacklist'] = [ a for a, c in self.stats['addnode_counts'].items() if c <= addnode_count_cutoff ] num_addnode_blackl = len(self.stats['addnode_blacklist']) num_addnode = len(self.stats['addnode_counts']) print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES') del self.stats['addnode_counts'] # State machine stats for this senetnce if out_rule_stats: with open(out_rule_stats, 'w') as fid: fid.write(json.dumps(self.stats))