예제 #1
0
    def compare_ast(self, hyp_ast, ref_ast):
        hyp_code = self.ast_to_surface_code(hyp_ast)
        ref_reformatted_code = self.ast_to_surface_code(ref_ast)

        ref_code_tokens = tokenize_code(ref_reformatted_code)
        hyp_code_tokens = tokenize_code(hyp_code)

        return ref_code_tokens == hyp_code_tokens
예제 #2
0
    def compare_ast(self, hyp_ast, ref_ast):
        hyp_code = self.ast_to_surface_code(hyp_ast)
        ref_reformatted_code = self.ast_to_surface_code(ref_ast)

        ref_code_tokens = tokenize_code(ref_reformatted_code)
        hyp_code_tokens = tokenize_code(hyp_code)

        return ref_code_tokens == hyp_code_tokens
예제 #3
0
    def hyp_bleu(self, hyp, example):
        ref_code = example.tgt_code
        ref_py_ast = ast.parse(ref_code)
        ref_reformatted_code = astor.to_source(ref_py_ast).strip()

        ref_code_tokens = tokenize_code(ref_reformatted_code)
        hyp_code_tokens = tokenize_code(hyp.code)

        return bt.compute_bleu(ref_code_tokens, hyp_code_tokens)
예제 #4
0
    def hyp_correct(self, hyp, example):
        ref_code = example.tgt_code
        ref_py_ast = ast.parse(ref_code).body[0]
        ref_reformatted_code = astor.to_source(ref_py_ast).strip()

        ref_code_tokens = tokenize_code(ref_reformatted_code)
        hyp_code_tokens = tokenize_code(hyp.code)

        return ref_code_tokens == hyp_code_tokens
예제 #5
0
def load_code_data(file_path):
    # read in lines of Python code to train a prior
    processed_code = []
    for line in open(file_path):
        raw_code = line.strip()
        # perform canonicalization same as how we pre-process the dataset
        if raw_code:
            code = Django.canonicalize_code(raw_code)
            try:
                # use the astor-style code
                py_ast = ast.parse(code).body[0]
                code = astor.to_source(py_ast).strip()
                code_tokens = tokenize_code(code, mode='canonicalize')
                if len(code_tokens) < 50:
                    processed_code.append({
                        'code': code,
                        'tokens': code_tokens
                    })
            except:
                print("Exception in reading line: %s" % raw_code,
                      file=sys.stdout)
                print('-' * 60, file=sys.stdout)
                traceback.print_exc(file=sys.stdout)
                print('-' * 60, file=sys.stdout)

    return processed_code
예제 #6
0
    def parse_django_dataset(annot_file,
                             code_file,
                             asdl_file_path,
                             max_query_len=70,
                             vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = PythonTransitionSystem(grammar)

        loaded_examples = []

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        for idx, (src_query,
                  tgt_code) in enumerate(zip(open(annot_file),
                                             open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code).body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            try:
                hyp = Hypothesis()
                for t, action in enumerate(tgt_actions):
                    # assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                    # if isinstance(action, ApplyRuleAction):
                    #     assert action.production in transition_system.get_valid_continuating_productions(hyp)

                    p_t = -1
                    f_t = None
                    if hyp.frontier_node:
                        p_t = hyp.frontier_node.created_time
                        f_t = hyp.frontier_field.field.__repr__(plain=True)

                    # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                    hyp = hyp.clone_and_apply_action(action)

                assert hyp.frontier_node is None and hyp.frontier_field is None

                src_from_hyp = astor.to_source(
                    asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
                assert src_from_hyp == gold_source

                # print('+' * 60)
            except:
                continue

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            if 0 <= idx < 16000:
                train_examples.append(example)
            elif 16000 <= idx < 17000:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab
예제 #7
0
 def tokenize_code(self, code, mode=None):
     return tokenize_code(code, mode)
예제 #8
0
    def parse_django_dataset(annot_file, code_file, asdl_file_path, max_query_len=70, vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = PythonTransitionSystem(grammar)

        loaded_examples = []

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        for idx, (src_query, tgt_code) in enumerate(zip(open(annot_file), open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code).body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            print('+' * 60)
            print('Example: %d' % idx)
            print('Source: %s' % ' '.join(src_query_tokens))
            if str_map:
                print('Original String Map:')
                for str_literal, str_repr in str_map.items():
                    print('\t%s: %s' % (str_literal, str_repr))
            print('Code:\n%s' % gold_source)
            print('Actions:')

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(hyp)

                p_t = -1
                f_t = None
                if hyp.frontier_node:
                    p_t = hyp.frontier_node.created_time
                    f_t = hyp.frontier_field.field.__repr__(plain=True)

                print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None

            src_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
            assert src_from_hyp == gold_source

            print('+' * 60)

            loaded_examples.append({'src_query_tokens': src_query_tokens,
                                    'tgt_canonical_code': gold_source,
                                    'tgt_ast': tgt_ast,
                                    'tgt_actions': tgt_actions,
                                    'raw_code': tgt_code, 'str_map': str_map})

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={'raw_code': e['raw_code'], 'str_map': e['str_map']})

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            if 0 <= idx < 16000:
                train_examples.append(example)
            elif 16000 <= idx < 17000:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' % len(list(filter(lambda x: x > 100, action_len))), file=sys.stderr)

        src_vocab = VocabEntry.from_corpus([e.src_sent for e in train_examples], size=5000, freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [map(lambda a: a.action.token,
                            filter(lambda a: isinstance(a.action, GenTokenAction), e.tgt_actions))
                            for e in train_examples]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens, size=5000, freq_cutoff=vocab_freq_cutoff)
        assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [tokenize_code(e.tgt_code, mode='decoder') for e in train_examples]
        code_vocab = VocabEntry.from_corpus(code_tokens, size=5000, freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab, primitive=primitive_vocab, code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab
예제 #9
0
 def tokenize_code(self, code, mode=None):
     return tokenize_code(code, mode)