Пример #1
0
def decode_python_dataset(model, dataset, verbose=True):
    from lang.py.parse import decode_tree_to_python_ast
    # if verbose:
    # logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)

    decode_results = []
    cum_num = 0
    for example in dataset.examples:
        # import pdb;pdb.set_trace()
        cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
                                 beam_size=config.beam_size, max_time_step=config.decode_max_time_step)

        exg_decode_results = []
        for cid, cand in enumerate(cand_list[:10]):
            try:
                ast_tree = decode_tree_to_python_ast(cand.tree)
                code = astor.to_source(ast_tree)
                exg_decode_results.append((cid, cand, ast_tree, code))
            except:
                if verbose:
                    print "Exception in converting tree to code:"
                    print '-' * 60
                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                    traceback.print_exc(file=sys.stdout)
                    print '-' * 60

        cum_num += 1
        if cum_num % 50 == 0 and verbose:
            print '%d examples so far ...' % cum_num

        decode_results.append(exg_decode_results)

    return decode_results
Пример #2
0
def decode_python_dataset(model, dataset, verbose=True):
    from lang.py.parse import decode_tree_to_python_ast
    if verbose:
        logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)

    decode_results = []
    cum_num = 0
    for example in dataset.examples:
        cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
                                 beam_size=config.beam_size, max_time_step=config.decode_max_time_step)

        exg_decode_results = []
        for cid, cand in enumerate(cand_list[:10]):
            try:
                ast_tree = decode_tree_to_python_ast(cand.tree)
                code = astor.to_source(ast_tree)
                exg_decode_results.append((cid, cand, ast_tree, code))
            except:
                if verbose:
                    print "Exception in converting tree to code:"
                    print '-' * 60
                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                    traceback.print_exc(file=sys.stdout)
                    print '-' * 60

        cum_num += 1
        if cum_num % 50 == 0 and verbose:
            print '%d examples so far ...' % cum_num

        decode_results.append(exg_decode_results)

    return decode_results
Пример #3
0
        def handle(clientsocket):
            # while 1:
            cmd = clientsocket.recv(MAX_LENGTH)
            print cmd
            if args.mode == 'dataset':
                try:
                    example_id = int(cmd)
                    example = [
                        e for e in test_data.examples if e.raw_id == example_id
                    ][0]
                except:
                    print 'something went wrong ...'
                    # continue
            elif args.mode == 'new':
                # we play with new examples!
                query, str_map = canonicalize_query(cmd)
                vocab = train_data.annot_vocab
                query_tokens = query.split(' ')
                query_tokens_data = [query_to_data(query, vocab)]
                example = namedtuple('example',
                                     ['query', 'data'])(query=query_tokens,
                                                        data=query_tokens_data)

            if hasattr(example, 'parse_tree'):
                print 'gold parse tree:'
                print example.parse_tree

            cand_list = model.decode(example,
                                     train_data.grammar,
                                     train_data.terminal_vocab,
                                     beam_size=args.beam_size,
                                     max_time_step=args.decode_max_time_step,
                                     log=True)

            has_grammar_error = any(
                [c for c in cand_list if c.has_grammar_error])
            print 'has_grammar_error: ', has_grammar_error

            for cid, cand in enumerate(cand_list[:5]):
                print '*' * 60
                print 'cand #%d, score: %f' % (cid, cand.score)

                try:
                    ast_tree = decode_tree_to_python_ast(cand.tree)
                    code = astor.to_source(ast_tree)
                    print 'code: ', code
                    print 'decode log: ', cand.log
                except:
                    print "Exception in converting tree to code:"
                    print '-' * 60
                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                    traceback.print_exc(file=sys.stdout)
                    print '-' * 60
                finally:
                    print '* parse tree *'
                    print cand.tree.__repr__()
                    print 'n_timestep: %d' % cand.n_timestep
                    print 'ast size: %d' % cand.tree.size
                    print '*' * 60
Пример #4
0
def decode_python_dataset(model,
                          train_data,
                          dev_data,
                          test_data,
                          dataset_type="test_data",
                          retrieval=False,
                          verbose=True):
    dataset = eval(dataset_type)
    from lang.py.parse import decode_tree_to_python_ast
    if verbose:
        logging.info('decoding [%s] set, num. examples: %d', dataset.name,
                     dataset.count)

    decode_results = []
    cum_num = 0

    for example in dataset.examples:
        print 'raw_id: %d' % example.raw_id
        if retrieval:
            ngrams = retrieve_translation_pieces(train_data, example)
            cand_list = model.decode_with_retrieval(
                example,
                dataset.grammar,
                dataset.terminal_vocab,
                ngrams,
                beam_size=config.beam_size,
                max_time_step=config.decode_max_time_step,
                log=True)
        else:
            cand_list = model.decode(example,
                                     dataset.grammar,
                                     dataset.terminal_vocab,
                                     beam_size=config.beam_size,
                                     max_time_step=config.decode_max_time_step)

        exg_decode_results = []
        for cid, cand in enumerate(cand_list[:10]):
            try:
                ast_tree = decode_tree_to_python_ast(cand.tree)
                code = astor.to_source(ast_tree)
                exg_decode_results.append((cid, cand, ast_tree, code))
            except:
                if verbose:
                    print "Exception in converting tree to code:"
                    print '-' * 60
                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                    traceback.print_exc(file=sys.stdout)
                    print '-' * 60

        cum_num += 1
        if cum_num % 50 == 0 and verbose:
            print '%d examples so far ...' % cum_num

        decode_results.append(exg_decode_results)

    return decode_results
Пример #5
0
def decode_python_dataset(model, dataset, verbose=True):
    from lang.py.parse import decode_tree_to_python_ast
    if verbose:
        logging.info('decoding [%s] set, num. examples: %d', dataset.name,
                     dataset.count)

    decode_results = []
    cum_num = 0
    for example in dataset.examples:
        cand_list = model.decode(example,
                                 dataset.grammar,
                                 dataset.terminal_vocab,
                                 beam_size=config.beam_size,
                                 max_time_step=config.decode_max_time_step)
        print(len(cand_list))
        exg_decode_results = []
        #print(example.query)
        for cid, cand in enumerate(cand_list[:10]):
            #print('come in')
            try:
                code = ''
                #print(cand.tree)
                ast_tree = decode_tree_to_python_ast(cand.tree)
                #print(ast_tree)
                #print('------------------------')
                for at in ast_tree.children:
                    code += astor.to_source(at)
                #print(cid, cand)
                exg_decode_results.append((cid, cand, ast_tree, code))
            except:
                if verbose:
                    print("Exception in converting tree to code:")
                    print('-' * 60)
                    print('raw_id: %d, beam pos: %d' % (example.raw_id, cid))
                    traceback.print_exc(file=sys.stdout)
                    print('-' * 60)

        cum_num += 1
        if cum_num % 50 == 0 and verbose:
            print('%d examples so far ...' % cum_num)

        decode_results.append(exg_decode_results)

    return decode_results
Пример #6
0
def write_to_code_file(mode, data, path_to_load, path_to_export,
                       path_raw_code):
    g = data.grammar
    nt = {v: reverse_typename(k) for k, v in g.node_type_to_id.items()}

    #print(nt,g.node_type_to_id)
    v = data.terminal_vocab

    raw = []
    with open(path_raw_code, 'r') as f:
        for line in f:
            raw.append(line[:-1])

    with open(path_to_load, 'r') as f:
        l = json.load(f, encoding='utf8')
    l_code = []
    for i in range(len(l)):
        # print(raw[i])
        try:
            t = ASTNode.from_dict(l[i], nt, v)
            ast_tree = parse.decode_tree_to_python_ast(t)
            code = astor.to_source(ast_tree)[:-1]
            real_code = parse.de_canonicalize_code(code, raw[i])
            if (mode == "hs"):
                real_code = " ".join(parse.tokenize_code_adv(
                    real_code, True)).replace("\n", "#NEWLINE#").replace(
                        "#NEWLINE# ", "").replace("#INDENT# ", "")
                real_code = " ".join(parse.tokenize_code_adv(real_code, False))
            #print(real_code,raw[i])
            l_code.append(real_code)
        except:
            print "Tree %d impossible to parse" % (i)
            l_code.append("")

    with open(path_to_export, 'w') as f:
        for c in l_code:
            f.write(c + "\n")
Пример #7
0
                                     train_data.terminal_vocab,
                                     beam_size=args.beam_size,
                                     max_time_step=args.decode_max_time_step,
                                     log=True)

            has_grammar_error = any(
                [c for c in cand_list if c.has_grammar_error])
            print 'has_grammar_error: ', has_grammar_error
            query_temp, str_m = canonicalize_query(cmd)
            # print 'Query',query_temp
            for cid, cand in enumerate(cand_list[:5]):
                print '*' * 60
                print 'cand #%d, score: %f' % (cid, cand.score)

                try:
                    ast_tree = decode_tree_to_python_ast(cand.tree)
                    code = astor.to_source(ast_tree)

                    print 'code: ', code
                    print 'decode log: ', cand.log
                except:
                    print "Exception in converting tree to code:"
                    print '-' * 60
                    print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                    traceback.print_exc(file=sys.stdout)
                    print '-' * 60
                finally:
                    print '* parse tree *'
                    print cand.tree.__repr__()
                    print 'n_timestep: %d' % cand.n_timestep
                    print 'ast size: %d' % cand.tree.size
Пример #8
0
    return cand_list

if __name__ == '__main__':
    print 'run in interactive mode'
    while True:
        query = raw_input('input a query: ')
        cand_list = decode_query(query)

        # output top 5 candidates
        for cid, cand in enumerate(cand_list[:5]):
            print '*' * 60
            print 'cand #%d, score: %f' % (cid, cand.score)

            try:
                ast_tree = decode_tree_to_python_ast(cand.tree)
                code = astor.to_source(ast_tree)
                print 'code: ', code
                print 'decode log: ', cand.log
            except:
                print "Exception in converting tree to code:"
                print '-' * 60
                print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
                traceback.print_exc(file=sys.stdout)
                print '-' * 60
            finally:
                print '* parse tree *'
                print cand.tree.__repr__()
                print 'n_timestep: %d' % cand.n_timestep
                print 'ast size: %d' % cand.tree.size
                print '*' * 60
Пример #9
0
def evaluate_seq2tree_sample_file(sample_file, id_file, dataset):
    from lang.py.parse import tokenize_code, de_canonicalize_code
    import ast, astor
    import traceback
    from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes
    from lang.py.parse import decode_tree_to_python_ast

    f_sample = open(sample_file)
    line_id_to_raw_id = OrderedDict()
    raw_id_to_eid = OrderedDict()
    for i, line in enumerate(open(id_file)):
        raw_id = int(line.strip())
        line_id_to_raw_id[i] = raw_id

    for eid in range(len(dataset.examples)):
        raw_id_to_eid[dataset.examples[eid].raw_id] = eid

    rare_word_map = defaultdict(dict)
    if config.seq2tree_rareword_map:
        logging.info('use rare word map')
        for i, line in enumerate(open(config.seq2tree_rareword_map)):
            line = line.strip()
            if line:
                for e in line.split(' '):
                    d = e.split(':', 1)
                    rare_word_map[i][int(d[0])] = d[1]

    cum_bleu = 0.0
    cum_acc = 0.0
    sm = SmoothingFunction()
    convert_error_num = 0

    for i in range(len(line_id_to_raw_id)):
        # print 'working on %d' % i
        ref_repr = f_sample.readline().strip()
        predict_repr = f_sample.readline().strip()
        predict_repr = predict_repr.replace(
            '<U>', 'str{}{unk}')  # .replace('( )', '( str{}{unk} )')
        f_sample.readline()

        # if ' ( ) ' in ref_repr:
        #     print i, ref_repr

        if i in rare_word_map:
            for unk_id, w in rare_word_map[i].iteritems():
                ref_repr = ref_repr.replace(' str{}{unk_%s} ' % unk_id,
                                            ' str{}{%s} ' % w)
                predict_repr = predict_repr.replace(' str{}{unk_%s} ' % unk_id,
                                                    ' str{}{%s} ' % w)

        try:
            parse_tree = seq2tree_repr_to_ast_tree(predict_repr)
            merge_broken_value_nodes(parse_tree)
        except:
            print('error when converting:')
            print(predict_repr)
            convert_error_num += 1
            continue

        raw_id = line_id_to_raw_id[i]
        eid = raw_id_to_eid[raw_id]
        example = dataset.examples[eid]

        ref_code = example.code
        ref_ast_tree = ast.parse(ref_code).body[0]
        refer_source = astor.to_source(ref_ast_tree).strip()
        refer_tokens = tokenize_code(refer_source)

        try:
            ast_tree = decode_tree_to_python_ast(parse_tree)
            code = astor.to_source(ast_tree).strip()
        except:
            print("Exception in converting tree to code:")
            print('-' * 60)
            print('line id: %d' % i)
            traceback.print_exc(file=sys.stdout)
            print('-' * 60)
            convert_error_num += 1
            continue

        if config.data_type == 'django':
            ref_code_for_bleu = example.meta_data['raw_code']
            pred_code_for_bleu = de_canonicalize_code(
                code, example.meta_data['raw_code'])
            # convert canonicalized code to raw code
            for literal, place_holder in example.meta_data[
                    'str_map'].iteritems():
                pred_code_for_bleu = pred_code_for_bleu.replace(
                    '\'' + place_holder + '\'', literal)
        elif config.data_type == 'hs':
            ref_code_for_bleu = ref_code
            pred_code_for_bleu = code

        # we apply Ling Wang's trick when evaluating BLEU scores
        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)

        predict_tokens = tokenize_code(code)
        # if ref_repr == predict_repr:
        if predict_tokens == refer_tokens:
            cum_acc += 1

        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
        bleu_score = sentence_bleu([refer_tokens_for_bleu],
                                   pred_tokens_for_bleu,
                                   weights=ngram_weights,
                                   smoothing_function=sm.method3)
        cum_bleu += bleu_score

    cum_bleu /= len(line_id_to_raw_id)
    cum_acc /= len(line_id_to_raw_id)
    logging.info('nun. examples: %d', len(line_id_to_raw_id))
    logging.info('num. errors when converting repr to tree: %d',
                 convert_error_num)
    logging.info('ratio of grammatically incorrect trees: %f',
                 convert_error_num / float(len(line_id_to_raw_id)))
    logging.info('sentence level bleu: %f', cum_bleu)
    logging.info('accuracy: %f', cum_acc)
Пример #10
0
        return reverse_typename(t[:-1])
    else:
        return vars(ast)[t]

if __name__ == '__main__':

    flag = "hs"
    path_to_load = "../data/exp/results/test_hs_10_iter.json"
    if flag == "django":
        train_data, dev_data, test_data = deserialize_from_file("../../django.cleaned.dataset.freq5.par_info.refact.space_only.bin")
    elif flag == "hs":
        train_data, dev_data, test_data = deserialize_from_file("../../hs.freq3.pre_suf.unary_closure.bin")
    data = test_data
    g = data.grammar
    nt = {v:reverse_typename(k) for k,v in g.node_type_to_id.items()}

    #print(nt,g.node_type_to_id)
    v = data.terminal_vocab

    results = []
    with open(path_to_load,'r') as f:
        l = json.load(f, encoding='utf8')

    for i in range(len(l)):
        t = ASTNode.from_dict(l[i], nt,v)
        ast_tree = parse.decode_tree_to_python_ast(t)
        results.append(ast_tree)


    evaluate_decode_results(flag, test_data, results, verbose=True)
Пример #11
0
def evaluate_seq2tree_sample_file(sample_file, id_file, dataset):
    from lang.py.parse import tokenize_code, de_canonicalize_code
    import ast, astor
    import traceback
    from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes
    from lang.py.parse import decode_tree_to_python_ast

    f_sample = open(sample_file)
    line_id_to_raw_id = OrderedDict()
    raw_id_to_eid = OrderedDict()
    for i, line in enumerate(open(id_file)):
        raw_id = int(line.strip())
        line_id_to_raw_id[i] = raw_id

    for eid in range(len(dataset.examples)):
        raw_id_to_eid[dataset.examples[eid].raw_id] = eid

    rare_word_map = defaultdict(dict)
    if config.seq2tree_rareword_map:
        logging.info('use rare word map')
        for i, line in enumerate(open(config.seq2tree_rareword_map)):
            line = line.strip()
            if line:
                for e in line.split(' '):
                    d = e.split(':', 1)
                    rare_word_map[i][int(d[0])] = d[1]

    cum_bleu = 0.0
    cum_acc = 0.0
    sm = SmoothingFunction()
    convert_error_num = 0

    for i in range(len(line_id_to_raw_id)):
        # print 'working on %d' % i
        ref_repr = f_sample.readline().strip()
        predict_repr = f_sample.readline().strip()
        predict_repr = predict_repr.replace('<U>', 'str{}{unk}') # .replace('( )', '( str{}{unk} )')
        f_sample.readline()

        # if ' ( ) ' in ref_repr:
        #     print i, ref_repr

        if i in rare_word_map:
            for unk_id, w in rare_word_map[i].iteritems():
                ref_repr = ref_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)
                predict_repr = predict_repr.replace(' str{}{unk_%s} ' % unk_id, ' str{}{%s} ' % w)

        try:
            parse_tree = seq2tree_repr_to_ast_tree(predict_repr)
            merge_broken_value_nodes(parse_tree)
        except:
            print 'error when converting:'
            print predict_repr
            convert_error_num += 1
            continue

        raw_id = line_id_to_raw_id[i]
        eid = raw_id_to_eid[raw_id]
        example = dataset.examples[eid]

        ref_code = example.code
        ref_ast_tree = ast.parse(ref_code).body[0]
        refer_source = astor.to_source(ref_ast_tree).strip()
        refer_tokens = tokenize_code(refer_source)

        try:
            ast_tree = decode_tree_to_python_ast(parse_tree)
            code = astor.to_source(ast_tree).strip()
        except:
            print "Exception in converting tree to code:"
            print '-' * 60
            print 'line id: %d' % i
            traceback.print_exc(file=sys.stdout)
            print '-' * 60
            convert_error_num += 1
            continue

        if config.data_type == 'django':
            ref_code_for_bleu = example.meta_data['raw_code']
            pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
            # convert canonicalized code to raw code
            for literal, place_holder in example.meta_data['str_map'].iteritems():
                pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
        elif config.data_type == 'hs':
            ref_code_for_bleu = ref_code
            pred_code_for_bleu = code

        # we apply Ling Wang's trick when evaluating BLEU scores
        refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
        pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)

        predict_tokens = tokenize_code(code)
        # if ref_repr == predict_repr:
        if predict_tokens == refer_tokens:
            cum_acc += 1

        ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
        bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights,
                                   smoothing_function=sm.method3)
        cum_bleu += bleu_score

    cum_bleu /= len(line_id_to_raw_id)
    cum_acc /= len(line_id_to_raw_id)
    logging.info('nun. examples: %d', len(line_id_to_raw_id))
    logging.info('num. errors when converting repr to tree: %d', convert_error_num)
    logging.info('ratio of grammatically incorrect trees: %f', convert_error_num / float(len(line_id_to_raw_id)))
    logging.info('sentence level bleu: %f', cum_bleu)
    logging.info('accuracy: %f', cum_acc)