Exemple #1
0
    def run():
        asdl_text = open('asdl/lang/py3/py3_asdl.simplified.txt').read()
        grammar = ASDLGrammar.from_text(asdl_text)

        annot_file = 'data/natural/all.anno'
        code_file = 'data/natural/all.code'

        transition_system = Python3TransitionSystem(grammar)

        for idx, (src_query,
                  tgt_code) in enumerate(zip(open(annot_file),
                                             open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast)
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            hyp2 = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)
                hyp2.apply_action(action)

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar))
            assert src_from_hyp == gold_source
            assert hyp.tree == hyp2.tree and hyp.tree is not hyp2.tree

            print(idx)
Exemple #2
0
def preprocess_dataset(file_path,
                       transition_system,
                       name='train',
                       firstk=None):
    try:
        dataset = json.load(open(file_path))
    except:
        dataset = [json.loads(jline) for jline in open(file_path).readlines()]
    if firstk:
        dataset = dataset[:firstk]
    examples = []
    evaluator = ConalaEvaluator(transition_system)
    f = open(file_path + '.debug', 'w')
    skipped_list = []
    for i, example_json in enumerate(dataset):
        try:
            example_dict = preprocess_example(example_json)

            python_ast = ast.parse(example_dict['canonical_snippet'])
            canonical_code = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # p_t = -1
                # f_t = None
                # if hyp.frontier_node:
                #     p_t = hyp.frontier_node.created_time
                #     f_t = hyp.frontier_field.field.__repr__(plain=True)
                #
                # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            hyp.code = code_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree,
                                       transition_system.grammar)).strip()
            # print(code_from_hyp)
            # print(canonical_code)
            assert code_from_hyp == canonical_code

            decanonicalized_code_from_hyp = decanonicalize_code(
                code_from_hyp, example_dict['slot_map'])
            assert compare_ast(ast.parse(example_json['snippet']),
                               ast.parse(decanonicalized_code_from_hyp))
            assert transition_system.compare_ast(
                transition_system.surface_code_to_ast(
                    decanonicalized_code_from_hyp),
                transition_system.surface_code_to_ast(example_json['snippet']))

            tgt_action_infos = get_action_infos(example_dict['intent_tokens'],
                                                tgt_actions)
        except (AssertionError, SyntaxError, ValueError, OverflowError) as e:
            skipped_list.append(example_json['question_id'])
            continue
        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # log!
        f.write(f'Example: {example.idx}\n')
        if 'rewritten_intent' in example.meta['example_dict']:
            f.write(
                f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n"
            )
        else:
            f.write(
                f"Original Utterance: {example.meta['example_dict']['intent']}\n"
            )
        f.write(
            f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        f.write(f"\n")
        f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        f.write(f"Snippet: {example.tgt_code}\n")
        f.write(
            f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
        )

    f.close()
    print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr)
    return examples
def preprocess_dataset(file_path, transition_system, name='train'):
    dataset = json.load(open(file_path))
    examples = []
    evaluator = ConalaEvaluator(transition_system)

    f = open(file_path + '.debug', 'w')

    for i, example_json in enumerate(dataset):
        example_dict = preprocess_example(example_json)
        if example_json['question_id'] in (18351951, 9497290, 19641579,
                                           32283692):
            pprint(preprocess_example(example_json))
            continue

        python_ast = ast.parse(example_dict['canonical_snippet'])
        canonical_code = astor.to_source(python_ast).strip()
        tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for t, action in enumerate(tgt_actions):
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)

            p_t = -1
            f_t = None
            if hyp.frontier_node:
                p_t = hyp.frontier_node.created_time
                f_t = hyp.frontier_field.field.__repr__(plain=True)

            # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None
        hyp.code = code_from_hyp = astor.to_source(
            asdl_ast_to_python_ast(hyp.tree,
                                   transition_system.grammar)).strip()
        assert code_from_hyp == canonical_code

        decanonicalized_code_from_hyp = decanonicalize_code(
            code_from_hyp, example_dict['slot_map'])
        assert compare_ast(ast.parse(example_json['snippet']),
                           ast.parse(decanonicalized_code_from_hyp))
        assert transition_system.compare_ast(
            transition_system.surface_code_to_ast(
                decanonicalized_code_from_hyp),
            transition_system.surface_code_to_ast(example_json['snippet']))

        tgt_action_infos = get_action_infos(example_dict['intent_tokens'],
                                            tgt_actions)

        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # log!
        f.write(f'Example: {example.idx}\n')
        f.write(
            f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n"
        )
        f.write(
            f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        f.write(f"\n")
        f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        f.write(f"Snippet: {example.tgt_code}\n")
        f.write(
            f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
        )

    f.close()

    return examples
def preprocess_dataset(file_path, transition_system, name='train', firstk=None):
    file_path = os.path.join(os.getcwd(), *file_path.split('/' if '/' in file_path else "\\"))

    try:
        dataset = json.load(open(file_path))
    except:
        dataset = [json.loads(jline) for jline in open(file_path).readlines()]
    if firstk:
        dataset = dataset[:firstk]
    examples = []
    evaluator = ConalaEvaluator(transition_system)

    # Author: Gabe
    # Added in encoding to try and deal with UnicodeEncodeErrors
    f = open(file_path + '.debug', 'w', encoding='utf-8')

    skipped_list = []
    for i, example_json in tqdm(enumerate(dataset), file=sys.stdout, total=len(dataset),
                                desc='Preproc'):

        # Author: Gabe
        # Have to skip this one question because it causes the program to hang and never recover.
        if example_json['question_id'] in [39525993]:
            skipped_list.append(example_json['question_id'])
            tqdm.write(f"Skipping {example_json['question_id']} because it causes errors")
            continue
        try:
            example_dict = preprocess_example(example_json)

            python_ast = ast.parse(example_dict['canonical_snippet'])
            canonical_code = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in \
                           transition_system.get_valid_continuating_productions(
                               hyp)
                # p_t = -1
                # f_t = None
                # if hyp.frontier_node:
                #     p_t = hyp.frontier_node.created_time
                #     f_t = hyp.frontier_field.field.__repr__(plain=True)
                #
                # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            hyp.code = code_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
            # print(code_from_hyp)
            # print(canonical_code)
            assert code_from_hyp == canonical_code

            decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp,
                                                                example_dict['slot_map'])
            assert compare_ast(ast.parse(example_json['snippet']),
                               ast.parse(decanonicalized_code_from_hyp))
            assert transition_system.compare_ast(
                transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
                transition_system.surface_code_to_ast(example_json['snippet']))

            tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)
        except (AssertionError, SyntaxError, ValueError, OverflowError) as e:
            skipped_list.append(example_json['question_id'])
            tqdm.write(
                f"Skipping example {example_json['question_id']} because of {type(e).__name__}:{e}"
            )
            continue
        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # Author: Gabe
        # Had to remove logging, when the log file would get too large, it would cause the
        # program to hang.

        # log!
        # f.write(f'Example: {example.idx}\n')
        # if 'rewritten_intent' in example.meta['example_dict']:
        #     f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
        # else:
        #     f.write(f"Original Utterance: {example.meta['example_dict']['intent']}\n")
        # f.write(f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        # f.write(f"\n")
        # f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        # f.write(f"Snippet: {example.tgt_code}\n")
        # f.write(f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")

    f.close()
    print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr)
    return examples
Exemple #5
0
    def parse_natural_dataset(asdl_file_path,
                              max_query_len=70,
                              vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        print('building grammar')
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = Python3TransitionSystem(grammar)

        loaded_examples = []

        annotations = []
        codes = []
        path = os.path.join(os.path.dirname(__file__), "datagen")
        datagens = os.listdir(path)
        for folder in datagens:
            if "__" in folder or not os.path.isdir(os.path.join(path, folder)):
                continue
            with open(os.path.join(path, folder, "inputs.txt"), 'r') as file:
                annotations += file.read().split('\n')
            with open(os.path.join(path, folder, "outputs.txt"), 'r') as file:
                codes += file.read().split('\n')
        annotation_codes = list(zip(annotations, codes))
        np.random.seed(42)
        np.random.shuffle(annotation_codes)

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        print('processing examples')
        for idx, (src_query, tgt_code) in enumerate(annotation_codes):
            if (idx % 100 == 0):
                sys.stdout.write("\r%s / %s" % (idx, len(annotation_codes)))
                sys.stdout.flush()

            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)
            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # assert action.__class__ in transition_system.get_valid_continuation_types(
                # hyp)

                p_t = -1
                f_t = None
                if hyp.frontier_node:
                    p_t = hyp.frontier_node.created_time
                    f_t = hyp.frontier_field.field.__repr__(plain=True)

                # print('\t[%d] %s, frontier field: %s, parent: %d' %
                #     (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            # assert hyp.frontier_node is None and hyp.frontier_field is None

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
            if "b'" not in str(gold_source) and 'b"' not in str(gold_source):
                assert src_from_hyp == gold_source

            # print('+' * 60)

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        print("\nsplitting train/dev/test")
        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            total_examples = len(loaded_examples)
            split_size = np.ceil(total_examples * 0.05)
            (dev_split, test_split) = (total_examples - split_size * 2,
                                       total_examples - split_size)
            if 0 <= idx < dev_split:
                train_examples.append(example)
            elif dev_split <= idx < test_split:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        # assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab