Ejemplo n.º 1
0
def prepare_dataset():
    vocab_freq_cutoff = 1
    grammar = ASDLGrammar.from_text(open('asdl/lang/lambda_dcs/lambda_asdl.txt').read())
    transition_system = LambdaCalculusTransitionSystem(grammar)

    train_set = load_dataset(transition_system, 'data/atis/train.txt')
    dev_set = load_dataset(transition_system, 'data/atis/dev.txt')
    test_set = load_dataset(transition_system, 'data/atis/test.txt')

    # generate vocabulary
    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_set], size=5000, freq_cutoff=vocab_freq_cutoff)

    primitive_tokens = [map(lambda a: a.action.token,
                            filter(lambda a: isinstance(a.action, GenTokenAction), e.tgt_actions))
                        for e in train_set]

    primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=0)

    # generate vocabulary for the code tokens!
    code_tokens = [transition_system.tokenize_code(e.tgt_code, mode='decoder') for e in train_set]
    code_vocab = VocabEntry.from_corpus(code_tokens, size=5000, freq_cutoff=0)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_len = [len(e.tgt_actions) for e in chain(train_set, dev_set, test_set)]
    print('Max action len: %d' % max(action_len), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
    print('Actions larger than 100: %d' % len(filter(lambda x: x > 100, action_len)), file=sys.stderr)

    pickle.dump(train_set, open('data/atis/train.bin', 'w'))
    pickle.dump(dev_set, open('data/atis/dev.bin', 'w'))
    pickle.dump(test_set, open('data/atis/test.bin', 'w'))
    pickle.dump(vocab, open('data/atis/vocab.bin', 'w'))
Ejemplo n.º 2
0
    def process_regex_dataset():
        grammar = ASDLGrammar.from_text(
            open('asdl/lang/streg/streg_asdl.txt').read())
        transition_system = StRegTransitionSystem(grammar)
        train_set = StReg.load_regex_dataset(transition_system, "train")
        val_set = StReg.load_regex_dataset(transition_system, "val")
        testi_set = StReg.load_regex_dataset(transition_system, "testi")
        teste_set = StReg.load_regex_dataset(transition_system, "teste")

        # generate vocabulary
        vocab_freq_cutoff = 2
        src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_set],
                                           size=5000,
                                           freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_set
        ]
        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=0)

        # generate vocabulary for the code tokens!
        code_tokens = [
            transition_system.tokenize_code(e.tgt_code, mode='decoder')
            for e in train_set
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=0)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        action_len = [
            len(e.tgt_actions) for e in chain(train_set, testi_set, teste_set)
        ]
        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        pickle.dump(train_set, open('data/streg/train.bin', 'wb'))
        pickle.dump(val_set, open('data/streg/val.bin', 'wb'))
        pickle.dump(testi_set, open('data/streg/testi.bin', 'wb'))
        pickle.dump(teste_set, open('data/streg/teste.bin', 'wb'))
        pickle.dump(
            vocab,
            open('data/streg/vocab.freq%d.bin' % (vocab_freq_cutoff), 'wb'))
Ejemplo n.º 3
0
def prepare_dataset(data_path):
    grammar = ASDLGrammar.from_text(
        open('../../asdl/lang/sql/sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)

    datasets = []
    for file in ['dev', 'test', 'train']:
        print('processing %s' % file, file=sys.stderr)
        dataset_path = os.path.join(data_path, file + '.jsonl')
        table_path = os.path.join(data_path, file + '.tables.jsonl')
        dataset = load_dataset(transition_system, dataset_path, table_path)
        pickle.dump(dataset, open('../../data/wikisql1/%s.bin' % file, 'wb'))

        datasets.append(dataset)

    train_set = datasets[2]
    dev_set = datasets[0]
    test_set = datasets[1]
    # generate vocabulary
    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_set],
                                       size=100000,
                                       freq_cutoff=2)
    primitive_vocab = VocabEntry()
    primitive_vocab.add('</primitive>')

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    pickle.dump(vocab, open('../../data/wikisql1/vocab.bin', 'wb'))

    action_len = [
        len(e.tgt_actions) for e in chain(train_set, dev_set, test_set)
    ]
    print('Max action len: %d' % max(action_len), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_len))),
          file=sys.stderr)
Ejemplo n.º 4
0
    def parse_django_dataset(annot_file,
                             code_file,
                             asdl_file_path,
                             max_query_len=70,
                             vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = PythonTransitionSystem(grammar)

        loaded_examples = []

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        for idx, (src_query,
                  tgt_code) in enumerate(zip(open(annot_file),
                                             open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code).body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            try:
                hyp = Hypothesis()
                for t, action in enumerate(tgt_actions):
                    # assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                    # if isinstance(action, ApplyRuleAction):
                    #     assert action.production in transition_system.get_valid_continuating_productions(hyp)

                    p_t = -1
                    f_t = None
                    if hyp.frontier_node:
                        p_t = hyp.frontier_node.created_time
                        f_t = hyp.frontier_field.field.__repr__(plain=True)

                    # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                    hyp = hyp.clone_and_apply_action(action)

                assert hyp.frontier_node is None and hyp.frontier_field is None

                src_from_hyp = astor.to_source(
                    asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
                assert src_from_hyp == gold_source

                # print('+' * 60)
            except:
                continue

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            if 0 <= idx < 16000:
                train_examples.append(example)
            elif 16000 <= idx < 17000:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab
Ejemplo n.º 5
0
def preprocess_conala_dataset(train_file,
                              test_file,
                              grammar_file,
                              src_freq=3,
                              code_freq=3,
                              mined_data_file=None,
                              api_data_file=None,
                              vocab_size=20000,
                              num_mined=0,
                              out_dir='data/conala'):
    np.random.seed(1234)

    asdl_text = open(grammar_file).read()
    grammar = ASDLGrammar.from_text(asdl_text)
    transition_system = Python3TransitionSystem(grammar)

    print('process gold training data...')
    train_examples = preprocess_dataset(train_file,
                                        name='train',
                                        transition_system=transition_system)

    # held out 200 examples for development
    full_train_examples = train_examples[:]
    np.random.shuffle(train_examples)
    dev_examples = train_examples[:200]
    train_examples = train_examples[200:]

    mined_examples = []
    api_examples = []
    if mined_data_file and num_mined > 0:
        print("use mined data: ", num_mined)
        print("from file: ", mined_data_file)
        mined_examples = preprocess_dataset(
            mined_data_file,
            name='mined',
            transition_system=transition_system,
            firstk=num_mined)
        pickle.dump(
            mined_examples,
            open(os.path.join(out_dir, 'mined_{}.bin'.format(num_mined)),
                 'wb'))

    if api_data_file:
        print("use api docs from file: ", api_data_file)
        name = os.path.splitext(os.path.basename(api_data_file))[0]
        api_examples = preprocess_dataset(api_data_file,
                                          name='api',
                                          transition_system=transition_system)
        pickle.dump(api_examples,
                    open(os.path.join(out_dir, name + '.bin'), 'wb'))

    if mined_examples and api_examples:
        pickle.dump(
            mined_examples + api_examples,
            open(
                os.path.join(out_dir, 'pre_{}_{}.bin'.format(num_mined, name)),
                'wb'))

    # combine to make vocab
    train_examples += mined_examples
    train_examples += api_examples
    print(f'{len(train_examples)} training instances', file=sys.stderr)
    print(f'{len(dev_examples)} dev instances', file=sys.stderr)

    print('process testing data...')
    test_examples = preprocess_dataset(test_file,
                                       name='test',
                                       transition_system=transition_system)
    print(f'{len(test_examples)} testing instances', file=sys.stderr)

    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples],
                                       size=vocab_size,
                                       freq_cutoff=src_freq)
    primitive_tokens = [
        map(
            lambda a: a.action.token,
            filter(lambda a: isinstance(a.action, GenTokenAction),
                   e.tgt_actions)) for e in train_examples
    ]
    primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                             size=vocab_size,
                                             freq_cutoff=code_freq)

    # generate vocabulary for the code tokens!
    code_tokens = [
        transition_system.tokenize_code(e.tgt_code, mode='decoder')
        for e in train_examples
    ]

    code_vocab = VocabEntry.from_corpus(code_tokens,
                                        size=vocab_size,
                                        freq_cutoff=code_freq)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_lens = [len(e.tgt_actions) for e in train_examples]
    print('Max action len: %d' % max(action_lens), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_lens))),
          file=sys.stderr)

    pickle.dump(
        train_examples,
        open(os.path.join(out_dir, 'train.all_{}.bin'.format(num_mined)),
             'wb'))
    pickle.dump(full_train_examples,
                open(os.path.join(out_dir, 'train.gold.full.bin'), 'wb'))
    pickle.dump(dev_examples, open(os.path.join(out_dir, 'dev.bin'), 'wb'))
    pickle.dump(test_examples, open(os.path.join(out_dir, 'test.bin'), 'wb'))
    if mined_examples and api_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.mined_%s.%s.bin' % (
            src_freq, code_freq, num_mined, name)
    elif mined_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.mined_%s.bin' % (
            src_freq, code_freq, num_mined)
    elif api_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.%s.bin' % (src_freq,
                                                              code_freq, name)
    else:
        vocab_name = 'vocab.src_freq%d.code_freq%d.bin' % (src_freq, code_freq)
    pickle.dump(vocab, open(os.path.join(out_dir, vocab_name), 'wb'))
Ejemplo n.º 6
0
def preprocess_conala_dataset(train_file,
                              test_file,
                              grammar_file,
                              src_freq=3,
                              code_freq=3):
    np.random.seed(1234)

    # load grammar and transition system
    asdl_text = open(grammar_file).read()
    grammar = ASDLGrammar.from_text(asdl_text)
    transition_system = Python3TransitionSystem(grammar)

    print('process training data...')
    train_examples = preprocess_dataset(train_file,
                                        name='train',
                                        transition_system=transition_system)

    # held out 200 examples for development
    full_train_examples = train_examples[:]
    np.random.shuffle(train_examples)
    dev_examples = train_examples[:200]
    train_examples = train_examples[200:]

    # full_train_examples = train_examples[:]
    # np.random.shuffle(train_examples)
    # dev_examples = []
    # dev_questions = set()
    # dev_examples_id = []
    # for i, example in enumerate(full_train_examples):
    #     qid = example.meta['example_dict']['question_id']
    #     if qid not in dev_questions and len(dev_examples) < 200:
    #         dev_questions.add(qid)
    #         dev_examples.append(example)
    #         dev_examples_id.append(i)

    # train_examples = [e for i, e in enumerate(full_train_examples) if i not in dev_examples_id]
    print(f'{len(train_examples)} training instances', file=sys.stderr)
    print(f'{len(dev_examples)} dev instances', file=sys.stderr)

    print('process testing data...')
    test_examples = preprocess_dataset(test_file,
                                       name='test',
                                       transition_system=transition_system)
    print(f'{len(test_examples)} testing instances', file=sys.stderr)

    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples],
                                       size=5000,
                                       freq_cutoff=src_freq)
    primitive_tokens = [
        map(
            lambda a: a.action.token,
            filter(lambda a: isinstance(a.action, GenTokenAction),
                   e.tgt_actions)) for e in train_examples
    ]
    primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                             size=5000,
                                             freq_cutoff=code_freq)

    # generate vocabulary for the code tokens!
    code_tokens = [
        transition_system.tokenize_code(e.tgt_code, mode='decoder')
        for e in train_examples
    ]
    code_vocab = VocabEntry.from_corpus(code_tokens,
                                        size=5000,
                                        freq_cutoff=code_freq)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_lens = [len(e.tgt_actions) for e in train_examples]
    print('Max action len: %d' % max(action_lens), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_lens))),
          file=sys.stderr)

    if False:
        pickle.dump(train_examples,
                    open('data/conala/train.var_str_sep.bin', 'wb'))
        pickle.dump(full_train_examples,
                    open('data/conala/train.var_str_sep.full.bin', 'wb'))
        pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.bin',
                                       'wb'))
        pickle.dump(test_examples,
                    open('data/conala/test.var_str_sep.bin', 'wb'))
        pickle.dump(
            vocab,
            open(
                'data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.bin'
                % (src_freq, code_freq), 'wb'))
Ejemplo n.º 7
0
def prepare_geo_dataset():
    vocab_freq_cutoff = 2  # for geo query
    grammar = ASDLGrammar.from_text(
        open('asdl/lang/lambda_dcs/lambda_asdl.txt').read())
    transition_system = LambdaCalculusTransitionSystem(grammar)

    train_set = load_dataset(transition_system,
                             'data/geo/train.txt',
                             reorder_predicates=False)
    test_set = load_dataset(transition_system,
                            'data/geo/test.txt',
                            reorder_predicates=False)

    # generate vocabulary
    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_set],
                                       size=5000,
                                       freq_cutoff=vocab_freq_cutoff)

    primitive_tokens = [
        map(
            lambda a: a.action.token,
            filter(lambda a: isinstance(a.action, GenTokenAction),
                   e.tgt_actions)) for e in train_set
    ]

    primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                             size=5000,
                                             freq_cutoff=0)

    # generate vocabulary for the code tokens!
    code_tokens = [
        transition_system.tokenize_code(e.tgt_code, mode='decoder')
        for e in train_set
    ]
    code_vocab = VocabEntry.from_corpus(code_tokens, size=5000, freq_cutoff=0)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_len = [len(e.tgt_actions) for e in chain(train_set, test_set)]
    print('Max action len: %d' % max(action_len), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_len))),
          file=sys.stderr)

    pickle.dump(train_set, open('data/geo/train.bin', 'wb'))

    # randomly sample 20% data as the dev set
    np.random.seed(1234)
    dev_example_ids = np.random.choice(list(range(len(train_set))),
                                       replace=False,
                                       size=int(0.2 * len(train_set)))
    train_example_ids = [
        i for i in range(len(train_set)) if i not in dev_example_ids
    ]

    print('# splitted train examples %d, # splitted dev examples %d' %
          (len(train_example_ids), len(dev_example_ids)),
          file=sys.stderr)

    dev_set = [train_set[i] for i in dev_example_ids]
    train_set = [train_set[i] for i in train_example_ids]

    pickle.dump(train_set, open('data/geo/train.split.bin', 'wb'))
    pickle.dump(dev_set, open('data/geo/dev.bin', 'wb'))

    pickle.dump(test_set, open('data/geo/test.bin', 'wb'))
    pickle.dump(vocab, open('data/geo/vocab.freq2.bin', 'wb'))
Ejemplo n.º 8
0
    def parse_natural_dataset(asdl_file_path,
                              max_query_len=70,
                              vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        print('building grammar')
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = Python3TransitionSystem(grammar)

        loaded_examples = []

        annotations = []
        codes = []
        path = os.path.join(os.path.dirname(__file__), "datagen")
        datagens = os.listdir(path)
        for folder in datagens:
            if "__" in folder or not os.path.isdir(os.path.join(path, folder)):
                continue
            with open(os.path.join(path, folder, "inputs.txt"), 'r') as file:
                annotations += file.read().split('\n')
            with open(os.path.join(path, folder, "outputs.txt"), 'r') as file:
                codes += file.read().split('\n')
        annotation_codes = list(zip(annotations, codes))
        np.random.seed(42)
        np.random.shuffle(annotation_codes)

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        print('processing examples')
        for idx, (src_query, tgt_code) in enumerate(annotation_codes):
            if (idx % 100 == 0):
                sys.stdout.write("\r%s / %s" % (idx, len(annotation_codes)))
                sys.stdout.flush()

            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)
            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # assert action.__class__ in transition_system.get_valid_continuation_types(
                # hyp)

                p_t = -1
                f_t = None
                if hyp.frontier_node:
                    p_t = hyp.frontier_node.created_time
                    f_t = hyp.frontier_field.field.__repr__(plain=True)

                # print('\t[%d] %s, frontier field: %s, parent: %d' %
                #     (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            # assert hyp.frontier_node is None and hyp.frontier_field is None

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
            if "b'" not in str(gold_source) and 'b"' not in str(gold_source):
                assert src_from_hyp == gold_source

            # print('+' * 60)

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        print("\nsplitting train/dev/test")
        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            total_examples = len(loaded_examples)
            split_size = np.ceil(total_examples * 0.05)
            (dev_split, test_split) = (total_examples - split_size * 2,
                                       total_examples - split_size)
            if 0 <= idx < dev_split:
                train_examples.append(example)
            elif dev_split <= idx < test_split:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        # assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab