Esempio n. 1
0
    def run():
        asdl_text = open('asdl/lang/py3/py3_asdl.simplified.txt').read()
        grammar = ASDLGrammar.from_text(asdl_text)

        annot_file = 'data/natural/all.anno'
        code_file = 'data/natural/all.code'

        transition_system = Python3TransitionSystem(grammar)

        for idx, (src_query,
                  tgt_code) in enumerate(zip(open(annot_file),
                                             open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast)
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            hyp2 = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)
                hyp2.apply_action(action)

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar))
            assert src_from_hyp == gold_source
            assert hyp.tree == hyp2.tree and hyp.tree is not hyp2.tree

            print(idx)
Esempio n. 2
0
def preprocess_conala_dataset(train_file,
                              test_file,
                              grammar_file,
                              src_freq=3,
                              code_freq=3,
                              mined_data_file=None,
                              api_data_file=None,
                              vocab_size=20000,
                              num_mined=0,
                              out_dir='data/conala'):
    np.random.seed(1234)

    asdl_text = open(grammar_file).read()
    grammar = ASDLGrammar.from_text(asdl_text)
    transition_system = Python3TransitionSystem(grammar)

    print('process gold training data...')
    train_examples = preprocess_dataset(train_file,
                                        name='train',
                                        transition_system=transition_system)

    # held out 200 examples for development
    full_train_examples = train_examples[:]
    np.random.shuffle(train_examples)
    dev_examples = train_examples[:200]
    train_examples = train_examples[200:]

    mined_examples = []
    api_examples = []
    if mined_data_file and num_mined > 0:
        print("use mined data: ", num_mined)
        print("from file: ", mined_data_file)
        mined_examples = preprocess_dataset(
            mined_data_file,
            name='mined',
            transition_system=transition_system,
            firstk=num_mined)
        pickle.dump(
            mined_examples,
            open(os.path.join(out_dir, 'mined_{}.bin'.format(num_mined)),
                 'wb'))

    if api_data_file:
        print("use api docs from file: ", api_data_file)
        name = os.path.splitext(os.path.basename(api_data_file))[0]
        api_examples = preprocess_dataset(api_data_file,
                                          name='api',
                                          transition_system=transition_system)
        pickle.dump(api_examples,
                    open(os.path.join(out_dir, name + '.bin'), 'wb'))

    if mined_examples and api_examples:
        pickle.dump(
            mined_examples + api_examples,
            open(
                os.path.join(out_dir, 'pre_{}_{}.bin'.format(num_mined, name)),
                'wb'))

    # combine to make vocab
    train_examples += mined_examples
    train_examples += api_examples
    print(f'{len(train_examples)} training instances', file=sys.stderr)
    print(f'{len(dev_examples)} dev instances', file=sys.stderr)

    print('process testing data...')
    test_examples = preprocess_dataset(test_file,
                                       name='test',
                                       transition_system=transition_system)
    print(f'{len(test_examples)} testing instances', file=sys.stderr)

    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples],
                                       size=vocab_size,
                                       freq_cutoff=src_freq)
    primitive_tokens = [
        map(
            lambda a: a.action.token,
            filter(lambda a: isinstance(a.action, GenTokenAction),
                   e.tgt_actions)) for e in train_examples
    ]
    primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                             size=vocab_size,
                                             freq_cutoff=code_freq)

    # generate vocabulary for the code tokens!
    code_tokens = [
        transition_system.tokenize_code(e.tgt_code, mode='decoder')
        for e in train_examples
    ]

    code_vocab = VocabEntry.from_corpus(code_tokens,
                                        size=vocab_size,
                                        freq_cutoff=code_freq)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_lens = [len(e.tgt_actions) for e in train_examples]
    print('Max action len: %d' % max(action_lens), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_lens))),
          file=sys.stderr)

    pickle.dump(
        train_examples,
        open(os.path.join(out_dir, 'train.all_{}.bin'.format(num_mined)),
             'wb'))
    pickle.dump(full_train_examples,
                open(os.path.join(out_dir, 'train.gold.full.bin'), 'wb'))
    pickle.dump(dev_examples, open(os.path.join(out_dir, 'dev.bin'), 'wb'))
    pickle.dump(test_examples, open(os.path.join(out_dir, 'test.bin'), 'wb'))
    if mined_examples and api_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.mined_%s.%s.bin' % (
            src_freq, code_freq, num_mined, name)
    elif mined_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.mined_%s.bin' % (
            src_freq, code_freq, num_mined)
    elif api_examples:
        vocab_name = 'vocab.src_freq%d.code_freq%d.%s.bin' % (src_freq,
                                                              code_freq, name)
    else:
        vocab_name = 'vocab.src_freq%d.code_freq%d.bin' % (src_freq, code_freq)
    pickle.dump(vocab, open(os.path.join(out_dir, vocab_name), 'wb'))
Esempio n. 3
0
def preprocess_conala_dataset(train_file,
                              test_file,
                              grammar_file,
                              src_freq=3,
                              code_freq=3):
    np.random.seed(1234)

    # load grammar and transition system
    asdl_text = open(grammar_file).read()
    grammar = ASDLGrammar.from_text(asdl_text)
    transition_system = Python3TransitionSystem(grammar)

    print('process training data...')
    train_examples = preprocess_dataset(train_file,
                                        name='train',
                                        transition_system=transition_system)

    # held out 200 examples for development
    full_train_examples = train_examples[:]
    np.random.shuffle(train_examples)
    dev_examples = train_examples[:200]
    train_examples = train_examples[200:]

    # full_train_examples = train_examples[:]
    # np.random.shuffle(train_examples)
    # dev_examples = []
    # dev_questions = set()
    # dev_examples_id = []
    # for i, example in enumerate(full_train_examples):
    #     qid = example.meta['example_dict']['question_id']
    #     if qid not in dev_questions and len(dev_examples) < 200:
    #         dev_questions.add(qid)
    #         dev_examples.append(example)
    #         dev_examples_id.append(i)

    # train_examples = [e for i, e in enumerate(full_train_examples) if i not in dev_examples_id]
    print(f'{len(train_examples)} training instances', file=sys.stderr)
    print(f'{len(dev_examples)} dev instances', file=sys.stderr)

    print('process testing data...')
    test_examples = preprocess_dataset(test_file,
                                       name='test',
                                       transition_system=transition_system)
    print(f'{len(test_examples)} testing instances', file=sys.stderr)

    src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples],
                                       size=5000,
                                       freq_cutoff=src_freq)
    primitive_tokens = [
        map(
            lambda a: a.action.token,
            filter(lambda a: isinstance(a.action, GenTokenAction),
                   e.tgt_actions)) for e in train_examples
    ]
    primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                             size=5000,
                                             freq_cutoff=code_freq)

    # generate vocabulary for the code tokens!
    code_tokens = [
        transition_system.tokenize_code(e.tgt_code, mode='decoder')
        for e in train_examples
    ]
    code_vocab = VocabEntry.from_corpus(code_tokens,
                                        size=5000,
                                        freq_cutoff=code_freq)

    vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
    print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

    action_lens = [len(e.tgt_actions) for e in train_examples]
    print('Max action len: %d' % max(action_lens), file=sys.stderr)
    print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
    print('Actions larger than 100: %d' %
          len(list(filter(lambda x: x > 100, action_lens))),
          file=sys.stderr)

    if False:
        pickle.dump(train_examples,
                    open('data/conala/train.var_str_sep.bin', 'wb'))
        pickle.dump(full_train_examples,
                    open('data/conala/train.var_str_sep.full.bin', 'wb'))
        pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.bin',
                                       'wb'))
        pickle.dump(test_examples,
                    open('data/conala/test.var_str_sep.bin', 'wb'))
        pickle.dump(
            vocab,
            open(
                'data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.bin'
                % (src_freq, code_freq), 'wb'))
Esempio n. 4
0
    def parse_natural_dataset(asdl_file_path,
                              max_query_len=70,
                              vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        print('building grammar')
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = Python3TransitionSystem(grammar)

        loaded_examples = []

        annotations = []
        codes = []
        path = os.path.join(os.path.dirname(__file__), "datagen")
        datagens = os.listdir(path)
        for folder in datagens:
            if "__" in folder or not os.path.isdir(os.path.join(path, folder)):
                continue
            with open(os.path.join(path, folder, "inputs.txt"), 'r') as file:
                annotations += file.read().split('\n')
            with open(os.path.join(path, folder, "outputs.txt"), 'r') as file:
                codes += file.read().split('\n')
        annotation_codes = list(zip(annotations, codes))
        np.random.seed(42)
        np.random.shuffle(annotation_codes)

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        print('processing examples')
        for idx, (src_query, tgt_code) in enumerate(annotation_codes):
            if (idx % 100 == 0):
                sys.stdout.write("\r%s / %s" % (idx, len(annotation_codes)))
                sys.stdout.flush()

            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)
            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # assert action.__class__ in transition_system.get_valid_continuation_types(
                # hyp)

                p_t = -1
                f_t = None
                if hyp.frontier_node:
                    p_t = hyp.frontier_node.created_time
                    f_t = hyp.frontier_field.field.__repr__(plain=True)

                # print('\t[%d] %s, frontier field: %s, parent: %d' %
                #     (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            # assert hyp.frontier_node is None and hyp.frontier_field is None

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
            if "b'" not in str(gold_source) and 'b"' not in str(gold_source):
                assert src_from_hyp == gold_source

            # print('+' * 60)

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        print("\nsplitting train/dev/test")
        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            total_examples = len(loaded_examples)
            split_size = np.ceil(total_examples * 0.05)
            (dev_split, test_split) = (total_examples - split_size * 2,
                                       total_examples - split_size)
            if 0 <= idx < dev_split:
                train_examples.append(example)
            elif dev_split <= idx < test_split:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        # assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab