def process_heart_stone_dataset(): data_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' parse_trees = [] rule_num = 0. example_num = 0 for line in open(data_file): code = line.replace('§', '\n').strip() parse_tree = parse(code) # sanity check pred_ast = parse_tree_to_python_ast(parse_tree) pred_code = astor.to_source(pred_ast) ref_ast = ast.parse(code) ref_code = astor.to_source(ref_ast) if pred_code != ref_code: raise RuntimeError('code mismatch!') rules, _ = parse_tree.get_productions(include_value_node=False) rule_num += len(rules) example_num += 1 parse_trees.append(parse_tree) grammar = get_grammar(parse_trees) with open('hs.grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open('hs.parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') print 'avg. nums of rules: %f' % (rule_num / example_num)
def extract_grammar(code_file, prefix='py'): line_num = 0 parse_trees = [] for line in open(code_file): code = line.strip() parse_tree = parse(code) # leaves = parse_tree.get_leaves() # for leaf in leaves: # if not is_terminal_type(leaf.type): # print parse_tree # parse_tree = add_root(parse_tree) parse_trees.append(parse_tree) # sanity check ast_tree = parse_tree_to_python_ast(parse_tree) ref_ast_tree = ast.parse(canonicalize_code(code)).body[0] source1 = astor.to_source(ast_tree) source2 = astor.to_source(ref_ast_tree) assert source1 == source2 # check rules # rule_list = parse_tree.get_rule_list(include_leaf=True) # for rule in rule_list: # if rule.parent.type == int and rule.children[0].type == int: # # rule.parent.type == str and rule.children[0].type == str: # pass # ast_tree = tree_to_ast(parse_tree) # print astor.to_source(ast_tree) # print parse_tree # except Exception as e: # error_num += 1 # #pass # #print e line_num += 1 print 'total line of code: %d' % line_num grammar = get_grammar(parse_trees) with open(prefix + '.grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open(prefix + '.parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') return grammar, parse_trees
def process_heart_stone_dataset(): data_file = '/home1/zjq/try3/en-django/all.anno' parse_trees = [] rule_num = 0. example_num = 0 for line in open(data_file): code = line.replace('§', '\n').strip() parse_tree = parse(code) # sanity check pred_ast = parse_tree_to_python_ast(parse_tree) pred_code = astor.to_source(pred_ast) #print(pred_code) ref_ast = ast.parse(code) ref_code = astor.to_source(ref_ast) #print(ref_code) if pred_code != ref_code: raise RuntimeError('code mismatch!') rules, _ = parse_tree.get_productions(include_value_node=False) rule_num += len(rules) example_num += 1 parse_trees.append(parse_tree) grammar = get_grammar(parse_trees) with open('hs.grammar.txt', 'w') as f: for rule in grammar: str = rule.__repr__() f.write(str + '\n') with open('hs.parse_trees.txt', 'w') as f: for tree in parse_trees: f.write(tree.__repr__() + '\n') print('avg. nums of rules: %f' % (rule_num / example_num))
def parse_django_dataset(): from lang.py.parse import parse_raw from lang.util import escape MAX_QUERY_LENGTH = 70 UNARY_CUTOFF_FREQ = 30 annot_file = 'all.anno' code_file = 'all.code' data = preprocess_dataset(annot_file, code_file) for e in data: e['parse_tree'] = parse_raw(e['code']) parse_trees = [e['parse_tree'] for e in data] # apply unary closures # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ) # for i, parse_tree in enumerate(parse_trees): # apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) # write grammar with open('django.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') # # build grammar ... # from lang.py.py_dataset import extract_grammar # grammar, all_parse_trees = extract_grammar(code_file) annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab( annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980) terminal_token_seq = [] empty_actions_count = 0 # helper function begins def get_terminal_tokens(_terminal_str): # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-') # _terminal_tokens = filter(None, re.split('( )', _terminal_str)) # _terminal_str.split('-SP-') tmp_terminal_tokens = _terminal_str.split(' ') _terminal_tokens = [] for token in tmp_terminal_tokens: if token: _terminal_tokens.append(token) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # return _terminal_tokens # helper function ends # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 terminal_token_seq.append(terminal_token) terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3) assert '_STR:0_' in terminal_vocab train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data') all_examples = [] can_fully_gen_num = 0 # second pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] str_map = entry['str_map'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions( include_value_node=True) actions = [] can_fully_gen = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = { 'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule } action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = { 'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_gen = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = { 'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: empty_actions_count += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, { 'raw_code': entry['raw_code'], 'str_map': entry['str_map'] }) if can_fully_gen: can_fully_gen_num += 1 # train, valid, test if 0 <= idx < 16000: train_data.add(example) elif 16000 <= idx < 17000: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) serialize_to_file([len(e.query) for e in all_examples], 'query.len') serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_gen_num, len(all_examples), can_fully_gen_num / len(all_examples)) logging.info('empty_actions_count: %d', empty_actions_count) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices() dev_data.init_data_matrices() test_data.init_data_matrices() serialize_to_file(( train_data, dev_data, test_data ), 'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin' ) # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ)) return train_data, dev_data, test_data
def parse_django_dataset(): from lang.py.parse import parse_raw from lang.util import escape MAX_QUERY_LENGTH = 70 UNARY_CUTOFF_FREQ = 30 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' data = preprocess_dataset(annot_file, code_file) for e in data: e['parse_tree'] = parse_raw(e['code']) parse_trees = [e['parse_tree'] for e in data] # apply unary closures # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ) # for i, parse_tree in enumerate(parse_trees): # apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) # write grammar with open('django.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') # # build grammar ... # from lang.py.py_dataset import extract_grammar # grammar, all_parse_trees = extract_grammar(code_file) annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980) terminal_token_seq = [] empty_actions_count = 0 # helper function begins def get_terminal_tokens(_terminal_str): # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-') # _terminal_tokens = filter(None, re.split('( )', _terminal_str)) # _terminal_str.split('-SP-') tmp_terminal_tokens = _terminal_str.split(' ') _terminal_tokens = [] for token in tmp_terminal_tokens: if token: _terminal_tokens.append(token) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # return _terminal_tokens # helper function ends # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 terminal_token_seq.append(terminal_token) terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3) assert '_STR:0_' in terminal_vocab train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data') all_examples = [] can_fully_gen_num = 0 # second pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] str_map = entry['str_map'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions(include_value_node=True) actions = [] can_fully_gen = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule} action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_gen = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: empty_actions_count += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'raw_code': entry['raw_code'], 'str_map': entry['str_map']}) if can_fully_gen: can_fully_gen_num += 1 # train, valid, test if 0 <= idx < 16000: train_data.add(example) elif 16000 <= idx < 17000: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) serialize_to_file([len(e.query) for e in all_examples], 'query.len') serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_gen_num, len(all_examples), can_fully_gen_num / len(all_examples)) logging.info('empty_actions_count: %d', empty_actions_count) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices() dev_data.init_data_matrices() test_data.init_data_matrices() serialize_to_file((train_data, dev_data, test_data), 'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin') # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ)) return train_data, dev_data, test_data
def parse_hs_dataset(): MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config! WORD_FREQ_CUT_OFF = 3 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' data = preprocess_hs_dataset(annot_file, code_file) parse_trees = [e['parse_tree'] for e in data] # apply unary closures unary_closures = get_top_unary_closures(parse_trees, k=20) for parse_tree in parse_trees: apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) with open('hs.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) def get_terminal_tokens(_terminal_str): """ get terminal tokens break words like MinionCards into [Minion, Cards] """ tmp_terminal_tokens = [ t for t in _terminal_str.split(' ') if len(t) > 0 ] _terminal_tokens = [] for token in tmp_terminal_tokens: sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ') _terminal_tokens.extend(sub_tokens) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # enumerate all terminal tokens to build up the terminal tokens vocabulary all_terminal_tokens = [] for entry in data: parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 all_terminal_tokens.append(terminal_token) terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) # now generate the dataset! train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data') all_examples = [] can_fully_reconstructed_examples_num = 0 examples_with_empty_actions_num = 0 for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions( include_value_node=True) actions = [] can_fully_reconstructed = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = { 'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule } action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = { 'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_reconstructed = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = { 'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: examples_with_empty_actions_num += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, { 'str_map': None, 'raw_code': entry['raw_code'] }) if can_fully_reconstructed: can_fully_reconstructed_examples_num += 1 # train, valid, test splits if 0 <= idx < 533: train_data.add(example) elif idx < 599: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) # serialize_to_file([len(e.query) for e in all_examples], 'query.len') # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_reconstructed_examples_num, len(all_examples), can_fully_reconstructed_examples_num / len(all_examples)) logging.info('empty_actions_count: %d', examples_with_empty_actions_num) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices(max_query_length=70, max_example_action_num=350) dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350) test_data.init_data_matrices(max_query_length=70, max_example_action_num=350) serialize_to_file(( train_data, dev_data, test_data ), 'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin' .format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF)) return train_data, dev_data, test_data
def parse_hs_dataset(): MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config! WORD_FREQ_CUT_OFF = 3 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' data = preprocess_hs_dataset(annot_file, code_file) parse_trees = [e['parse_tree'] for e in data] # apply unary closures unary_closures = get_top_unary_closures(parse_trees, k=20) for parse_tree in parse_trees: apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) with open('hs.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) def get_terminal_tokens(_terminal_str): """ get terminal tokens break words like MinionCards into [Minion, Cards] """ tmp_terminal_tokens = [t for t in _terminal_str.split(' ') if len(t) > 0] _terminal_tokens = [] for token in tmp_terminal_tokens: sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ') _terminal_tokens.extend(sub_tokens) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # enumerate all terminal tokens to build up the terminal tokens vocabulary all_terminal_tokens = [] for entry in data: parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 all_terminal_tokens.append(terminal_token) terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF) # now generate the dataset! train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data') dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data') test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data') all_examples = [] can_fully_reconstructed_examples_num = 0 examples_with_empty_actions_num = 0 for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions(include_value_node=True) actions = [] can_fully_reconstructed = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule} action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_reconstructed = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t} actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: examples_with_empty_actions_num += 1 continue example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']}) if can_fully_reconstructed: can_fully_reconstructed_examples_num += 1 # train, valid, test splits if 0 <= idx < 533: train_data.add(example) elif idx < 599: dev_data.add(example) else: test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) # serialize_to_file([len(e.query) for e in all_examples], 'query.len') # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_reconstructed_examples_num, len(all_examples), can_fully_reconstructed_examples_num / len(all_examples)) logging.info('empty_actions_count: %d', examples_with_empty_actions_num) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices(max_query_length=70, max_example_action_num=350) dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350) test_data.init_data_matrices(max_query_length=70, max_example_action_num=350) serialize_to_file((train_data, dev_data, test_data), 'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF)) return train_data, dev_data, test_data
def parse_django_dataset(): from lang.py.parse import parse_raw from lang.util import escape MAX_QUERY_LENGTH = 70 UNARY_CUTOFF_FREQ = 30 ##annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' ##code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' #annot_file = ALLANNO #code_file = ALLCODE annot_file = GENALLANNO code_file = GENALLCODE #data = preprocess_dataset(annot_file, code_file) #data = preprocess_gendataset(annot_file, code_file) data = preprocess_syndataset(annot_file, code_file) # print data #for e in data: # print "-" * 60 # print "\n\n" # print "idx - ", e['id'] # print "\n\n" # print "query tokens - ", e['query_tokens'] # print "\n\n" # print "code - ", e['code'] # print "\n\n" # print "str_map - ", e['str_map'] # print "\n\n" # print "raw_code - ", e['raw_code'] # print "\n\n" # print "bannot - ", e['bannot'] # print "\n\n" # print "bcode - ", e['bcode'] # print "\n\n" # print "ref_type - ", e['ref_type'] # print "\n\n" for e in data: e['parse_tree'] = parse_raw(e['code']) parse_trees = [e['parse_tree'] for e in data] # apply unary closures # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ) # for i, parse_tree in enumerate(parse_trees): # apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) # write grammar with open('django.grammar.unary_closure.txt', 'w') as f: for rule in grammar: f.write(rule.__repr__() + '\n') # # build grammar ... # from lang.py.py_dataset import extract_grammar # grammar, all_parse_trees = extract_grammar(code_file) annot_tokens = list(chain(*[e['query_tokens'] for e in data])) annot_tokens = normalize_query_tokens(annot_tokens) annot_vocab = gen_vocab( annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980) #annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=0) # gen_vocab(annot_tokens, vocab_size=5980) terminal_token_seq = [] empty_actions_count = 0 # helper function begins def get_terminal_tokens(_terminal_str): # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-') # _terminal_tokens = filter(None, re.split('( )', _terminal_str)) # _terminal_str.split('-SP-') tmp_terminal_tokens = _terminal_str.split(' ') _terminal_tokens = [] for token in tmp_terminal_tokens: if token: _terminal_tokens.append(token) _terminal_tokens.append(' ') return _terminal_tokens[:-1] # return _terminal_tokens # helper function ends # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] for node in parse_tree.get_leaves(): if grammar.is_value_node(node): terminal_val = node.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: assert len(terminal_token) > 0 terminal_token_seq.append(terminal_token) terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3) #terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=0) phrase_vocab = gen_phrase_vocab(data) pos_vocab = gen_pos_vocab(data) #assert '_STR:0_' in terminal_vocab train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data', phrase_vocab, pos_vocab) dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data', phrase_vocab, pos_vocab) test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data', phrase_vocab, pos_vocab) all_examples = [] can_fully_gen_num = 0 # second pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] str_map = entry['str_map'] parse_tree = entry['parse_tree'] rule_list, rule_parents = parse_tree.get_productions( include_value_node=True) #print "Rule List - " #for r in rule_list: # print "Rule -", r #for k, v in rule_parents.iteritems(): # print "Rule parents - ", k, " - ", v #print "Rule parents - ", rule_parents actions = [] can_fully_gen = True rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = { 'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule } action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) # assert len(terminal_tokens) > 0 for terminal_token in terminal_tokens: term_tok_id = terminal_vocab[terminal_token] tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = { 'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH: action = Action(GEN_TOKEN, d) if terminal_token not in terminal_vocab: if terminal_token not in query_tokens: # print terminal_token can_fully_gen = False else: # copy if term_tok_id != terminal_vocab.unk: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = { 'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: empty_actions_count += 1 continue example = DataEntry( idx, query_tokens, parse_tree, code, actions, { 'raw_code': entry['raw_code'], 'str_map': entry['str_map'], 'phrase': entry['phrase'], 'pos': entry['pos'], 'bannot': entry['bannot'], 'bcode': entry['bcode'], 'ref_type': entry['ref_type'] }) if can_fully_gen: can_fully_gen_num += 1 # train, valid, test if 0 <= idx < 13000: train_data.add(example) elif 13000 <= idx < 14000: dev_data.add(example) else: test_data.add(example) # modified train valid test counts #if 0 <= idx < 10000: # train_data.add(example) #elif 10000 <= idx < 11000: # dev_data.add(example) #else: # test_data.add(example) all_examples.append(example) # print statistics max_query_len = max(len(e.query) for e in all_examples) max_actions_len = max(len(e.actions) for e in all_examples) serialize_to_file([len(e.query) for e in all_examples], 'query.len') serialize_to_file([len(e.actions) for e in all_examples], 'actions.len') logging.info('examples that can be fully reconstructed: %d/%d=%f', can_fully_gen_num, len(all_examples), can_fully_gen_num / len(all_examples)) logging.info('empty_actions_count: %d', empty_actions_count) logging.info('max_query_len: %d', max_query_len) logging.info('max_actions_len: %d', max_actions_len) train_data.init_data_matrices() dev_data.init_data_matrices() test_data.init_data_matrices() #print train_data ## print train_data matrix #print "Data matrix: query_tokens " #print train_data.data_matrix['query_tokens'] #print "\n" * 2 #print "Data matrix : query_tokens_phrase" #print "\n" * 2 #print train_data.data_matrix['query_tokens_phrase'] #print "\n" * 2 #print "Data matrix : query_tokens_pos" #print "\n" * 2 #print train_data.data_matrix['query_tokens_pos'] #print "\n" * 2 #print "Data matrix : query_tokens_cid" #print "\n" * 2 #print train_data.data_matrix['query_tokens_cid'] #print "\n" * 2 ## print few data entries #for d in train_data.examples[:5]: # print "\n" * 2 # print d ## lets print dataset for good measure serialize_to_file( (train_data, dev_data, test_data), # 'data/django.pnet.qparse.dataset.freq3.par_info.refact.space_only.bin') 'data/django.pnet.fullcanon.dataset.freq3.par_info.refact.space_only.bin' ) # 'data/django.pnet.dataset.freq3.par_info.refact.space_only.bin') #'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin') # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ)) return train_data, dev_data, test_data
def parse_hs_dataset_for_seq2tree(): from lang.py.py_dataset import preprocess_hs_dataset MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config! WORD_FREQ_CUT_OFF = 3 MAX_DECODING_TIME_STEP = 800 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out' data = preprocess_hs_dataset(annot_file, code_file) parse_trees = [e['parse_tree'] for e in data] # apply unary closures unary_closures = get_top_unary_closures(parse_trees, k=20) for parse_tree in parse_trees: apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) decode_time_steps = defaultdict(int) f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.txt', 'w') f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.txt', 'w') f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.txt', 'w') f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/train.id.txt', 'w') f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/dev.id.txt', 'w') f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/hs/data_unkreplaced/test.id.txt', 'w') # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] parse_tree = entry['parse_tree'] original_parse_tree = parse_tree.copy() break_value_nodes(parse_tree, hs=True) tree_repr = ast_tree_to_seq2tree_repr(parse_tree) num_decode_time_step = len(tree_repr.split(' ')) decode_time_steps[num_decode_time_step] += 1 new_tree = seq2tree_repr_to_ast_tree(tree_repr) merge_broken_value_nodes(new_tree) query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH] query = ' '.join(query_tokens) line = query + '\t' + tree_repr if num_decode_time_step > MAX_DECODING_TIME_STEP: continue # train, valid, test if 0 <= idx < 533: f_train.write(line + '\n') f_train_rawid.write(str(idx) + '\n') elif idx < 599: f_dev.write(line + '\n') f_dev_rawid.write(str(idx) + '\n') else: f_test.write(line + '\n') f_test_rawid.write(str(idx) + '\n') if original_parse_tree != new_tree: print '*' * 50 print idx print code f_train.close() f_dev.close() f_test.close() f_train_rawid.close() f_dev_rawid.close() f_test_rawid.close() # print 'num. of decoding time steps distribution:' for k in sorted(decode_time_steps): print '%d\t%d' % (k, decode_time_steps[k])
def parse_django_dataset_for_seq2tree(): from lang.py.parse import parse_raw MAX_QUERY_LENGTH = 70 MAX_DECODING_TIME_STEP = 300 UNARY_CUTOFF_FREQ = 30 annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno' code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code' data = preprocess_dataset(annot_file, code_file) for e in data: e['parse_tree'] = parse_raw(e['code']) parse_trees = [e['parse_tree'] for e in data] # apply unary closures # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ) # for i, parse_tree in enumerate(parse_trees): # apply_unary_closures(parse_tree, unary_closures) # build the grammar grammar = get_grammar(parse_trees) # # build grammar ... # from lang.py.py_dataset import extract_grammar # grammar, all_parse_trees = extract_grammar(code_file) f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.txt', 'w') f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.txt', 'w') f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.txt', 'w') f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.id.txt', 'w') f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.id.txt', 'w') f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.id.txt', 'w') decode_time_steps = defaultdict(int) # first pass for entry in data: idx = entry['id'] query_tokens = entry['query_tokens'] code = entry['code'] parse_tree = entry['parse_tree'] original_parse_tree = parse_tree.copy() break_value_nodes(parse_tree) tree_repr = ast_tree_to_seq2tree_repr(parse_tree) num_decode_time_step = len(tree_repr.split(' ')) decode_time_steps[num_decode_time_step] += 1 new_tree = seq2tree_repr_to_ast_tree(tree_repr) merge_broken_value_nodes(new_tree) query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH] query = ' '.join(query_tokens) line = query + '\t' + tree_repr if num_decode_time_step > MAX_DECODING_TIME_STEP: continue # train, valid, test if 0 <= idx < 16000: f_train.write(line + '\n') f_train_rawid.write(str(idx) + '\n') elif 16000 <= idx < 17000: f_dev.write(line + '\n') f_dev_rawid.write(str(idx) + '\n') else: f_test.write(line + '\n') f_test_rawid.write(str(idx) + '\n') if original_parse_tree != new_tree: print '*' * 50 print idx print code f_train.close() f_dev.close() f_test.close() f_train_rawid.close() f_dev_rawid.close() f_test_rawid.close()