Пример #1
0
    def infer(self, examples):
        # currently use beam search as sampling method
        # set model to evaluation model for beam search, make sure dropout is properly behaving!
        was_training = self.encoder.training
        self.encoder.eval()

        hypotheses = [self.encoder.parse(e.src_sent, beam_size=self.args.sample_size) for e in examples]

        if len(hypotheses) == 0:
            raise ValueError('No candidate hypotheses.')

        if was_training: self.encoder.train()

        # some source may not have corresponding samples, so we only retain those that have sampled logical forms
        sampled_examples = []
        for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)):
            for hyp_id, hyp in enumerate(hyps):
                try:
                    code = self.transition_system.ast_to_surface_code(hyp.tree)
                    self.transition_system.tokenize_code(code)  # make sure the code is tokenizable!
                    sampled_example = Example(idx='%d-sample%d' % (example.idx, hyp_id),
                                              src_sent=example.src_sent,
                                              tgt_code=code,
                                              tgt_actions=hyp.action_infos,
                                              tgt_ast=hyp.tree)
                    sampled_examples.append(sampled_example)
                except:
                    print("Exception in converting tree to code:", file=sys.stdout)
                    print('-' * 60, file=sys.stdout)
                    traceback.print_exc(file=sys.stdout)
                    print('-' * 60, file=sys.stdout)

        sample_scores, enc_states = self.encoder.score(sampled_examples, return_enc_state=True)

        return sampled_examples, sample_scores, enc_states
Пример #2
0
def load_dataset(transition_system, dataset_file, reorder_predicates=True):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        lf = parse_lambda_expr(tgt_code)
        assert lf.to_string() == tgt_code

        if reorder_predicates:
            ordered_lf = get_canonical_order_of_logical_form(
                lf, order_by='alphabet')
            assert ordered_lf == lf
            lf = ordered_lf

        gold_source = lf.to_string()

        tgt_ast = logical_form_to_ast(grammar, lf)
        reconstructed_lf = ast_to_logical_form(tgt_ast)
        assert lf == reconstructed_lf

        tgt_actions = transition_system.get_actions(tgt_ast)

        print(idx)
        print('Utterance: %s' % src_query)
        print('Reference: %s' % tgt_code)
        # print('===== Actions =====')
        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)
            hyp = hyp.clone_and_apply_action(action)
            # print(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        src_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert src_from_hyp == gold_source

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        # print(' '.join(src_query_tokens))
        print('***')
        print(lf.to_string())
        print()
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=gold_source,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
Пример #3
0
    def infer_backward(self, examples, sample_size):
        # get samples of p_{\theta}(x|z) on evaluating mode
        was_training = self.decoder.training
        self.decoder.eval()
        hypotheses = [
            self.decoder.sample(e.tgt_code, cuda=self.args.cuda)
            for e in examples
        ]
        if len(hypotheses) == 0:
            raise ValueError('No candidate hypotheses.')

        if was_training: self.decoder.train()
        sampled_examples = []
        picked_example = []
        for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)):
            if len(hyps) != sample_size:
                print('not valid sample size expected %d but %d' %
                      (sample_size, len(hyps)))
            samples_temp = []
            for hyp_id, hyp in enumerate(hyps):
                if len(hyp) == 0:  # generated x is null
                    continue
                try:
                    sampled_example = Example(idx='%d-resample%d' %
                                              (example.idx, hyp_id),
                                              src_sent=hyp,
                                              tgt_code=example.tgt_code,
                                              tgt_actions=example.tgt_actions,
                                              tgt_ast=example.tgt_ast)
                    samples_temp.append(sampled_example)
                except:
                    pass
            if len(
                    samples_temp
            ) < sample_size / 2.0:  #valid samples too little, skip this example
                continue
            if len(
                    samples_temp
            ) < sample_size:  #valid samples less than sample size, padding by repeating valid samples
                samples_temp += samples_temp[:sample_size - len(samples_temp)]
            assert len(samples_temp) == sample_size
            # initialize cached samples with the first sample
            if hasattr(example, 'recache'):
                sample_final = [example.recache]
            else:
                sample_final = samples_temp[:1]
            sample_final += samples_temp
            sampled_examples += sample_final
            picked_example.append(example)
        index = sorted(range(len(sampled_examples)),
                       key=lambda i: -len(sampled_examples[i].src_sent))
        back_index = [(index[i], i) for i in index]
        back_index.sort(key=lambda x: x[0])
        assert [index[x[1]]
                for x in back_index] == list(range(len(sampled_examples)))
        back_index = np.array([x[1] for x in back_index])
        sampled_examples = [sampled_examples[i] for i in index]
        sample_scores = self.decoder.score(sampled_examples)
        return sampled_examples, sample_scores, back_index, picked_example
Пример #4
0
    def infer(self, examples, sample_size):
        # get samples of q_{\phi}(z|x) on evaluating mode
        was_training = self.encoder.training
        self.encoder.eval()

        hypotheses = [self.encoder.parse_sample(e.src_sent) for e in examples]

        if len(hypotheses) == 0:
            raise ValueError('No candidate hypotheses.')

        if was_training: self.encoder.train()

        # some source may not have corresponding samples, so we only retain those that have sampled logical forms
        sampled_examples = []
        picked_example = []
        for e_id, (example, hyps) in enumerate(zip(examples, hypotheses)):
            if len(hyps) != sample_size:
                print('not valid sample size expected %d but %d' %
                      (sample_size, len(hyps)))
            samples_temp = []
            for hyp_id, hyp in enumerate(hyps):

                try:
                    code = self.transition_system.ast_to_surface_code(hyp.tree)
                    self.transition_system.tokenize_code(
                        code)  # make sure the code is tokenizable!
                    if len(code) == 0:
                        continue
                    sampled_example = Example(idx='%d-sample%d' %
                                              (example.idx, hyp_id),
                                              src_sent=example.src_sent,
                                              tgt_code=code,
                                              tgt_actions=hyp.action_infos,
                                              tgt_ast=hyp.tree)
                    samples_temp.append(sampled_example)
                except:
                    pass
            if len(
                    samples_temp
            ) < sample_size / 2.0:  #valid samples too little, skip this example
                continue
            if len(
                    samples_temp
            ) < sample_size:  #valid samples less than sample size, padding by repeating valid samples
                samples_temp += samples_temp[:sample_size - len(samples_temp)]
            assert len(samples_temp) == sample_size
            # initialize cached samples with the first sample
            if hasattr(example, 'cache'):
                sample_final = [example.cache]
            else:
                sample_final = samples_temp[:1]
            sample_final += samples_temp
            sampled_examples += sample_final
            picked_example.append(example)
        sample_scores = self.encoder.score(sampled_examples)

        return sampled_examples, sample_scores, picked_example
Пример #5
0
    def load_regex_dataset(transition_system, split):
        prefix = 'data/regex/'
        src_file = join(prefix, "src-{}.txt".format(split))
        spec_file = join(prefix, "spec-{}.txt".format(split))

        examples = []
        for idx, (src_line,
                  spec_line) in enumerate(zip(open(src_file),
                                              open(spec_file))):
            print(idx)

            src_line = src_line.rstrip()
            spec_line = spec_line.rstrip()
            src_toks = src_line.split()

            spec_toks = spec_line.rstrip().split()
            spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks)

            # sanity check
            reconstructed_expr = transition_system.ast_to_surface_code(
                spec_ast)
            print(spec_line, reconstructed_expr)
            assert spec_line == reconstructed_expr

            tgt_actions = transition_system.get_actions(spec_ast)

            # sanity check
            hyp = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            assert is_equal_ast(hyp.tree, spec_ast)

            expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
            assert expr_from_hyp == spec_line

            tgt_action_infos = get_action_infos(src_toks, tgt_actions)
            example = Example(idx=idx,
                              src_sent=src_toks,
                              tgt_actions=tgt_action_infos,
                              tgt_code=spec_line,
                              tgt_ast=spec_ast,
                              meta=None)

            examples.append(example)
        return examples
Пример #6
0
def load_dataset(transition_system, path, num, reorder_predicates=True):
    grammar = transition_system.grammar

    examples = []
    pre_len = 0
    if os.path.exists('data/pdf/train.bin'):
        examples = pickle.load(open('data/pdf/train.bin', 'rb'))
        pre_len = len(examples)

    idx = 0
    for item in os.listdir(path):
        item_path = os.path.join(path, item)
        print(item)

        try:
            pdf = PdfReader(item_path)
        except:
            continue

        for page in pdf.pages:
            idx += 1
            if idx <= pre_len:
                continue
            print(idx)

            try:
                tgt_ast = pdf_to_ast(grammar, page, [])
            except:
                continue

            tgt_actions = transition_system.get_actions(tgt_ast)

            """
            hyp = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(hyp)
                hyp = hyp.clone_and_apply_action(action)
            assert hyp.frontier_node is None and hyp.frontier_field is None 
            """

            tgt_action_infos = get_action_infos(tgt_actions)

            example = Example(idx=idx, tgt_actions=tgt_action_infos, meta=None)

            examples.append(example)

        if idx >= num:
            break

    return examples
Пример #7
0
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        print(line)
        src_query, tgt_code = line.strip().split('~')

        tgt_code = tgt_code.replace("("," ( ")
        tgt_code = tgt_code.replace(")"," ) ")
        tgt_code = " ".join(tgt_code.split())
        src_query = src_query.replace("(","")
        src_query = src_query.replace(")","")
        src_query_tokens = src_query.split(' ')

        tgt_ast = lisp_expr_to_ast(transition_system.grammar, tgt_code)
        reconstructed_lisp_expr = ast_to_lisp_expr(tgt_ast)
        assert tgt_code == reconstructed_lisp_expr

        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(hyp)
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        assert is_equal_ast(hyp.tree, tgt_ast)

        expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert expr_from_hyp == tgt_code

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=tgt_code,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
Пример #8
0
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        lf = parse_lambda_expr(tgt_code)
        gold_source = lf.to_string()
        assert gold_source == tgt_code
        tgt_ast = logical_form_to_ast(grammar, lf)
        reconstructed_lf = ast_to_logical_form(tgt_ast)
        assert lf == reconstructed_lf

        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        src_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert src_from_hyp == gold_source

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=gold_source,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
Пример #9
0
    def initialize_rerank_features(self, examples, decode_results):
        hyp_examples = []
        # print('initializing features...', file=sys.stderr)
        for example, hyps in zip(examples, decode_results):
            for hyp_id, hyp in enumerate(hyps):
                hyp_example = Example(idx=None,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=None,
                                      tgt_ast=None)
                hyp_examples.append(hyp_example)
                # hyp.tokenized_code = len(self.transition_system.tokenize_code(hyp.code))
                # hyp.code_token_count = len(hyp.code.split(' '))

                feat_vals = OrderedDict()
                hyp.rerank_feature_values = feat_vals

        for batch_examples in utils.batch_iter(hyp_examples, batch_size=128):
            for feat_name, feat in self.batched_features.items():
                batch_example_scores = feat.score(
                    batch_examples).data.cpu().tolist()
                for i, e in enumerate(batch_examples):
                    setattr(e, feat_name, batch_example_scores[i])

        e_ptr = 0
        for example, hyps in zip(examples, decode_results):
            for hyp in hyps:
                for feat_name, feat in self.batched_features.items():
                    hyp.rerank_feature_values[feat_name] = getattr(
                        hyp_examples[e_ptr], feat_name)
                e_ptr += 1

        for example, hyps in zip(examples, decode_results):
            for hyp_id, hyp in enumerate(hyps):
                for feat_name, feat in self.feat_map.items():
                    if not feat.is_batched:
                        feat_val = feat.get_feat_value(
                            example,
                            hyp,
                            hyp_id=hyp_id,
                            all_hyps=hyps,
                            transition_system=self.transition_system)
                        hyp.rerank_feature_values[feat_name] = feat_val
Пример #10
0
def load_dataset(split, transition_system):

    prefix = 'data/turk/'
    src_file = join(prefix, "src-{}.txt".format(split))
    spec_file = join(prefix, "spec-{}.txt".format(split))

    examples = []
    for idx, (src_line,
              spec_line) in enumerate(zip(open(src_file), open(spec_file))):
        print(idx)

        src_line = src_line.rstrip()
        spec_line = spec_line.rstrip()
        src_toks = src_line.split()

        spec_toks = spec_line.rstrip().split()
        spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks)

        # sanity check
        reconstructed_expr = transition_system.ast_to_surface_code(spec_ast)
        print(spec_line, reconstructed_expr)
        assert spec_line == reconstructed_expr

        tgt_action_tree = transition_system.get_action_tree(spec_ast)

        # sanity check
        ast_from_action = transition_system.build_ast_from_actions(
            tgt_action_tree)
        assert is_equal_ast(ast_from_action, spec_ast)

        expr_from_hyp = transition_system.ast_to_surface_code(ast_from_action)
        assert expr_from_hyp == spec_line
        # sanity check
        # tgt_action_infos = get_action_infos(src_toks, tgt_actions)
        example = Example(idx=idx,
                          src_toks=src_toks,
                          tgt_actions=tgt_action_tree,
                          tgt_toks=spec_toks,
                          tgt_ast=spec_ast,
                          meta=None)

        examples.append(example)
    return examples
Пример #11
0
    def parse(self, utterance, debug=False):
        utterance = utterance.strip()
        processed_utterance_tokens, utterance_meta = self.example_processor.pre_process_utterance(
            utterance)
        if debug:
            print(processed_utterance_tokens)
            print(utterance_meta)
        examples = [
            Example(idx=None,
                    src_sent=processed_utterance_tokens,
                    tgt_code=None,
                    tgt_actions=None,
                    tgt_ast=None)
        ]
        hypotheses = self.parser.parse(processed_utterance_tokens,
                                       beam_size=self.beam_size,
                                       debug=debug)

        if self.reranker:
            hypotheses = self.decode_tree_to_code(hypotheses)
            hypotheses = self.reranker.rerank_hypotheses(
                examples, [hypotheses])[0]

        valid_hypotheses = list(
            filter(
                lambda hyp: self.parser.transition_system.is_valid_hypothesis(
                    hyp), hypotheses))
        for hyp in valid_hypotheses:
            self.example_processor.post_process_hypothesis(hyp, utterance_meta)

        if debug:
            for hyp_id, hyp in enumerate(valid_hypotheses):
                print('------------------ Hypothesis %d ------------------' %
                      hyp_id)
                print(hyp.code)
                print(hyp.tree.to_string())
                print('Actions:')
                for action_t in hyp.action_infos:
                    print(action_t.action)

        return valid_hypotheses
Пример #12
0
    def parse_django_dataset(annot_file,
                             code_file,
                             asdl_file_path,
                             max_query_len=70,
                             vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = PythonTransitionSystem(grammar)

        loaded_examples = []

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        for idx, (src_query,
                  tgt_code) in enumerate(zip(open(annot_file),
                                             open(code_file))):
            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Django.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code).body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            try:
                hyp = Hypothesis()
                for t, action in enumerate(tgt_actions):
                    # assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                    # if isinstance(action, ApplyRuleAction):
                    #     assert action.production in transition_system.get_valid_continuating_productions(hyp)

                    p_t = -1
                    f_t = None
                    if hyp.frontier_node:
                        p_t = hyp.frontier_node.created_time
                        f_t = hyp.frontier_field.field.__repr__(plain=True)

                    # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                    hyp = hyp.clone_and_apply_action(action)

                assert hyp.frontier_node is None and hyp.frontier_field is None

                src_from_hyp = astor.to_source(
                    asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
                assert src_from_hyp == gold_source

                # print('+' * 60)
            except:
                continue

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            if 0 <= idx < 16000:
                train_examples.append(example)
            elif 16000 <= idx < 17000:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab
Пример #13
0
    def parse_natural_dataset(asdl_file_path,
                              max_query_len=70,
                              vocab_freq_cutoff=10):
        asdl_text = open(asdl_file_path).read()
        print('building grammar')
        grammar = ASDLGrammar.from_text(asdl_text)
        transition_system = Python3TransitionSystem(grammar)

        loaded_examples = []

        annotations = []
        codes = []
        path = os.path.join(os.path.dirname(__file__), "datagen")
        datagens = os.listdir(path)
        for folder in datagens:
            if "__" in folder or not os.path.isdir(os.path.join(path, folder)):
                continue
            with open(os.path.join(path, folder, "inputs.txt"), 'r') as file:
                annotations += file.read().split('\n')
            with open(os.path.join(path, folder, "outputs.txt"), 'r') as file:
                codes += file.read().split('\n')
        annotation_codes = list(zip(annotations, codes))
        np.random.seed(42)
        np.random.shuffle(annotation_codes)

        from components.vocab import Vocab, VocabEntry
        from components.dataset import Example

        print('processing examples')
        for idx, (src_query, tgt_code) in enumerate(annotation_codes):
            if (idx % 100 == 0):
                sys.stdout.write("\r%s / %s" % (idx, len(annotation_codes)))
                sys.stdout.flush()

            src_query = src_query.strip()
            tgt_code = tgt_code.strip()

            src_query_tokens, tgt_canonical_code, str_map = Natural.canonicalize_example(
                src_query, tgt_code)
            python_ast = ast.parse(tgt_canonical_code)  #.body[0]
            gold_source = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)
            # print('+' * 60)
            # print('Example: %d' % idx)
            # print('Source: %s' % ' '.join(src_query_tokens))
            # if str_map:
            #     print('Original String Map:')
            #     for str_literal, str_repr in str_map.items():
            #         print('\t%s: %s' % (str_literal, str_repr))
            # print('Code:\n%s' % gold_source)
            # print('Actions:')

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # assert action.__class__ in transition_system.get_valid_continuation_types(
                # hyp)

                p_t = -1
                f_t = None
                if hyp.frontier_node:
                    p_t = hyp.frontier_node.created_time
                    f_t = hyp.frontier_field.field.__repr__(plain=True)

                # print('\t[%d] %s, frontier field: %s, parent: %d' %
                #     (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            # assert hyp.frontier_node is None and hyp.frontier_field is None

            src_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, grammar)).strip()
            if "b'" not in str(gold_source) and 'b"' not in str(gold_source):
                assert src_from_hyp == gold_source

            # print('+' * 60)

            loaded_examples.append({
                'src_query_tokens': src_query_tokens,
                'tgt_canonical_code': gold_source,
                'tgt_ast': tgt_ast,
                'tgt_actions': tgt_actions,
                'raw_code': tgt_code,
                'str_map': str_map
            })

            # print('first pass, processed %d' % idx, file=sys.stderr)

        train_examples = []
        dev_examples = []
        test_examples = []

        action_len = []

        print("\nsplitting train/dev/test")
        for idx, e in enumerate(loaded_examples):
            src_query_tokens = e['src_query_tokens'][:max_query_len]
            tgt_actions = e['tgt_actions']
            tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_query_tokens,
                              tgt_actions=tgt_action_infos,
                              tgt_code=e['tgt_canonical_code'],
                              tgt_ast=e['tgt_ast'],
                              meta={
                                  'raw_code': e['raw_code'],
                                  'str_map': e['str_map']
                              })

            # print('second pass, processed %d' % idx, file=sys.stderr)

            action_len.append(len(tgt_action_infos))

            # train, valid, test split
            total_examples = len(loaded_examples)
            split_size = np.ceil(total_examples * 0.05)
            (dev_split, test_split) = (total_examples - split_size * 2,
                                       total_examples - split_size)
            if 0 <= idx < dev_split:
                train_examples.append(example)
            elif dev_split <= idx < test_split:
                dev_examples.append(example)
            else:
                test_examples.append(example)

        print('Max action len: %d' % max(action_len), file=sys.stderr)
        print('Avg action len: %d' % np.average(action_len), file=sys.stderr)
        print('Actions larger than 100: %d' %
              len(list(filter(lambda x: x > 100, action_len))),
              file=sys.stderr)

        src_vocab = VocabEntry.from_corpus(
            [e.src_sent for e in train_examples],
            size=5000,
            freq_cutoff=vocab_freq_cutoff)

        primitive_tokens = [
            map(
                lambda a: a.action.token,
                filter(lambda a: isinstance(a.action, GenTokenAction),
                       e.tgt_actions)) for e in train_examples
        ]

        primitive_vocab = VocabEntry.from_corpus(primitive_tokens,
                                                 size=5000,
                                                 freq_cutoff=vocab_freq_cutoff)
        # assert '_STR:0_' in primitive_vocab

        # generate vocabulary for the code tokens!
        code_tokens = [
            tokenize_code(e.tgt_code, mode='decoder') for e in train_examples
        ]
        code_vocab = VocabEntry.from_corpus(code_tokens,
                                            size=5000,
                                            freq_cutoff=vocab_freq_cutoff)

        vocab = Vocab(source=src_vocab,
                      primitive=primitive_vocab,
                      code=code_vocab)
        print('generated vocabulary %s' % repr(vocab), file=sys.stderr)

        return (train_examples, dev_examples, test_examples), vocab
Пример #14
0
    def load_regex_dataset(transition_system, split):
        prefix = 'data/streg/'
        src_file = join(prefix, "src-{}.txt".format(split))
        spec_file = join(prefix, "targ-{}.txt".format(split))
        map_file = join(prefix, "map-{}.txt".format(split))
        exs_file = join(prefix, "exs-{}.txt".format(split))
        rec_file = join(prefix, "rec-{}.pkl".format(split))

        exs_info = StReg.load_examples(exs_file)
        map_info = StReg.load_map_file(map_file)
        rec_info = StReg.load_rec(rec_file)

        examples = []
        for idx, (src_line, spec_line, str_exs, cmap, rec) in enumerate(
                zip(open(src_file), open(spec_file), exs_info, map_info,
                    rec_info)):
            print(idx)

            src_line = src_line.rstrip()
            spec_line = spec_line.rstrip()
            src_toks = src_line.split()

            spec_toks = spec_line.rstrip().split()
            spec_ast = streg_expr_to_ast(transition_system.grammar, spec_toks)
            # sanity check
            reconstructed_expr = transition_system.ast_to_surface_code(
                spec_ast)
            # print("Spec", spec_line)
            # print("Rcon", reconstructed_expr)
            assert spec_line == reconstructed_expr

            tgt_actions = transition_system.get_actions(spec_ast)
            # sanity check
            hyp = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            assert is_equal_ast(hyp.tree, spec_ast)

            expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
            assert expr_from_hyp == spec_line

            tgt_action_infos = get_action_infos(src_toks, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_toks,
                              tgt_actions=tgt_action_infos,
                              tgt_code=spec_line,
                              tgt_ast=spec_ast,
                              meta={
                                  "str_exs": str_exs,
                                  "const_map": cmap,
                                  "worker_info": rec
                              })
            examples.append(example)
        return examples
Пример #15
0
def preprocess_dataset(file_path, transition_system, name='train'):
    dataset = json.load(open(file_path))
    examples = []
    evaluator = ConalaEvaluator(transition_system)

    f = open(file_path + '.debug', 'w')

    for i, example_json in enumerate(dataset):
        example_dict = preprocess_example(example_json)
        if example_json['question_id'] in (18351951, 9497290, 19641579,
                                           32283692):
            pprint(preprocess_example(example_json))
            continue

        python_ast = ast.parse(example_dict['canonical_snippet'])
        canonical_code = astor.to_source(python_ast).strip()
        tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for t, action in enumerate(tgt_actions):
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)

            p_t = -1
            f_t = None
            if hyp.frontier_node:
                p_t = hyp.frontier_node.created_time
                f_t = hyp.frontier_field.field.__repr__(plain=True)

            # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None
        hyp.code = code_from_hyp = astor.to_source(
            asdl_ast_to_python_ast(hyp.tree,
                                   transition_system.grammar)).strip()
        assert code_from_hyp == canonical_code

        decanonicalized_code_from_hyp = decanonicalize_code(
            code_from_hyp, example_dict['slot_map'])
        assert compare_ast(ast.parse(example_json['snippet']),
                           ast.parse(decanonicalized_code_from_hyp))
        assert transition_system.compare_ast(
            transition_system.surface_code_to_ast(
                decanonicalized_code_from_hyp),
            transition_system.surface_code_to_ast(example_json['snippet']))

        tgt_action_infos = get_action_infos(example_dict['intent_tokens'],
                                            tgt_actions)

        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # log!
        f.write(f'Example: {example.idx}\n')
        f.write(
            f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n"
        )
        f.write(
            f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        f.write(f"\n")
        f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        f.write(f"Snippet: {example.tgt_code}\n")
        f.write(
            f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
        )

    f.close()

    return examples
def preprocess_dataset(file_path, transition_system, name='train', firstk=None):
    file_path = os.path.join(os.getcwd(), *file_path.split('/' if '/' in file_path else "\\"))

    try:
        dataset = json.load(open(file_path))
    except:
        dataset = [json.loads(jline) for jline in open(file_path).readlines()]
    if firstk:
        dataset = dataset[:firstk]
    examples = []
    evaluator = ConalaEvaluator(transition_system)

    # Author: Gabe
    # Added in encoding to try and deal with UnicodeEncodeErrors
    f = open(file_path + '.debug', 'w', encoding='utf-8')

    skipped_list = []
    for i, example_json in tqdm(enumerate(dataset), file=sys.stdout, total=len(dataset),
                                desc='Preproc'):

        # Author: Gabe
        # Have to skip this one question because it causes the program to hang and never recover.
        if example_json['question_id'] in [39525993]:
            skipped_list.append(example_json['question_id'])
            tqdm.write(f"Skipping {example_json['question_id']} because it causes errors")
            continue
        try:
            example_dict = preprocess_example(example_json)

            python_ast = ast.parse(example_dict['canonical_snippet'])
            canonical_code = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in \
                           transition_system.get_valid_continuating_productions(
                               hyp)
                # p_t = -1
                # f_t = None
                # if hyp.frontier_node:
                #     p_t = hyp.frontier_node.created_time
                #     f_t = hyp.frontier_field.field.__repr__(plain=True)
                #
                # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            hyp.code = code_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
            # print(code_from_hyp)
            # print(canonical_code)
            assert code_from_hyp == canonical_code

            decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp,
                                                                example_dict['slot_map'])
            assert compare_ast(ast.parse(example_json['snippet']),
                               ast.parse(decanonicalized_code_from_hyp))
            assert transition_system.compare_ast(
                transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
                transition_system.surface_code_to_ast(example_json['snippet']))

            tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)
        except (AssertionError, SyntaxError, ValueError, OverflowError) as e:
            skipped_list.append(example_json['question_id'])
            tqdm.write(
                f"Skipping example {example_json['question_id']} because of {type(e).__name__}:{e}"
            )
            continue
        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # Author: Gabe
        # Had to remove logging, when the log file would get too large, it would cause the
        # program to hang.

        # log!
        # f.write(f'Example: {example.idx}\n')
        # if 'rewritten_intent' in example.meta['example_dict']:
        #     f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
        # else:
        #     f.write(f"Original Utterance: {example.meta['example_dict']['intent']}\n")
        # f.write(f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        # f.write(f"\n")
        # f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        # f.write(f"Snippet: {example.tgt_code}\n")
        # f.write(f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")

    f.close()
    print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr)
    return examples
Пример #17
0
def self_training(args):
    """Perform self-training

    First load decoding results on disjoint data
    also load pre-trained model and perform supervised
    training on both existing training data and the
    decoded results
    """

    print('load pre-trained model from [%s]' % args.load_model,
          file=sys.stderr)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']

    # transfer arguments
    saved_args.cuda = args.cuda
    saved_args.save_to = args.save_to
    saved_args.train_file = args.train_file
    saved_args.unlabeled_file = args.unlabeled_file
    saved_args.dev_file = args.dev_file
    saved_args.load_decode_results = args.load_decode_results
    args = saved_args

    update_args(args)

    model = Parser(saved_args, vocab, transition_system)
    model.load_state_dict(saved_state)

    if args.cuda: model = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr)
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)

    print('load decoding results of unlabeled data [%s]' %
          args.load_decode_results,
          file=sys.stderr)
    decode_results = pickle.load(open(args.load_decode_results))

    labeled_data = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    print('Num. examples in unlabeled data: %d' % len(unlabeled_data),
          file=sys.stderr)
    assert len(unlabeled_data) == len(decode_results)
    self_train_examples = []
    for example, hyps in zip(unlabeled_data, decode_results):
        if hyps:
            hyp = hyps[0]
            sampled_example = Example(idx='self_train-%s' % example.idx,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=hyp.action_infos,
                                      tgt_ast=hyp.tree)
            self_train_examples.append(sampled_example)
    print('Num. self training examples: %d, Num. labeled examples: %d' %
          (len(self_train_examples), len(labeled_data)),
          file=sys.stderr)

    train_set = Dataset(examples=labeled_data.examples + self_train_examples)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           model,
                                           args,
                                           verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Пример #18
0
def preprocess_dataset(file_path,
                       transition_system,
                       name='train',
                       firstk=None):
    try:
        dataset = json.load(open(file_path))
    except:
        dataset = [json.loads(jline) for jline in open(file_path).readlines()]
    if firstk:
        dataset = dataset[:firstk]
    examples = []
    evaluator = ConalaEvaluator(transition_system)
    f = open(file_path + '.debug', 'w')
    skipped_list = []
    for i, example_json in enumerate(dataset):
        try:
            example_dict = preprocess_example(example_json)

            python_ast = ast.parse(example_dict['canonical_snippet'])
            canonical_code = astor.to_source(python_ast).strip()
            tgt_ast = python_ast_to_asdl_ast(python_ast,
                                             transition_system.grammar)
            tgt_actions = transition_system.get_actions(tgt_ast)

            # sanity check
            hyp = Hypothesis()
            for t, action in enumerate(tgt_actions):
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                # p_t = -1
                # f_t = None
                # if hyp.frontier_node:
                #     p_t = hyp.frontier_node.created_time
                #     f_t = hyp.frontier_field.field.__repr__(plain=True)
                #
                # # print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            hyp.code = code_from_hyp = astor.to_source(
                asdl_ast_to_python_ast(hyp.tree,
                                       transition_system.grammar)).strip()
            # print(code_from_hyp)
            # print(canonical_code)
            assert code_from_hyp == canonical_code

            decanonicalized_code_from_hyp = decanonicalize_code(
                code_from_hyp, example_dict['slot_map'])
            assert compare_ast(ast.parse(example_json['snippet']),
                               ast.parse(decanonicalized_code_from_hyp))
            assert transition_system.compare_ast(
                transition_system.surface_code_to_ast(
                    decanonicalized_code_from_hyp),
                transition_system.surface_code_to_ast(example_json['snippet']))

            tgt_action_infos = get_action_infos(example_dict['intent_tokens'],
                                                tgt_actions)
        except (AssertionError, SyntaxError, ValueError, OverflowError) as e:
            skipped_list.append(example_json['question_id'])
            continue
        example = Example(idx=f'{i}-{example_json["question_id"]}',
                          src_sent=example_dict['intent_tokens'],
                          tgt_actions=tgt_action_infos,
                          tgt_code=canonical_code,
                          tgt_ast=tgt_ast,
                          meta=dict(example_dict=example_json,
                                    slot_map=example_dict['slot_map']))
        assert evaluator.is_hyp_correct(example, hyp)

        examples.append(example)

        # log!
        f.write(f'Example: {example.idx}\n')
        if 'rewritten_intent' in example.meta['example_dict']:
            f.write(
                f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n"
            )
        else:
            f.write(
                f"Original Utterance: {example.meta['example_dict']['intent']}\n"
            )
        f.write(
            f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
        f.write(f"\n")
        f.write(f"Utterance: {' '.join(example.src_sent)}\n")
        f.write(f"Snippet: {example.tgt_code}\n")
        f.write(
            f"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
        )

    f.close()
    print('Skipped due to exceptions: %d' % len(skipped_list), file=sys.stderr)
    return examples