def infer(self, examples): # currently use beam search as sampling method # set model to evaluation model for beam search, make sure dropout is properly behaving! was_training = self.encoder.training self.encoder.eval() hypotheses = [self.encoder.parse(e.src_sent, beam_size=self.args.sample_size) for e in examples] if len(hypotheses) == 0: raise ValueError('No candidate hypotheses.') if was_training: self.encoder.train() # some source may not have corresponding samples, so we only retain those that have sampled logical forms sampled_examples = [] for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)): for hyp_id, hyp in enumerate(hyps): try: code = self.transition_system.ast_to_surface_code(hyp.tree) self.transition_system.tokenize_code(code) # make sure the code is tokenizable! sampled_example = Example(idx='%d-sample%d' % (example.idx, hyp_id), src_sent=example.src_sent, tgt_code=code, tgt_actions=hyp.action_infos, tgt_ast=hyp.tree) sampled_examples.append(sampled_example) except: print("Exception in converting tree to code:", file=sys.stdout) print('-' * 60, file=sys.stdout) traceback.print_exc(file=sys.stdout) print('-' * 60, file=sys.stdout) sample_scores, enc_states = self.encoder.score(sampled_examples, return_enc_state=True) return sampled_examples, sample_scores, enc_states
def load_dataset(transition_system, dataset_file, reorder_predicates=True): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') lf = parse_lambda_expr(tgt_code) assert lf.to_string() == tgt_code if reorder_predicates: ordered_lf = get_canonical_order_of_logical_form( lf, order_by='alphabet') assert ordered_lf == lf lf = ordered_lf gold_source = lf.to_string() tgt_ast = logical_form_to_ast(grammar, lf) reconstructed_lf = ast_to_logical_form(tgt_ast) assert lf == reconstructed_lf tgt_actions = transition_system.get_actions(tgt_ast) print(idx) print('Utterance: %s' % src_query) print('Reference: %s' % tgt_code) # print('===== Actions =====') # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) # print(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert src_from_hyp == gold_source tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) # print(' '.join(src_query_tokens)) print('***') print(lf.to_string()) print() example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=gold_source, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def infer_backward(self, examples, sample_size): # get samples of p_{\theta}(x|z) on evaluating mode was_training = self.decoder.training self.decoder.eval() hypotheses = [ self.decoder.sample(e.tgt_code, cuda=self.args.cuda) for e in examples ] if len(hypotheses) == 0: raise ValueError('No candidate hypotheses.') if was_training: self.decoder.train() sampled_examples = [] picked_example = [] for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)): if len(hyps) != sample_size: print('not valid sample size expected %d but %d' % (sample_size, len(hyps))) samples_temp = [] for hyp_id, hyp in enumerate(hyps): if len(hyp) == 0: # generated x is null continue try: sampled_example = Example(idx='%d-resample%d' % (example.idx, hyp_id), src_sent=hyp, tgt_code=example.tgt_code, tgt_actions=example.tgt_actions, tgt_ast=example.tgt_ast) samples_temp.append(sampled_example) except: pass if len( samples_temp ) < sample_size / 2.0: #valid samples too little, skip this example continue if len( samples_temp ) < sample_size: #valid samples less than sample size, padding by repeating valid samples samples_temp += samples_temp[:sample_size - len(samples_temp)] assert len(samples_temp) == sample_size # initialize cached samples with the first sample if hasattr(example, 'recache'): sample_final = [example.recache] else: sample_final = samples_temp[:1] sample_final += samples_temp sampled_examples += sample_final picked_example.append(example) index = sorted(range(len(sampled_examples)), key=lambda i: -len(sampled_examples[i].src_sent)) back_index = [(index[i], i) for i in index] back_index.sort(key=lambda x: x[0]) assert [index[x[1]] for x in back_index] == list(range(len(sampled_examples))) back_index = np.array([x[1] for x in back_index]) sampled_examples = [sampled_examples[i] for i in index] sample_scores = self.decoder.score(sampled_examples) return sampled_examples, sample_scores, back_index, picked_example
def infer(self, examples, sample_size): # get samples of q_{\phi}(z|x) on evaluating mode was_training = self.encoder.training self.encoder.eval() hypotheses = [self.encoder.parse_sample(e.src_sent) for e in examples] if len(hypotheses) == 0: raise ValueError('No candidate hypotheses.') if was_training: self.encoder.train() # some source may not have corresponding samples, so we only retain those that have sampled logical forms sampled_examples = [] picked_example = [] for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)): if len(hyps) != sample_size: print('not valid sample size expected %d but %d' % (sample_size, len(hyps))) samples_temp = [] for hyp_id, hyp in enumerate(hyps): try: code = self.transition_system.ast_to_surface_code(hyp.tree) self.transition_system.tokenize_code( code) # make sure the code is tokenizable! if len(code) == 0: continue sampled_example = Example(idx='%d-sample%d' % (example.idx, hyp_id), src_sent=example.src_sent, tgt_code=code, tgt_actions=hyp.action_infos, tgt_ast=hyp.tree) samples_temp.append(sampled_example) except: pass if len( samples_temp ) < sample_size / 2.0: #valid samples too little, skip this example continue if len( samples_temp ) < sample_size: #valid samples less than sample size, padding by repeating valid samples samples_temp += samples_temp[:sample_size - len(samples_temp)] assert len(samples_temp) == sample_size # initialize cached samples with the first sample if hasattr(example, 'cache'): sample_final = [example.cache] else: sample_final = samples_temp[:1] sample_final += samples_temp sampled_examples += sample_final picked_example.append(example) sample_scores = self.encoder.score(sampled_examples) return sampled_examples, sample_scores, picked_example
def load_regex_dataset(transition_system, split): prefix = 'data/regex/' src_file = join(prefix, "src-{}.txt".format(split)) spec_file = join(prefix, "spec-{}.txt".format(split)) examples = [] for idx, (src_line, spec_line) in enumerate(zip(open(src_file), open(spec_file))): print(idx) src_line = src_line.rstrip() spec_line = spec_line.rstrip() src_toks = src_line.split() spec_toks = spec_line.rstrip().split() spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks) # sanity check reconstructed_expr = transition_system.ast_to_surface_code( spec_ast) print(spec_line, reconstructed_expr) assert spec_line == reconstructed_expr tgt_actions = transition_system.get_actions(spec_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, spec_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == spec_line tgt_action_infos = get_action_infos(src_toks, tgt_actions) example = Example(idx=idx, src_sent=src_toks, tgt_actions=tgt_action_infos, tgt_code=spec_line, tgt_ast=spec_ast, meta=None) examples.append(example) return examples
def load_dataset(transition_system, path, num, reorder_predicates=True): grammar = transition_system.grammar examples = [] pre_len = 0 if os.path.exists('data/pdf/train.bin'): examples = pickle.load(open('data/pdf/train.bin', 'rb')) pre_len = len(examples) idx = 0 for item in os.listdir(path): item_path = os.path.join(path, item) print(item) try: pdf = PdfReader(item_path) except: continue for page in pdf.pages: idx += 1 if idx <= pre_len: continue print(idx) try: tgt_ast = pdf_to_ast(grammar, page, []) except: continue tgt_actions = transition_system.get_actions(tgt_ast) """ hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions(hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None """ tgt_action_infos = get_action_infos(tgt_actions) example = Example(idx=idx, tgt_actions=tgt_action_infos, meta=None) examples.append(example) if idx >= num: break return examples
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): print(line) src_query, tgt_code = line.strip().split('~') tgt_code = tgt_code.replace("("," ( ") tgt_code = tgt_code.replace(")"," ) ") tgt_code = " ".join(tgt_code.split()) src_query = src_query.replace("(","") src_query = src_query.replace(")","") src_query_tokens = src_query.split(' ') tgt_ast = lisp_expr_to_ast(transition_system.grammar, tgt_code) reconstructed_lisp_expr = ast_to_lisp_expr(tgt_ast) assert tgt_code == reconstructed_lisp_expr tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions(hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, tgt_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == tgt_code tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=tgt_code, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') lf = parse_lambda_expr(tgt_code) gold_source = lf.to_string() assert gold_source == tgt_code tgt_ast = logical_form_to_ast(grammar, lf) reconstructed_lf = ast_to_logical_form(tgt_ast) assert lf == reconstructed_lf tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert src_from_hyp == gold_source tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=gold_source, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def initialize_rerank_features(self, examples, decode_results): hyp_examples = [] # print('initializing features...', file=sys.stderr) for example, hyps in zip(examples, decode_results): for hyp_id, hyp in enumerate(hyps): hyp_example = Example(idx=None, src_sent=example.src_sent, tgt_code=hyp.code, tgt_actions=None, tgt_ast=None) hyp_examples.append(hyp_example) # hyp.tokenized_code = len(self.transition_system.tokenize_code(hyp.code)) # hyp.code_token_count = len(hyp.code.split(' ')) feat_vals = OrderedDict() hyp.rerank_feature_values = feat_vals for batch_examples in utils.batch_iter(hyp_examples, batch_size=128): for feat_name, feat in self.batched_features.items(): batch_example_scores = feat.score( batch_examples).data.cpu().tolist() for i, e in enumerate(batch_examples): setattr(e, feat_name, batch_example_scores[i]) e_ptr = 0 for example, hyps in zip(examples, decode_results): for hyp in hyps: for feat_name, feat in self.batched_features.items(): hyp.rerank_feature_values[feat_name] = getattr( hyp_examples[e_ptr], feat_name) e_ptr += 1 for example, hyps in zip(examples, decode_results): for hyp_id, hyp in enumerate(hyps): for feat_name, feat in self.feat_map.items(): if not feat.is_batched: feat_val = feat.get_feat_value( example, hyp, hyp_id=hyp_id, all_hyps=hyps, transition_system=self.transition_system) hyp.rerank_feature_values[feat_name] = feat_val
def load_dataset(split, transition_system): prefix = 'data/turk/' src_file = join(prefix, "src-{}.txt".format(split)) spec_file = join(prefix, "spec-{}.txt".format(split)) examples = [] for idx, (src_line, spec_line) in enumerate(zip(open(src_file), open(spec_file))): print(idx) src_line = src_line.rstrip() spec_line = spec_line.rstrip() src_toks = src_line.split() spec_toks = spec_line.rstrip().split() spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks) # sanity check reconstructed_expr = transition_system.ast_to_surface_code(spec_ast) print(spec_line, reconstructed_expr) assert spec_line == reconstructed_expr tgt_action_tree = transition_system.get_action_tree(spec_ast) # sanity check ast_from_action = transition_system.build_ast_from_actions( tgt_action_tree) assert is_equal_ast(ast_from_action, spec_ast) expr_from_hyp = transition_system.ast_to_surface_code(ast_from_action) assert expr_from_hyp == spec_line # sanity check # tgt_action_infos = get_action_infos(src_toks, tgt_actions) example = Example(idx=idx, src_toks=src_toks, tgt_actions=tgt_action_tree, tgt_toks=spec_toks, tgt_ast=spec_ast, meta=None) examples.append(example) return examples
def parse(self, utterance, debug=False): utterance = utterance.strip() processed_utterance_tokens, utterance_meta = self.example_processor.pre_process_utterance( utterance) if debug: print(processed_utterance_tokens) print(utterance_meta) examples = [ Example(idx=None, src_sent=processed_utterance_tokens, tgt_code=None, tgt_actions=None, tgt_ast=None) ] hypotheses = self.parser.parse(processed_utterance_tokens, beam_size=self.beam_size, debug=debug) if self.reranker: hypotheses = self.decode_tree_to_code(hypotheses) hypotheses = self.reranker.rerank_hypotheses( examples, [hypotheses])[0] valid_hypotheses = list( filter( lambda hyp: self.parser.transition_system.is_valid_hypothesis( hyp), hypotheses)) for hyp in valid_hypotheses: self.example_processor.post_process_hypothesis(hyp, utterance_meta) if debug: for hyp_id, hyp in enumerate(valid_hypotheses): print('------------------ Hypothesis %d ------------------' % hyp_id) print(hyp.code) print(hyp.tree.to_string()) print('Actions:') for action_t in hyp.action_infos: print(action_t.action) return valid_hypotheses
def parse_django_dataset(annot_file, code_file, asdl_file_path, max_query_len=70, vocab_freq_cutoff=10): asdl_text = open(asdl_file_path).read() grammar = ASDLGrammar.from_text(asdl_text) transition_system = PythonTransitionSystem(grammar) loaded_examples = [] from components.vocab import Vocab, VocabEntry from components.dataset import Example for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))): src_query = src_query.strip() tgt_code = tgt_code.strip() src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example( src_query, tgt_code) python_ast = ast.parse(tgt_canonical_code).body[0] gold_source = astor.to_source(python_ast).strip() tgt_ast = python_ast_to_asdl_ast(python_ast, grammar) tgt_actions = transition_system.get_actions(tgt_ast) # print('+' * 60) # print('Example: %d' % idx) # print('Source: %s' % ' '.join(src_query_tokens)) # if str_map: # print('Original String Map:') # for str_literal, str_repr in str_map.items(): # print('\t%s: %s' % (str_literal, str_repr)) # print('Code:\n%s' % gold_source) # print('Actions:') # sanity check try: hyp = Hypothesis() for t, action in enumerate(tgt_actions): # assert action.__class__ in transition_system.get_valid_continuation_types(hyp) # if isinstance(action, ApplyRuleAction): # assert action.production in transition_system.get_valid_continuating_productions(hyp) p_t = -1 f_t = None if hyp.frontier_node: p_t = hyp.frontier_node.created_time f_t = hyp.frontier_field.field.__repr__(plain=True) # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t)) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = astor.to_source( asdl_ast_to_python_ast(hyp.tree, grammar)).strip() assert src_from_hyp == gold_source # print('+' * 60) except: continue loaded_examples.append({ 'src_query_tokens': src_query_tokens, 'tgt_canonical_code': gold_source, 'tgt_ast': tgt_ast, 'tgt_actions': tgt_actions, 'raw_code': tgt_code, 'str_map': str_map }) # print('first pass, processed %d' % idx, file=sys.stderr) train_examples = [] dev_examples = [] test_examples = [] action_len = [] for idx, e in enumerate(loaded_examples): src_query_tokens = e['src_query_tokens'][:max_query_len] tgt_actions = e['tgt_actions'] tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=e['tgt_canonical_code'], tgt_ast=e['tgt_ast'], meta={ 'raw_code': e['raw_code'], 'str_map': e['str_map'] }) # print('second pass, processed %d' % idx, file=sys.stderr) action_len.append(len(tgt_action_infos)) # train, valid, test split if 0 <= idx < 16000: train_examples.append(example) elif 16000 <= idx < 17000: dev_examples.append(example) else: test_examples.append(example) print('Max action len: %d' % max(action_len), file=sys.stderr) print('Avg action len: %d' % np.average(action_len), file=sys.stderr) print('Actions larger than 100: %d' % len(list(filter(lambda x: x > 100, action_len))), file=sys.stderr) src_vocab = VocabEntry.from_corpus( [e.src_sent for e in train_examples], size=5000, freq_cutoff=vocab_freq_cutoff) primitive_tokens = [ map( lambda a: a.action.token, filter(lambda a: isinstance(a.action, GenTokenAction), e.tgt_actions)) for e in train_examples ] primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=vocab_freq_cutoff) assert '_STR:0_' in primitive_vocab # generate vocabulary for the code tokens! code_tokens = [ tokenize_code(e.tgt_code, mode='decoder') for e in train_examples ] code_vocab = VocabEntry.from_corpus(code_tokens, size=5000, freq_cutoff=vocab_freq_cutoff) vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab) print('generated vocabulary %s' % repr(vocab), file=sys.stderr) return (train_examples, dev_examples, test_examples), vocab
def parse_natural_dataset(asdl_file_path, max_query_len=70, vocab_freq_cutoff=10): asdl_text = open(asdl_file_path).read() print('building grammar') grammar = ASDLGrammar.from_text(asdl_text) transition_system = Python3TransitionSystem(grammar) loaded_examples = [] annotations = [] codes = [] path = os.path.join(os.path.dirname(__file__), "datagen") datagens = os.listdir(path) for folder in datagens: if "__" in folder or not os.path.isdir(os.path.join(path, folder)): continue with open(os.path.join(path, folder, "inputs.txt"), 'r') as file: annotations += file.read().split('\n') with open(os.path.join(path, folder, "outputs.txt"), 'r') as file: codes += file.read().split('\n') annotation_codes = list(zip(annotations, codes)) np.random.seed(42) np.random.shuffle(annotation_codes) from components.vocab import Vocab, VocabEntry from components.dataset import Example print('processing examples') for idx, (src_query, tgt_code) in enumerate(annotation_codes): if (idx % 100 == 0): sys.stdout.write("\r%s / %s" % (idx, len(annotation_codes))) sys.stdout.flush() src_query = src_query.strip() tgt_code = tgt_code.strip() src_query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example( src_query, tgt_code) python_ast = ast.parse(tgt_canonical_code) #.body[0] gold_source = astor.to_source(python_ast).strip() tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar) tgt_actions = transition_system.get_actions(tgt_ast) # print('+' * 60) # print('Example: %d' % idx) # print('Source: %s' % ' '.join(src_query_tokens)) # if str_map: # print('Original String Map:') # for str_literal, str_repr in str_map.items(): # print('\t%s: %s' % (str_literal, str_repr)) # print('Code:\n%s' % gold_source) # print('Actions:') # sanity check hyp = Hypothesis() for t, action in enumerate(tgt_actions): assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) # assert action.__class__ in transition_system.get_valid_continuation_types( # hyp) p_t = -1 f_t = None if hyp.frontier_node: p_t = hyp.frontier_node.created_time f_t = hyp.frontier_field.field.__repr__(plain=True) # print('\t[%d] %s, frontier field: %s, parent: %d' % # (t, action, f_t, p_t)) hyp = hyp.clone_and_apply_action(action) # assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = astor.to_source( asdl_ast_to_python_ast(hyp.tree, grammar)).strip() if "b'" not in str(gold_source) and 'b"' not in str(gold_source): assert src_from_hyp == gold_source # print('+' * 60) loaded_examples.append({ 'src_query_tokens': src_query_tokens, 'tgt_canonical_code': gold_source, 'tgt_ast': tgt_ast, 'tgt_actions': tgt_actions, 'raw_code': tgt_code, 'str_map': str_map }) # print('first pass, processed %d' % idx, file=sys.stderr) train_examples = [] dev_examples = [] test_examples = [] action_len = [] print("\nsplitting train/dev/test") for idx, e in enumerate(loaded_examples): src_query_tokens = e['src_query_tokens'][:max_query_len] tgt_actions = e['tgt_actions'] tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=e['tgt_canonical_code'], tgt_ast=e['tgt_ast'], meta={ 'raw_code': e['raw_code'], 'str_map': e['str_map'] }) # print('second pass, processed %d' % idx, file=sys.stderr) action_len.append(len(tgt_action_infos)) # train, valid, test split total_examples = len(loaded_examples) split_size = np.ceil(total_examples * 0.05) (dev_split, test_split) = (total_examples - split_size * 2, total_examples - split_size) if 0 <= idx < dev_split: train_examples.append(example) elif dev_split <= idx < test_split: dev_examples.append(example) else: test_examples.append(example) print('Max action len: %d' % max(action_len), file=sys.stderr) print('Avg action len: %d' % np.average(action_len), file=sys.stderr) print('Actions larger than 100: %d' % len(list(filter(lambda x: x > 100, action_len))), file=sys.stderr) src_vocab = VocabEntry.from_corpus( [e.src_sent for e in train_examples], size=5000, freq_cutoff=vocab_freq_cutoff) primitive_tokens = [ map( lambda a: a.action.token, filter(lambda a: isinstance(a.action, GenTokenAction), e.tgt_actions)) for e in train_examples ] primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=vocab_freq_cutoff) # assert '_STR:0_' in primitive_vocab # generate vocabulary for the code tokens! code_tokens = [ tokenize_code(e.tgt_code, mode='decoder') for e in train_examples ] code_vocab = VocabEntry.from_corpus(code_tokens, size=5000, freq_cutoff=vocab_freq_cutoff) vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab) print('generated vocabulary %s' % repr(vocab), file=sys.stderr) return (train_examples, dev_examples, test_examples), vocab
def load_regex_dataset(transition_system, split): prefix = 'data/streg/' src_file = join(prefix, "src-{}.txt".format(split)) spec_file = join(prefix, "targ-{}.txt".format(split)) map_file = join(prefix, "map-{}.txt".format(split)) exs_file = join(prefix, "exs-{}.txt".format(split)) rec_file = join(prefix, "rec-{}.pkl".format(split)) exs_info = StReg.load_examples(exs_file) map_info = StReg.load_map_file(map_file) rec_info = StReg.load_rec(rec_file) examples = [] for idx, (src_line, spec_line, str_exs, cmap, rec) in enumerate( zip(open(src_file), open(spec_file), exs_info, map_info, rec_info)): print(idx) src_line = src_line.rstrip() spec_line = spec_line.rstrip() src_toks = src_line.split() spec_toks = spec_line.rstrip().split() spec_ast = streg_expr_to_ast(transition_system.grammar, spec_toks) # sanity check reconstructed_expr = transition_system.ast_to_surface_code( spec_ast) # print("Spec", spec_line) # print("Rcon", reconstructed_expr) assert spec_line == reconstructed_expr tgt_actions = transition_system.get_actions(spec_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, spec_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == spec_line tgt_action_infos = get_action_infos(src_toks, tgt_actions) example = Example(idx=idx, src_sent=src_toks, tgt_actions=tgt_action_infos, tgt_code=spec_line, tgt_ast=spec_ast, meta={ "str_exs": str_exs, "const_map": cmap, "worker_info": rec }) examples.append(example) return examples
def preprocess_dataset(file_path, transition_system, name='train'): dataset = json.load(open(file_path)) examples = [] evaluator = ConalaEvaluator(transition_system) f = open(file_path + '.debug', 'w') for i, example_json in enumerate(dataset): example_dict = preprocess_example(example_json) if example_json['question_id'] in (18351951, 9497290, 19641579, 32283692): pprint(preprocess_example(example_json)) continue python_ast = ast.parse(example_dict['canonical_snippet']) canonical_code = astor.to_source(python_ast).strip() tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar) tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for t, action in enumerate(tgt_actions): assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) p_t = -1 f_t = None if hyp.frontier_node: p_t = hyp.frontier_node.created_time f_t = hyp.frontier_field.field.__repr__(plain=True) # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t)) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None hyp.code = code_from_hyp = astor.to_source( asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip() assert code_from_hyp == canonical_code decanonicalized_code_from_hyp = decanonicalize_code( code_from_hyp, example_dict['slot_map']) assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp)) assert transition_system.compare_ast( transition_system.surface_code_to_ast( decanonicalized_code_from_hyp), transition_system.surface_code_to_ast(example_json['snippet'])) tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions) example = Example(idx=f'{i}-{example_json["question_id"]}', src_sent=example_dict['intent_tokens'], tgt_actions=tgt_action_infos, tgt_code=canonical_code, tgt_ast=tgt_ast, meta=dict(example_dict=example_json, slot_map=example_dict['slot_map'])) assert evaluator.is_hyp_correct(example, hyp) examples.append(example) # log! f.write(f'Example: {example.idx}\n') f.write( f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n" ) f.write( f"Original Snippet: {example.meta['example_dict']['snippet']}\n") f.write(f"\n") f.write(f"Utterance: {' '.join(example.src_sent)}\n") f.write(f"Snippet: {example.tgt_code}\n") f.write( f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" ) f.close() return examples
def preprocess_dataset(file_path, transition_system, name='train', firstk=None): file_path = os.path.join(os.getcwd(), *file_path.split('/' if '/' in file_path else "\\")) try: dataset = json.load(open(file_path)) except: dataset = [json.loads(jline) for jline in open(file_path).readlines()] if firstk: dataset = dataset[:firstk] examples = [] evaluator = ConalaEvaluator(transition_system) # Author: Gabe # Added in encoding to try and deal with UnicodeEncodeErrors f = open(file_path + '.debug', 'w', encoding='utf-8') skipped_list = [] for i, example_json in tqdm(enumerate(dataset), file=sys.stdout, total=len(dataset), desc='Preproc'): # Author: Gabe # Have to skip this one question because it causes the program to hang and never recover. if example_json['question_id'] in [39525993]: skipped_list.append(example_json['question_id']) tqdm.write(f"Skipping {example_json['question_id']} because it causes errors") continue try: example_dict = preprocess_example(example_json) python_ast = ast.parse(example_dict['canonical_snippet']) canonical_code = astor.to_source(python_ast).strip() tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar) tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for t, action in enumerate(tgt_actions): assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in \ transition_system.get_valid_continuating_productions( hyp) # p_t = -1 # f_t = None # if hyp.frontier_node: # p_t = hyp.frontier_node.created_time # f_t = hyp.frontier_field.field.__repr__(plain=True) # # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t)) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None hyp.code = code_from_hyp = astor.to_source( asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip() # print(code_from_hyp) # print(canonical_code) assert code_from_hyp == canonical_code decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp, example_dict['slot_map']) assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp)) assert transition_system.compare_ast( transition_system.surface_code_to_ast(decanonicalized_code_from_hyp), transition_system.surface_code_to_ast(example_json['snippet'])) tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions) except (AssertionError, SyntaxError, ValueError, OverflowError) as e: skipped_list.append(example_json['question_id']) tqdm.write( f"Skipping example {example_json['question_id']} because of {type(e).__name__}:{e}" ) continue example = Example(idx=f'{i}-{example_json["question_id"]}', src_sent=example_dict['intent_tokens'], tgt_actions=tgt_action_infos, tgt_code=canonical_code, tgt_ast=tgt_ast, meta=dict(example_dict=example_json, slot_map=example_dict['slot_map'])) assert evaluator.is_hyp_correct(example, hyp) examples.append(example) # Author: Gabe # Had to remove logging, when the log file would get too large, it would cause the # program to hang. # log! # f.write(f'Example: {example.idx}\n') # if 'rewritten_intent' in example.meta['example_dict']: # f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n") # else: # f.write(f"Original Utterance: {example.meta['example_dict']['intent']}\n") # f.write(f"Original Snippet: {example.meta['example_dict']['snippet']}\n") # f.write(f"\n") # f.write(f"Utterance: {' '.join(example.src_sent)}\n") # f.write(f"Snippet: {example.tgt_code}\n") # f.write(f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") f.close() print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr) return examples
def self_training(args): """Perform self-training First load decoding results on disjoint data also load pre-trained model and perform supervised training on both existing training data and the decoded results """ print('load pre-trained model from [%s]' % args.load_model, file=sys.stderr) params = torch.load(args.load_model, map_location=lambda storage, loc: storage) vocab = params['vocab'] transition_system = params['transition_system'] saved_args = params['args'] saved_state = params['state_dict'] # transfer arguments saved_args.cuda = args.cuda saved_args.save_to = args.save_to saved_args.train_file = args.train_file saved_args.unlabeled_file = args.unlabeled_file saved_args.dev_file = args.dev_file saved_args.load_decode_results = args.load_decode_results args = saved_args update_args(args) model = Parser(saved_args, vocab, transition_system) model.load_state_dict(saved_state) if args.cuda: model = model.cuda() model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr) unlabeled_data = Dataset.from_bin_file(args.unlabeled_file) print('load decoding results of unlabeled data [%s]' % args.load_decode_results, file=sys.stderr) decode_results = pickle.load(open(args.load_decode_results)) labeled_data = Dataset.from_bin_file(args.train_file) dev_set = Dataset.from_bin_file(args.dev_file) print('Num. examples in unlabeled data: %d' % len(unlabeled_data), file=sys.stderr) assert len(unlabeled_data) == len(decode_results) self_train_examples = [] for example, hyps in zip(unlabeled_data, decode_results): if hyps: hyp = hyps[0] sampled_example = Example(idx='self_train-%s' % example.idx, src_sent=example.src_sent, tgt_code=hyp.code, tgt_actions=hyp.action_infos, tgt_ast=hyp.tree) self_train_examples.append(sampled_example) print('Num. self training examples: %d, Num. labeled examples: %d' % (len(self_train_examples), len(labeled_data)), file=sys.stderr) train_set = Dataset(examples=labeled_data.examples + self_train_examples) print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr) print('vocab: %s' % repr(vocab), file=sys.stderr) epoch = train_iter = 0 report_loss = report_examples = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [ e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step ] train_iter += 1 optimizer.zero_grad() loss = -model.score(batch_examples) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient if args.clip_grad > 0.: grad_norm = torch.nn.utils.clip_grad_norm( model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True) dev_acc = eval_results['accuracy'] print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr) is_better = history_dev_scores == [] or dev_acc > max( history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif epoch == args.max_epoch: print('reached max epoch, stop!', file=sys.stderr) exit(0) elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam( model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def preprocess_dataset(file_path, transition_system, name='train', firstk=None): try: dataset = json.load(open(file_path)) except: dataset = [json.loads(jline) for jline in open(file_path).readlines()] if firstk: dataset = dataset[:firstk] examples = [] evaluator = ConalaEvaluator(transition_system) f = open(file_path + '.debug', 'w') skipped_list = [] for i, example_json in enumerate(dataset): try: example_dict = preprocess_example(example_json) python_ast = ast.parse(example_dict['canonical_snippet']) canonical_code = astor.to_source(python_ast).strip() tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar) tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for t, action in enumerate(tgt_actions): assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) # p_t = -1 # f_t = None # if hyp.frontier_node: # p_t = hyp.frontier_node.created_time # f_t = hyp.frontier_field.field.__repr__(plain=True) # # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t)) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None hyp.code = code_from_hyp = astor.to_source( asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip() # print(code_from_hyp) # print(canonical_code) assert code_from_hyp == canonical_code decanonicalized_code_from_hyp = decanonicalize_code( code_from_hyp, example_dict['slot_map']) assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp)) assert transition_system.compare_ast( transition_system.surface_code_to_ast( decanonicalized_code_from_hyp), transition_system.surface_code_to_ast(example_json['snippet'])) tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions) except (AssertionError, SyntaxError, ValueError, OverflowError) as e: skipped_list.append(example_json['question_id']) continue example = Example(idx=f'{i}-{example_json["question_id"]}', src_sent=example_dict['intent_tokens'], tgt_actions=tgt_action_infos, tgt_code=canonical_code, tgt_ast=tgt_ast, meta=dict(example_dict=example_json, slot_map=example_dict['slot_map'])) assert evaluator.is_hyp_correct(example, hyp) examples.append(example) # log! f.write(f'Example: {example.idx}\n') if 'rewritten_intent' in example.meta['example_dict']: f.write( f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n" ) else: f.write( f"Original Utterance: {example.meta['example_dict']['intent']}\n" ) f.write( f"Original Snippet: {example.meta['example_dict']['snippet']}\n") f.write(f"\n") f.write(f"Utterance: {' '.join(example.src_sent)}\n") f.write(f"Snippet: {example.tgt_code}\n") f.write( f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" ) f.close() print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr) return examples