def __init__(self, src_dict, tgt_dict, machine_type, machine_rules=None, entity_rules=None): # Get all actions indexed by prefix self.action_indexer = get_action_indexer(tgt_dict.symbols) # Load rule stats if provided if machine_rules is not None: assert os.path.isfile(machine_rules) rule_stats = read_rule_stats(machine_rules) # self.state_machine = StateMachine(folder) self.get_new_state_machine = machine_generator( rule_stats['possible_predicates'], entity_rules=entity_rules ) else: assert machine_type != 'AMR', \ "AMR machine expects --machine-rules" rule_stats = None self.get_new_state_machine = machine_generator(None) # store some variables self.src_dict = src_dict self.tgt_dict = tgt_dict self.machine_type = machine_type
def main(): # Argument handling args = argument_parser() # Read # Load AMR (replace some unicode characters) if args.in_amr: corpus = read_amr(args.in_amr, unicode_fixes=True) amrs = corpus.amrs # Load tokens if args.in_tokens: sentences = read_tokenized_sentences(args.in_tokens, separator='\t') # Load actions i.e. oracle if args.in_actions: actions = read_tokenized_sentences(args.in_actions, separator='\t') # Load scored actions i.e. mined oracle if args.in_scored_actions: scored_actions = read_action_scores(args.in_scored_actions) # measure performance print_score_action_stats(scored_actions) # Load rule stats if args.in_rule_stats: rule_stats = read_rule_stats(args.in_rule_stats) # Modify # merge --in-actions and --in-scored-actions and store in --out-actions if args.merge_mined: # sanity checks assert args.in_tokens, "--merge-mined requires --in-tokens" assert args.in_actions, "--merge-mined requires --in-actions" assert args.in_rule_stats, "--merge-mined requires --in-rule-stats" assert args.out_rule_stats, "--merge-mined requires --out-rule-stats" if args.in_actions: assert len(actions) == len(scored_actions) print(f'Merging {args.out_actions} and {args.in_scored_actions}') # actions actions = merge_actions(actions, scored_actions) # fix actions split by whitespace arguments if args.fix_actions: actions = fix_actions_split_by_spaces(actions) # merge rules if args.merge_mined: out_rule_stats = merge_rules(sentences, actions, rule_stats, entity_rules=args.entity_rules) print(f'Merging {args.out_rule_stats} and {args.in_rule_stats}') # Write # actions if args.out_actions: dirname = os.path.dirname(args.out_actions) if dirname: os.makedirs(dirname, exist_ok=True) write_tokenized_sentences( args.out_actions, actions, separator='\t' ) print(f'Wrote {args.out_actions}') # rule stats if args.out_rule_stats: write_rule_stats(args.out_rule_stats, out_rule_stats) print(f'Wrote {args.out_rule_stats}') # AMR if args.out_amr: with open(args.out_amr, 'w') as fid: for amr in amrs: fid.write(amr.toJAMRString())
def make_binary_stack(args, target_vocab, input_prefix, output_prefix, eos_idx, pad_idx, mask_predicates=False, allow_unk=False, tokenize=None): assert tokenize # involved files # for debug input_senteces = input_prefix + '.en' input_actions = input_prefix + '.actions' # The AMR state machine allways expects rules if args.machine_type == 'AMR': assert args.machine_rules and os.path.isfile(args.machine_rules), \ f'Missing {args.machine_rules}' # Read rules train_rule_stats = read_rule_stats(args.machine_rules) actions_by_stack_rules = train_rule_stats['possible_predicates'] else: actions_by_stack_rules = None action_indexer = get_action_indexer(target_vocab.symbols) # initialize indices for each of variables # memory (stack, buffer, dead) (position in memory) stack_buffer_names = ['memory', 'memory_pos'] # FIXME: These values are hard-coded elsewhere in code state_indices = [3, 4, 5] assert eos_idx not in state_indices, "Can not reuse EOS index" assert pad_idx not in state_indices, "Can not reuse PAD index" indexed_data = {} for name in stack_buffer_names: indexed_data[name] = indexed_dataset.make_builder( dataset_dest_file(args, output_prefix, name, "bin"), impl=args.dataset_impl, ) if mask_predicates: # mask of target predictions masks_path = dataset_dest_file(args, output_prefix, 'target_masks', "bin") indexed_target_masks = indexed_dataset.make_builder( masks_path, impl=args.dataset_impl, ) # active indices active_logits_path = dataset_dest_file(args, output_prefix, 'active_logits', "bin") indexed_active_logits = indexed_dataset.make_builder( active_logits_path, impl=args.dataset_impl, ) # Returns function that generates initialized state machines given # sentence get_new_state_machine = machine_generator(actions_by_stack_rules, entity_rules=args.entity_rules) num_sents = 0 missing_actions = Counter() with open(input_actions, 'r') as fid_act, \ open(input_senteces, 'r') as fid_sent: # Loop over sentences for sentence in tqdm(fid_sent): # Get actions, tokens sent_tokens = tokenize(sentence) sent_actions = tokenize(fid_act.readline()) # intialize state machine batch for size 1 state_machine = get_new_state_machine( sent_tokens, machine_type=args.machine_type ) # collect target and source masks sent_data = {} for name in stack_buffer_names: sent_data[name] = [] shape = (len(sent_actions), len(target_vocab.symbols)) logits_mask = np.zeros(shape) active_logits = set() for action_idx, gold_action in enumerate(sent_actions): # active logits for this action if mask_predicates: # Get total valid actions by expanding base ones valid_actions, invalid_actions = state_machine.get_valid_actions() valid_action_idx = ( action_indexer(valid_actions) - action_indexer(invalid_actions) ) # if action is missing add it and count it if gold_action in target_vocab.symbols: gold_action_index = target_vocab.symbols.index(gold_action) else: gold_action_index = target_vocab.symbols.index('<unk>') if gold_action_index not in valid_action_idx: valid_action_idx.add(gold_action_index) missing_actions.update([gold_action]) # if length 1 add pad to avoid deltas during learning if len(valid_action_idx) == 1: valid_action_idx.add(pad_idx) # append number of nodes to regain matrix logits_mask[action_idx, list(valid_action_idx)] = 1 active_logits |= valid_action_idx # stack and buffer memory, memory_pos = get_word_states( state_machine, sent_tokens, indices=state_indices ) # word states sent_data['memory'].append(torch.Tensor(memory)) # note we use position 0 for reduced words sent_data['memory_pos'].append( torch.Tensor(memory_pos) ) # Update machine state_machine.applyAction(gold_action) for name in stack_buffer_names: # note that data needs to be stores as a 1d array indexed_data[name].add_item( torch.stack(sent_data[name]).view(-1) ) # valid nodes if mask_predicates: active_logits = list(active_logits) # reduce size to active items logits_mask = logits_mask[:, active_logits] indexed_target_masks.add_item( torch.Tensor(logits_mask).view(-1) ) # active indices indexed_active_logits.add_item(torch.Tensor( active_logits )) # update number of sents num_sents += 1 if not num_sents % 100: print("\r%d sentences" % num_sents, end = '') print("") # close indexed data files for name in stack_buffer_names: output_file_idx = dataset_dest_file(args, output_prefix, name, "idx") indexed_data[name].finalize(output_file_idx) # close valid action mask if mask_predicates: target_mask_idx = dataset_dest_file(args, output_prefix, 'target_masks', "idx") indexed_target_masks.finalize(target_mask_idx) # active indices active_logits_idx = dataset_dest_file(args, output_prefix, 'active_logits', "idx") indexed_active_logits.finalize(active_logits_idx) # inform about mssing actions if missing_actions: print(yellow_font("There were missing actions")) print(missing_actions)
def main(): # Argument handling args = argument_parser() # read rules train_rule_stats = read_rule_stats(args.in_rule_stats) assert 'action_vocabulary' in train_rule_stats assert 'possible_predicates' in train_rule_stats action_list = list(sorted(train_rule_stats['action_vocabulary'].keys())) # get all actions indexec by action root action_by_basic = defaultdict(list) for action in train_rule_stats['action_vocabulary'].keys(): key = action.split('(')[0] action_by_basic[key].append(action) # open file for reading if provided write_out_states = h5_writer(args.out_word_states) write_out_valid_actions = h5_writer(args.out_valid_actions) # Read content # TODO: Point to LDC data sentences = readlines(args.in_sentences) actions = readlines(args.in_actions) assert len(sentences) == len(actions) # initialize spacy lemmatizer out of the sentence loop for speed spacy_lemmatizer = get_spacy_lemmatizer() sent_idx = -1 stats = { 'missing_pred_count': Counter(), 'missing_action_count': Counter(), 'fan_out_count': Counter() } for sent_tokens, sent_actions in tqdm( zip(sentences, actions), desc='extracting oracle masks', total=len(actions) ): # keep count of sentence index sent_idx += 1 if args.offset and sent_idx < args.offset: continue # Initialize state machine amr_state_machine = AMRStateMachine( sent_tokens, spacy_lemmatizer=spacy_lemmatizer ) # process each action word_states_sent = [] valid_actions_sent = [] for raw_action in sent_actions: # Store states BEFORE ACTION # state of each word (buffer B, stack S, reduced X) word_states = get_word_states(amr_state_machine, sent_tokens) # Get actions valid for this state valid_actions = get_valid_actions( action_list, amr_state_machine, train_rule_stats, action_by_basic, raw_action, stats ) # update info word_states_sent.append(word_states) valid_actions_sent.append(valid_actions) # Update machine amr_state_machine.applyAction(raw_action) # Write states for this sentence write_out_states(word_states_sent) write_out_valid_actions(valid_actions_sent) # inform usre about missing predicates for miss in ['missing_pred_count', 'missing_action_count']: num_missing = len(stats[miss]) if num_missing: alert_str = f'{num_missing} {miss} rule_stats' print(yellow_font(alert_str)) # inform user about fan-out stats mode_fan_out = stats['fan_out_count'].most_common(1)[0][0] max_fan_out = max(stats['fan_out_count'].keys()) alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}' print(alert_str) # Close file write_out_states() write_out_valid_actions()
def main(): # Argument handling args = argument_parser() # Get data sentences = read_tokenized_sentences(args.in_sentences, separator=args.separator) # Initialize logger/printer logger = Logger(step_by_step=args.step_by_step, clear_print=args.clear_print, pause_time=args.pause_time, verbose=args.verbose) # generate rules to restrict action space by stack content if args.action_rules_from_stats: rule_stats = read_rule_stats(args.action_rules_from_stats) actions_by_stack_rules = rule_stats['possible_predicates'] for token, counter in rule_stats['possible_predicates'].items(): actions_by_stack_rules[token] = Counter(counter) else: actions_by_stack_rules = None # Fake parser built from actions actions = read_tokenized_sentences(args.in_actions, separator=args.separator) assert len(sentences) == len(actions) parsing_model = FakeAMRParser( from_sent_act_pairs=zip(sentences, actions), machine_type=args.machine_type, logger=logger, actions_by_stack_rules=actions_by_stack_rules, no_whitespace_in_actions=args.no_whitespace_in_actions, entity_rules=args.entity_rules) # Get output AMR writer if args.out_amr: amr_write = writer(args.out_amr) if args.out_bio_tags: bio_write = writer(args.out_bio_tags) # Loop over sentences for sent_idx, tokens in tqdm(enumerate(sentences), desc='parsing'): # fast-forward until desired sentence number if args.offset and sent_idx < args.offset: continue # parse # NOTE: To simulate the real endpoint, input provided as a string of # whitespace separated tokens machine, bio_tags = parsing_model.parse_sentence(" ".join(tokens)) # sanity check annotations dupes = get_duplicate_edges(machine.amr) if args.sanity_check and any(dupes): msg = yellow_font('WARNING:') print(f'{msg} duplicated edges in sent {sent_idx}', end=' ') print(dict(dupes)) print(' '.join(machine.tokens)) # store output AMR if args.out_bio_tags: tag_str = '\n'.join([f'{to} {ta}' for to, ta in bio_tags]) tag_str += '\n\n' bio_write(tag_str) if args.out_amr: try: amr_write(machine.amr.toJAMRString()) except InvalidAMRError as exception: print(f'\nFailed at sentence {sent_idx}\n') raise exception if (getattr(parsing_model, "rule_violation") and parsing_model.rule_violation): print(yellow_font("There were one or more action rule violations")) print(parsing_model.rule_violation) if args.action_rules_from_stats: print("Predict rules had following statistics") print(parsing_model.pred_counts) # close output writers if args.out_amr: amr_write() if args.out_bio_tags: bio_write()