예제 #1
0
def preprocess_dataset(decl_file, desc_file, code_file):
    examples = []

    err_num = 0
    for idx, (decl, desc, code) in enumerate(zip(decl_file, desc_file, code_file)):
        decl = decl.strip()
        desc = desc.strip()
        code = code.strip()
        try:
            clean_query_tokens = create_query(decl, desc)
            clean_code = canonicalize_code(code)
            example = {'id': idx,
                       'query_tokens': clean_query_tokens,
                       'code': clean_code,
                       'raw_code': code,
                       'parse_tree': parse_raw(clean_code)}
            examples.append(example)

        except:
            print code
            err_num += 1

        idx += 1

    print 'error num: %d' % err_num
    print 'preprocess_dataset: cleaned example num: %d' % len(examples)
    decl_file.close()
    desc_file.close()
    code_file.close()
    return examples
예제 #2
0
def canonicalize_example(query, code):
    from lang.py.parse import parse_raw, parse_tree_to_python_ast, canonicalize_code as make_it_compilable
    import astor, ast

    canonical_query, str_map = canonicalize_query(query)
    canonical_code = code

    for str_literal, str_repr in str_map.iteritems():
        canonical_code = canonical_code.replace(str_literal,
                                                '\'' + str_repr + '\'')

    canonical_code = make_it_compilable(canonical_code)

    # sanity check
    parse_tree = parse_raw(canonical_code)
    gold_ast_tree = ast.parse(canonical_code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    source = astor.to_source(ast_tree)

    assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (
        gold_source, source)

    query_tokens = canonical_query.split(' ')

    return query_tokens, canonical_code, str_map
예제 #3
0
def canonicalize_code(code):
    from lang.py.parse import parse_raw, parse_tree_to_python_ast, canonicalize_code as make_it_compilable
    import astor, ast

    canonical_code = make_it_compilable(code)

    # sanity check
    parse_tree = parse_raw(canonical_code)
    gold_ast_tree = ast.parse(canonical_code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    source = astor.to_source(ast_tree)

    assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source)

    return canonical_code
예제 #4
0
def canonicalize_hs_example(query, code):
    query = re.sub(r'<.*?>', '', query)
    query_tokens = nltk.word_tokenize(query)

    code = code.replace('§', '\n').strip()

    # sanity check
    parse_tree = parse_raw(code)
    gold_ast_tree = ast.parse(code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    pred_source = astor.to_source(ast_tree)

    assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, pred_source)

    return query_tokens, code, parse_tree
예제 #5
0
def canonicalize_hs_example(query, code):
    query = re.sub(r'<.*?>', '', query)
    query_tokens = nltk.word_tokenize(query)

    code = code.replace('§', '\n').strip()

    # sanity check
    parse_tree = parse_raw(code)
    gold_ast_tree = ast.parse(code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    pred_source = astor.to_source(ast_tree)

    assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (
        gold_source, pred_source)

    return query_tokens, code, parse_tree
예제 #6
0
def canonicalize_example(query, code):
    from lang.py.parse import parse_raw, parse_tree_to_python_ast, canonicalize_code as make_it_compilable
    import astor, ast

    canonical_query, str_map = canonicalize_query(query)
    canonical_code = code

    for str_literal, str_repr in str_map.iteritems():
        canonical_code = canonical_code.replace(str_literal, '\'' + str_repr + '\'')

    canonical_code = make_it_compilable(canonical_code)

    # sanity check
    parse_tree = parse_raw(canonical_code)
    gold_ast_tree = ast.parse(canonical_code).body[0]
    gold_source = astor.to_source(gold_ast_tree)
    ast_tree = parse_tree_to_python_ast(parse_tree)
    source = astor.to_source(ast_tree)

    assert gold_source == source, 'sanity check fails: gold=[%s], actual=[%s]' % (gold_source, source)

    query_tokens = canonical_query.split(' ')

    return query_tokens, canonical_code, str_map
예제 #7
0
def canonicalize_hs_example(query, code):
    query = re.sub(r'<.*?>', '', query)
    #print(query)
    query_tokens = nltk.word_tokenize(query)

    code = code.replace('§', '\n').strip()
    # sanity check
    #print(code)
    parse_tree = parse_raw(code)
    gold_ast_tree = ast.parse(code).body
    gold_source = ''
    for pt in gold_ast_tree:
        gold_source += astor.to_source(pt)
    #print(gold_source)
    #print('====gold========')
    pred_source = ''
    for pc in parse_tree.children:
        ast_tree = parse_tree_to_python_ast(pc)
        pred_source += astor.to_source(ast_tree)
    #print(pred_source)
    assert gold_source == pred_source, 'sanity check fails: gold=[%s], actual=[%s]' % (
        gold_source, pred_source)

    return query_tokens, code, parse_tree
예제 #8
0
def parse_django_dataset():
    from lang.py.parse import parse_raw
    from lang.util import escape
    MAX_QUERY_LENGTH = 70
    UNARY_CUTOFF_FREQ = 30

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # write grammar
    with open('django.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(
        annot_tokens, vocab_size=5000,
        freq_cutoff=3)  # gen_vocab(annot_tokens, vocab_size=5980)

    print "got annotation vocabulary"

    terminal_token_seq = []
    empty_actions_count = 0

    # helper function begins
    def get_terminal_tokens(_terminal_str):
        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')
        tmp_terminal_tokens = _terminal_str.split(' ')
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            if token:
                _terminal_tokens.append(token)
            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]
        # return _terminal_tokens

    # helper function ends

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    terminal_token_seq.append(terminal_token)

    terminal_vocab = gen_vocab(terminal_token_seq,
                               vocab_size=5000,
                               freq_cutoff=3)
    assert '_STR:0_' in terminal_vocab

    print "in terminal vocab"

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')

    all_examples = []

    can_fully_gen_num = 0

    # second pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        str_map = entry['str_map']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        actions = []
        can_fully_gen = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {
                        'literal': terminal_token,
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_gen = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {
                    'literal': '<eos>',
                    'rule': rule,
                    'parent_rule': parent_rule,
                    'parent_t': parent_t
                }
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            empty_actions_count += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {
            'raw_code': entry['raw_code'],
            'str_map': entry['str_map']
        })

        if can_fully_gen:
            can_fully_gen_num += 1

        # train, valid, test
        if 0 <= idx < 16000:
            train_data.add(example)
        elif 16000 <= idx < 17000:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_gen_num, len(all_examples),
                 can_fully_gen_num / len(all_examples))
    logging.info('empty_actions_count: %d', empty_actions_count)
    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices()
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((
        train_data, dev_data, test_data
    ), 'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin'
                      )
    # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

    return train_data, dev_data, test_data
예제 #9
0
def parse_django_dataset_for_seq2tree():
    from lang.py.parse import parse_raw
    MAX_QUERY_LENGTH = 70
    MAX_DECODING_TIME_STEP = 300
    UNARY_CUTOFF_FREQ = 30

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    f_train = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.txt',
        'w')
    f_dev = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.txt',
        'w')
    f_test = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.txt',
        'w')

    f_train_rawid = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.id.txt',
        'w')
    f_dev_rawid = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.id.txt',
        'w')
    f_test_rawid = open(
        '/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.id.txt',
        'w')

    decode_time_steps = defaultdict(int)

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        original_parse_tree = parse_tree.copy()
        break_value_nodes(parse_tree)
        tree_repr = ast_tree_to_seq2tree_repr(parse_tree)

        num_decode_time_step = len(tree_repr.split(' '))
        decode_time_steps[num_decode_time_step] += 1

        new_tree = seq2tree_repr_to_ast_tree(tree_repr)
        merge_broken_value_nodes(new_tree)

        query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]
        query = ' '.join(query_tokens)
        line = query + '\t' + tree_repr

        if num_decode_time_step > MAX_DECODING_TIME_STEP:
            continue

        # train, valid, test
        if 0 <= idx < 16000:
            f_train.write(line + '\n')
            f_train_rawid.write(str(idx) + '\n')
        elif 16000 <= idx < 17000:
            f_dev.write(line + '\n')
            f_dev_rawid.write(str(idx) + '\n')
        else:
            f_test.write(line + '\n')
            f_test_rawid.write(str(idx) + '\n')

        if original_parse_tree != new_tree:
            print '*' * 50
            print idx
            print code

    f_train.close()
    f_dev.close()
    f_test.close()

    f_train_rawid.close()
    f_dev_rawid.close()
    f_test_rawid.close()
예제 #10
0
def parse_django_dataset():
    from lang.py.parse import parse_raw
    from lang.util import escape
    MAX_QUERY_LENGTH = 70
    UNARY_CUTOFF_FREQ = 30

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # write grammar
    with open('django.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980)

    terminal_token_seq = []
    empty_actions_count = 0

    # helper function begins
    def get_terminal_tokens(_terminal_str):
        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')
        tmp_terminal_tokens = _terminal_str.split(' ')
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            if token:
                _terminal_tokens.append(token)
            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]
        # return _terminal_tokens
    # helper function ends

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    terminal_token_seq.append(terminal_token)

    terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3)
    assert '_STR:0_' in terminal_vocab

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')

    all_examples = []

    can_fully_gen_num = 0

    # second pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        str_map = entry['str_map']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)

        actions = []
        can_fully_gen = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_gen = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            empty_actions_count += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions,
                            {'raw_code': entry['raw_code'], 'str_map': entry['str_map']})

        if can_fully_gen:
            can_fully_gen_num += 1

        # train, valid, test
        if 0 <= idx < 16000:
            train_data.add(example)
        elif 16000 <= idx < 17000:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_gen_num, len(all_examples),
                 can_fully_gen_num / len(all_examples))
    logging.info('empty_actions_count: %d', empty_actions_count)
    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices()
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data),
                      'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')
                      # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

    return train_data, dev_data, test_data
예제 #11
0
def parse_django_dataset_for_seq2tree():
    from lang.py.parse import parse_raw
    MAX_QUERY_LENGTH = 70
    MAX_DECODING_TIME_STEP = 300
    UNARY_CUTOFF_FREQ = 30

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    f_train = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.txt', 'w')
    f_dev = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.txt', 'w')
    f_test = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.txt', 'w')

    f_train_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/train.id.txt', 'w')
    f_dev_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/dev.id.txt', 'w')
    f_test_rawid = open('/Users/yinpengcheng/Research/lang2logic/seq2tree/django/data/test.id.txt', 'w')

    decode_time_steps = defaultdict(int)

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        original_parse_tree = parse_tree.copy()
        break_value_nodes(parse_tree)
        tree_repr = ast_tree_to_seq2tree_repr(parse_tree)

        num_decode_time_step = len(tree_repr.split(' '))
        decode_time_steps[num_decode_time_step] += 1

        new_tree = seq2tree_repr_to_ast_tree(tree_repr)
        merge_broken_value_nodes(new_tree)

        query_tokens = [t for t in query_tokens if t != ''][:MAX_QUERY_LENGTH]
        query = ' '.join(query_tokens)
        line = query + '\t' + tree_repr

        if num_decode_time_step > MAX_DECODING_TIME_STEP:
            continue

        # train, valid, test
        if 0 <= idx < 16000:
            f_train.write(line + '\n')
            f_train_rawid.write(str(idx) + '\n')
        elif 16000 <= idx < 17000:
            f_dev.write(line + '\n')
            f_dev_rawid.write(str(idx) + '\n')
        else:
            f_test.write(line + '\n')
            f_test_rawid.write(str(idx) + '\n')

        if original_parse_tree != new_tree:
            print '*' * 50
            print idx
            print code

    f_train.close()
    f_dev.close()
    f_test.close()

    f_train_rawid.close()
    f_dev_rawid.close()
    f_test_rawid.close()