示例#1
0
def parse_ifttt_dataset():
    WORD_FREQ_CUT_OFF = 2

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt'

    data = preprocess_ifttt_dataset(annot_file, code_file)

    # build the grammar
    grammar = get_grammar([e['parse_tree'] for e in data])

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens, vocab_size=30000, freq_cutoff=WORD_FREQ_CUT_OFF)

    logging.info('annot vocab. size: %d', annot_vocab.size)

    # we have no terminal tokens in ifttt
    all_terminal_tokens = []
    terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=4000, freq_cutoff=WORD_FREQ_CUT_OFF)

    # now generate the dataset!

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.test_data')

    all_examples = []

    can_fully_reconstructed_examples_num = 0
    examples_with_empty_actions_num = 0

    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        # check if query tokens are valid
        query_token_ids = [annot_vocab[token] for token in query_tokens if token not in string.punctuation]
        valid_query_tokens_ids = [tid for tid in query_token_ids if tid != annot_vocab.unk]

        # remove examples with rare words from train and dev, avoid overfitting
        if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171:
            continue

        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)

        actions = []
        can_fully_reconstructed = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                raise RuntimeError('no terminals should be in ifttt dataset!')

        if len(actions) == 0:
            examples_with_empty_actions_num += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions,
                            {'str_map': None, 'raw_code': entry['raw_code']})

        if can_fully_reconstructed:
            can_fully_reconstructed_examples_num += 1

        # train, valid, test splits
        if 0 <= idx < 77495:
            train_data.add(example)
        elif idx < 77495 + 5171:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('train_data examples: %d', train_data.count)
    logging.info('dev_data examples: %d', dev_data.count)
    logging.info('test_data examples: %d', test_data.count)

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_reconstructed_examples_num, len(all_examples),
                 can_fully_reconstructed_examples_num / len(all_examples))
    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)

    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices(max_query_length=40, max_example_action_num=6)
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data),
                      'data/ifttt.freq{WORD_FREQ_CUT_OFF}.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))

    return train_data, dev_data, test_data
示例#2
0
        # for e in short_examples:
        #     print e.parse_tree
        # print 'short examples num: ', len(short_examples)

        # dataset = test_data # test_data.get_dataset_by_ids([1,2,3,4,5,6,7,8,9,10], name='sample')
        # cProfile.run('decode_dataset(model, dataset)', sort=2)

        # from evaluation import decode_and_evaluate_ifttt
        if args.data_type == 'ifttt':
            decode_results = decode_and_evaluate_ifttt_by_split(
                model, test_data)
        else:
            dataset = eval(args.type)
            decode_results = decode_python_dataset(model, dataset)

        serialize_to_file(decode_results, args.saveto)

    if args.operation == 'evaluate':
        dataset = eval(args.type)
        if config.mode == 'self':
            decode_results_file = args.input
            decode_results = deserialize_from_file(decode_results_file)

            evaluate_decode_results(dataset, decode_results)
        elif config.mode == 'seq2tree':
            from evaluation import evaluate_seq2tree_sample_file
            evaluate_seq2tree_sample_file(config.seq2tree_sample_file,
                                          config.seq2tree_id_file, dataset)
        elif config.mode == 'seq2seq':
            from evaluation import evaluate_seq2seq_decode_results
            evaluate_seq2seq_decode_results(dataset,
示例#3
0
def parse_django_dataset_nt_only():
    from parse import parse_django

    annot_file = 'all.anno'

    vocab = gen_vocab(annot_file, vocab_size=4500)

    code_file = 'all.code'

    grammar, all_parse_trees = parse_django(code_file)

    train_data = DataSet(vocab, grammar, name='train')
    dev_data = DataSet(vocab, grammar, name='dev')
    test_data = DataSet(vocab, grammar, name='test')

    # train_data

    train_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/train.anno'
    train_parse_trees = all_parse_trees[0:16000]
    for line, parse_tree in zip(open(train_annot_file), train_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        train_data.add(entry)

    train_data.init_data_matrices()

    # dev_data

    dev_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/dev.anno'
    dev_parse_trees = all_parse_trees[16000:17000]
    for line, parse_tree in zip(open(dev_annot_file), dev_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        dev_data.add(entry)

    dev_data.init_data_matrices()

    # test_data

    test_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/test.anno'
    test_parse_trees = all_parse_trees[17000:18805]
    for line, parse_tree in zip(open(test_annot_file), test_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        test_data.add(entry)

    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data),
                      'django.typed_rule.bin')
示例#4
0
def parse_django_dataset():
    from lang.py.parse import parse_raw
    from lang.util import escape
    MAX_QUERY_LENGTH = 70
    UNARY_CUTOFF_FREQ = 30

    annot_file = 'all.anno'
    code_file = 'all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # write grammar
    with open('django.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(
        annot_tokens, vocab_size=5000,
        freq_cutoff=3)  # gen_vocab(annot_tokens, vocab_size=5980)

    terminal_token_seq = []
    empty_actions_count = 0

    # helper function begins
    def get_terminal_tokens(_terminal_str):
        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')
        tmp_terminal_tokens = _terminal_str.split(' ')
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            if token:
                _terminal_tokens.append(token)
            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]
        # return _terminal_tokens

    # helper function ends

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    terminal_token_seq.append(terminal_token)

    terminal_vocab = gen_vocab(terminal_token_seq,
                               vocab_size=5000,
                               freq_cutoff=3)
    assert '_STR:0_' in terminal_vocab

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')

    all_examples = []

    can_fully_gen_num = 0

    # second pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        str_map = entry['str_map']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        actions = []
        can_fully_gen = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {
                        'literal': terminal_token,
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_gen = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {
                    'literal': '<eos>',
                    'rule': rule,
                    'parent_rule': parent_rule,
                    'parent_t': parent_t
                }
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            empty_actions_count += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {
            'raw_code': entry['raw_code'],
            'str_map': entry['str_map']
        })

        if can_fully_gen:
            can_fully_gen_num += 1

        # train, valid, test
        if 0 <= idx < 16000:
            train_data.add(example)
        elif 16000 <= idx < 17000:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_gen_num, len(all_examples),
                 can_fully_gen_num / len(all_examples))
    logging.info('empty_actions_count: %d', empty_actions_count)
    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices()
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((
        train_data, dev_data, test_data
    ), 'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin'
                      )
    # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

    return train_data, dev_data, test_data
示例#5
0
def parse_django_dataset():
    from lang.py.parse import parse_raw
    from lang.util import escape
    MAX_QUERY_LENGTH = 70
    UNARY_CUTOFF_FREQ = 30

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    data = preprocess_dataset(annot_file, code_file)

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # write grammar
    with open('django.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980)

    terminal_token_seq = []
    empty_actions_count = 0

    # helper function begins
    def get_terminal_tokens(_terminal_str):
        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')
        tmp_terminal_tokens = _terminal_str.split(' ')
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            if token:
                _terminal_tokens.append(token)
            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]
        # return _terminal_tokens
    # helper function ends

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    terminal_token_seq.append(terminal_token)

    terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3)
    assert '_STR:0_' in terminal_vocab

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data')

    all_examples = []

    can_fully_gen_num = 0

    # second pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        str_map = entry['str_map']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)

        actions = []
        can_fully_gen = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_gen = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            empty_actions_count += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions,
                            {'raw_code': entry['raw_code'], 'str_map': entry['str_map']})

        if can_fully_gen:
            can_fully_gen_num += 1

        # train, valid, test
        if 0 <= idx < 16000:
            train_data.add(example)
        elif 16000 <= idx < 17000:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_gen_num, len(all_examples),
                 can_fully_gen_num / len(all_examples))
    logging.info('empty_actions_count: %d', empty_actions_count)
    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices()
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data),
                      'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')
                      # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

    return train_data, dev_data, test_data
示例#6
0
def parse_ifttt_dataset():
    WORD_FREQ_CUT_OFF = 2

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/lang.all.txt'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/ifttt/Data/code.all.txt'

    data = preprocess_ifttt_dataset(annot_file, code_file)

    # build the grammar
    grammar = get_grammar([e['parse_tree'] for e in data])

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens,
                            vocab_size=30000,
                            freq_cutoff=WORD_FREQ_CUT_OFF)

    logging.info('annot vocab. size: %d', annot_vocab.size)

    # we have no terminal tokens in ifttt
    all_terminal_tokens = []
    terminal_vocab = gen_vocab(all_terminal_tokens,
                               vocab_size=4000,
                               freq_cutoff=WORD_FREQ_CUT_OFF)

    # now generate the dataset!

    train_data = DataSet(annot_vocab, terminal_vocab, grammar,
                         'ifttt.train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'ifttt.dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar,
                        'ifttt.test_data')

    all_examples = []

    can_fully_reconstructed_examples_num = 0
    examples_with_empty_actions_num = 0

    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        # check if query tokens are valid
        query_token_ids = [
            annot_vocab[token] for token in query_tokens
            if token not in string.punctuation
        ]
        valid_query_tokens_ids = [
            tid for tid in query_token_ids if tid != annot_vocab.unk
        ]

        # remove examples with rare words from train and dev, avoid overfitting
        if len(valid_query_tokens_ids) == 0 and 0 <= idx < 77495 + 5171:
            continue

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        actions = []
        can_fully_reconstructed = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                raise RuntimeError('no terminals should be in ifttt dataset!')

        if len(actions) == 0:
            examples_with_empty_actions_num += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {
            'str_map': None,
            'raw_code': entry['raw_code']
        })

        if can_fully_reconstructed:
            can_fully_reconstructed_examples_num += 1

        # train, valid, test splits
        if 0 <= idx < 77495:
            train_data.add(example)
        elif idx < 77495 + 5171:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('train_data examples: %d', train_data.count)
    logging.info('dev_data examples: %d', dev_data.count)
    logging.info('test_data examples: %d', test_data.count)

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_reconstructed_examples_num, len(all_examples),
                 can_fully_reconstructed_examples_num / len(all_examples))
    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)

    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices(max_query_length=40,
                                  max_example_action_num=6)
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data),
                      'data/ifttt.freq{WORD_FREQ_CUT_OFF}.bin'.format(
                          WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))

    return train_data, dev_data, test_data
示例#7
0
def parse_hs_dataset():
    MAX_QUERY_LENGTH = 70  # FIXME: figure out the best config!
    WORD_FREQ_CUT_OFF = 3

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'

    data = preprocess_hs_dataset(annot_file, code_file)
    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    unary_closures = get_top_unary_closures(parse_trees, k=20)
    for parse_tree in parse_trees:
        apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    with open('hs.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens,
                            vocab_size=5000,
                            freq_cutoff=WORD_FREQ_CUT_OFF)

    def get_terminal_tokens(_terminal_str):
        """
        get terminal tokens
        break words like MinionCards into [Minion, Cards]
        """
        tmp_terminal_tokens = [
            t for t in _terminal_str.split(' ') if len(t) > 0
        ]
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ')
            _terminal_tokens.extend(sub_tokens)

            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]

    # enumerate all terminal tokens to build up the terminal tokens vocabulary
    all_terminal_tokens = []
    for entry in data:
        parse_tree = entry['parse_tree']
        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    all_terminal_tokens.append(terminal_token)

    terminal_vocab = gen_vocab(all_terminal_tokens,
                               vocab_size=5000,
                               freq_cutoff=WORD_FREQ_CUT_OFF)

    # now generate the dataset!

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data')

    all_examples = []

    can_fully_reconstructed_examples_num = 0
    examples_with_empty_actions_num = 0

    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        actions = []
        can_fully_reconstructed = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {
                        'literal': terminal_token,
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_reconstructed = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {
                    'literal': '<eos>',
                    'rule': rule,
                    'parent_rule': parent_rule,
                    'parent_t': parent_t
                }
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            examples_with_empty_actions_num += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {
            'str_map': None,
            'raw_code': entry['raw_code']
        })

        if can_fully_reconstructed:
            can_fully_reconstructed_examples_num += 1

        # train, valid, test splits
        if 0 <= idx < 533:
            train_data.add(example)
        elif idx < 599:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_reconstructed_examples_num, len(all_examples),
                 can_fully_reconstructed_examples_num / len(all_examples))
    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)

    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices(max_query_length=70,
                                  max_example_action_num=350)
    dev_data.init_data_matrices(max_query_length=70,
                                max_example_action_num=350)
    test_data.init_data_matrices(max_query_length=70,
                                 max_example_action_num=350)

    serialize_to_file((
        train_data, dev_data, test_data
    ), 'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'
                      .format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))

    return train_data, dev_data, test_data
示例#8
0
def parse_django_dataset_nt_only():
    from parse import parse_django

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'

    vocab = gen_vocab(annot_file, vocab_size=4500)

    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'

    grammar, all_parse_trees = parse_django(code_file)

    train_data = DataSet(vocab, grammar, name='train')
    dev_data = DataSet(vocab, grammar, name='dev')
    test_data = DataSet(vocab, grammar, name='test')

    # train_data

    train_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/train.anno'
    train_parse_trees = all_parse_trees[0:16000]
    for line, parse_tree in zip(open(train_annot_file), train_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        train_data.add(entry)

    train_data.init_data_matrices()

    # dev_data

    dev_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/dev.anno'
    dev_parse_trees = all_parse_trees[16000:17000]
    for line, parse_tree in zip(open(dev_annot_file), dev_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        dev_data.add(entry)

    dev_data.init_data_matrices()

    # test_data

    test_annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/test.anno'
    test_parse_trees = all_parse_trees[17000:18805]
    for line, parse_tree in zip(open(test_annot_file), test_parse_trees):
        if parse_tree.is_leaf:
            continue

        line = line.strip()
        tokens = tokenize(line)
        entry = DataEntry(tokens, parse_tree)

        test_data.add(entry)

    test_data.init_data_matrices()

    serialize_to_file((train_data, dev_data, test_data), 'django.typed_rule.bin')
示例#9
0
    logging.info('source vocab size: %d', train_data.annot_vocab.size)
    logging.info('target vocab size: %d', train_data.terminal_vocab.size)

    if args.operation in ['train', 'decode', 'interactive', 'align']:
        if args.enable_retrieval:
            model = RetrievalModel()
        else:
            model = Model()
        model.build()

        if args.model:
            model.load(args.model)

    if args.operation == 'align':
        aligned_train_data = compute_alignments(model, train_data)
        serialize_to_file((aligned_train_data, dev_data, test_data), args.saveto)

    if args.operation == 'train':
        # train_data = train_data.get_dataset_by_ids(range(2000), 'train_sample')
        # dev_data = dev_data.get_dataset_by_ids(range(10), 'dev_sample')
        learner = Learner(model, train_data, dev_data)
        learner.train()

    if args.operation == 'decode':
        # ==========================
        # investigate short examples
        # ==========================

        # short_examples = [e for e in test_data.examples if e.parse_tree.size <= 2]
        # for e in short_examples:
        #     print e.parse_tree
示例#10
0
def parse_train_dataset(args):
    MAX_QUERY_LENGTH = 70  # FIXME: figure out the best config!
    WORD_FREQ_CUT_OFF = 0

    # nl_file = './data/mix.nl'
    # sql_file = './data/mix-1.sql'
    # data_file = './data/train.json'
    # ast_file = './data/mix.json'
    train_data = preprocess_sql_dataset(args.train_data, args.train_data_ast)
    dev_data = preprocess_sql_dataset(args.dev_data, args.dev_data_ast)
    test_data = preprocess_sql_dataset(args.test_data, args.test_data_ast)
    data = train_data + dev_data + test_data
    print("data size: {}".format(len(data)))
    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=20)
    # for parse_tree in parse_trees:
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    with open('sql.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    nl_tokens = list(chain(*[e['query_tokens'] for e in data]))
    nl_vocab = gen_vocab(nl_tokens,
                         vocab_size=5000,
                         freq_cutoff=WORD_FREQ_CUT_OFF)

    # enumerate all terminal tokens to build up the terminal tokens vocabulary
    all_terminal_tokens = []
    for entry in data:
        parse_tree = entry['parse_tree']
        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    all_terminal_tokens.append(terminal_token)

    # print all_terminal_tokens
    table_schema = args.table_schema

    terminal_vocab = gen_vocab(all_terminal_tokens,
                               vocab_size=5000,
                               freq_cutoff=WORD_FREQ_CUT_OFF)
    non_schema_vocab_size = terminal_vocab.size
    db_dict, schema_vocab = load_table_schema_data(table_schema)
    terminal_vocab = gen_schema_vocab(schema_vocab, terminal_vocab)
    db_mask = gen_db_mask(terminal_vocab, non_schema_vocab_size, table_schema)

    # print terminal_vocab
    # now generate the dataset!
    # print(terminal_vocab)
    # print(terminal_vocab.token_id_map.keys())
    train_data = DataSet(nl_vocab, terminal_vocab, grammar, db_mask,
                         'sql.train_data')
    dev_data = DataSet(nl_vocab, terminal_vocab, grammar, db_mask,
                       'sql.dev_data')
    test_data = DataSet(nl_vocab, terminal_vocab, grammar, db_mask,
                        'sql.test_data')

    all_examples = []

    can_fully_reconstructed_examples_num = 0
    examples_with_empty_actions_num = 0
    # print(list(terminal_vocab.iteritems()))

    for index, entry in enumerate(data):
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        actions = []
        can_fully_reconstructed = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            # if rule_count == 116:
            #     continue
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None, rule.value
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf, (rule.type, rule.value, rule.label)

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {
                        'literal': terminal_token,
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_reconstructed = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {
                    'literal': '<eos>',
                    'rule': rule,
                    'parent_rule': parent_rule,
                    'parent_t': parent_t
                }
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            examples_with_empty_actions_num += 1
            continue
        mask = db_mask[entry['db_id']]
        example = DataEntry(idx, query_tokens, parse_tree, code, actions, mask,
                            {
                                'str_map': None,
                                'raw_code': entry['raw_code']
                            })

        if can_fully_reconstructed:
            can_fully_reconstructed_examples_num += 1

        # train, valid, test splits
        if 0 <= index < args.train_data_size:
            train_data.add(example)
        elif index < args.train_data_size + args.dev_data_size:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)
    # print("test data size {}".format(len(test_data)))
    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_reconstructed_examples_num, len(all_examples),
                 can_fully_reconstructed_examples_num / len(all_examples))
    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)

    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices(max_query_length=70,
                                  max_example_action_num=350)
    dev_data.init_data_matrices(max_query_length=70,
                                max_example_action_num=350)
    test_data.init_data_matrices(max_query_length=70,
                                 max_example_action_num=350)

    # serialize_to_file((train_data, dev_data, test_data),
    #                   './data/sql.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))
    print("train data size:{}".format(train_data.count))
    print("dev data size:{}".format(dev_data.count))
    print("test data size:{}".format(test_data.count))
    serialize_to_file((train_data, dev_data, test_data), args.output_path)
    return train_data, dev_data, test_data
示例#11
0
def parse_hs_dataset():
    MAX_QUERY_LENGTH = 70 # FIXME: figure out the best config!
    WORD_FREQ_CUT_OFF = 3

    annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.mod.in'
    code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/card_datasets/hearthstone/all_hs.out'

    data = preprocess_hs_dataset(annot_file, code_file)
    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    unary_closures = get_top_unary_closures(parse_trees, k=20)
    for parse_tree in parse_trees:
        apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    with open('hs.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)

    def get_terminal_tokens(_terminal_str):
        """
        get terminal tokens
        break words like MinionCards into [Minion, Cards]
        """
        tmp_terminal_tokens = [t for t in _terminal_str.split(' ') if len(t) > 0]
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            sub_tokens = re.sub(r'([a-z])([A-Z])', r'\1 \2', token).split(' ')
            _terminal_tokens.extend(sub_tokens)

            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]

    # enumerate all terminal tokens to build up the terminal tokens vocabulary
    all_terminal_tokens = []
    for entry in data:
        parse_tree = entry['parse_tree']
        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    all_terminal_tokens.append(terminal_token)

    terminal_vocab = gen_vocab(all_terminal_tokens, vocab_size=5000, freq_cutoff=WORD_FREQ_CUT_OFF)

    # now generate the dataset!

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.train_data')
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.dev_data')
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'hs.test_data')

    all_examples = []

    can_fully_reconstructed_examples_num = 0
    examples_with_empty_actions_num = 0

    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(include_value_node=True)

        actions = []
        can_fully_reconstructed = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule}
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_reconstructed = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t}
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            examples_with_empty_actions_num += 1
            continue

        example = DataEntry(idx, query_tokens, parse_tree, code, actions, {'str_map': None, 'raw_code': entry['raw_code']})

        if can_fully_reconstructed:
            can_fully_reconstructed_examples_num += 1

        # train, valid, test splits
        if 0 <= idx < 533:
            train_data.add(example)
        elif idx < 599:
            dev_data.add(example)
        else:
            test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    # serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    # serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_reconstructed_examples_num, len(all_examples),
                 can_fully_reconstructed_examples_num / len(all_examples))
    logging.info('empty_actions_count: %d', examples_with_empty_actions_num)

    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
    dev_data.init_data_matrices(max_query_length=70, max_example_action_num=350)
    test_data.init_data_matrices(max_query_length=70, max_example_action_num=350)

    serialize_to_file((train_data, dev_data, test_data),
                      'data/hs.freq{WORD_FREQ_CUT_OFF}.max_action350.pre_suf.unary_closure.bin'.format(WORD_FREQ_CUT_OFF=WORD_FREQ_CUT_OFF))

    return train_data, dev_data, test_data
示例#12
0
        # short_examples = [e for e in test_data.examples if e.parse_tree.size <= 2]
        # for e in short_examples:
        #     print e.parse_tree
        # print 'short examples num: ', len(short_examples)

        # dataset = test_data # test_data.get_dataset_by_ids([1,2,3,4,5,6,7,8,9,10], name='sample')
        # cProfile.run('decode_dataset(model, dataset)', sort=2)

        # from evaluation import decode_and_evaluate_ifttt
        if args.data_type == 'ifttt':
            decode_results = decode_and_evaluate_ifttt_by_split(model, test_data)
        else:
            dataset = eval(args.type)
            decode_results = decode_python_dataset(model, dataset)

        serialize_to_file(decode_results, args.saveto)

    if args.operation == 'evaluate':
        dataset = eval(args.type)
        if config.mode == 'self':
            decode_results_file = args.input
            decode_results = deserialize_from_file(decode_results_file)

            evaluate_decode_results(dataset, decode_results)
        elif config.mode == 'seq2tree':
            from evaluation import evaluate_seq2tree_sample_file
            evaluate_seq2tree_sample_file(config.seq2tree_sample_file, config.seq2tree_id_file, dataset)
        elif config.mode == 'seq2seq':
            from evaluation import evaluate_seq2seq_decode_results
            evaluate_seq2seq_decode_results(dataset, config.seq2seq_decode_file, config.seq2seq_ref_file, is_nbest=config.is_nbest)
        elif config.mode == 'analyze':
示例#13
0
def parse_django_dataset():
    from lang.py.parse import parse_raw
    from lang.util import escape
    MAX_QUERY_LENGTH = 70
    UNARY_CUTOFF_FREQ = 30

    ##annot_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.anno'
    ##code_file = '/Users/yinpengcheng/Research/SemanticParsing/CodeGeneration/en-django/all.code'
    #annot_file = ALLANNO
    #code_file  = ALLCODE
    annot_file = GENALLANNO
    code_file = GENALLCODE

    #data = preprocess_dataset(annot_file, code_file)
    #data = preprocess_gendataset(annot_file, code_file)
    data = preprocess_syndataset(annot_file, code_file)

    # print data
    #for e in data:

    #    print "-" * 60
    #    print "\n\n"
    #    print "idx - ", e['id']
    #    print "\n\n"
    #    print "query tokens - ", e['query_tokens']
    #    print "\n\n"
    #    print "code - ", e['code']
    #    print "\n\n"
    #    print "str_map - ", e['str_map']
    #    print "\n\n"
    #    print "raw_code - ", e['raw_code']
    #    print "\n\n"
    #    print "bannot - ", e['bannot']
    #    print "\n\n"
    #    print "bcode  - ", e['bcode']
    #    print "\n\n"
    #    print "ref_type - ", e['ref_type']
    #    print "\n\n"

    for e in data:
        e['parse_tree'] = parse_raw(e['code'])

    parse_trees = [e['parse_tree'] for e in data]

    # apply unary closures
    # unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
    # for i, parse_tree in enumerate(parse_trees):
    #     apply_unary_closures(parse_tree, unary_closures)

    # build the grammar
    grammar = get_grammar(parse_trees)

    # write grammar
    with open('django.grammar.unary_closure.txt', 'w') as f:
        for rule in grammar:
            f.write(rule.__repr__() + '\n')

    # # build grammar ...
    # from lang.py.py_dataset import extract_grammar
    # grammar, all_parse_trees = extract_grammar(code_file)

    annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
    annot_tokens = normalize_query_tokens(annot_tokens)

    annot_vocab = gen_vocab(
        annot_tokens, vocab_size=5000,
        freq_cutoff=3)  # gen_vocab(annot_tokens, vocab_size=5980)
    #annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=0) # gen_vocab(annot_tokens, vocab_size=5980)

    terminal_token_seq = []
    empty_actions_count = 0

    # helper function begins
    def get_terminal_tokens(_terminal_str):
        # _terminal_tokens = filter(None, re.split('([, .?!])', _terminal_str)) # _terminal_str.split('-SP-')
        # _terminal_tokens = filter(None, re.split('( )', _terminal_str))  # _terminal_str.split('-SP-')
        tmp_terminal_tokens = _terminal_str.split(' ')
        _terminal_tokens = []
        for token in tmp_terminal_tokens:
            if token:
                _terminal_tokens.append(token)
            _terminal_tokens.append(' ')

        return _terminal_tokens[:-1]
        # return _terminal_tokens

    # helper function ends

    # first pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        parse_tree = entry['parse_tree']

        for node in parse_tree.get_leaves():
            if grammar.is_value_node(node):
                terminal_val = node.value
                terminal_str = str(terminal_val)

                terminal_tokens = get_terminal_tokens(terminal_str)

                for terminal_token in terminal_tokens:
                    assert len(terminal_token) > 0
                    terminal_token_seq.append(terminal_token)

    terminal_vocab = gen_vocab(terminal_token_seq,
                               vocab_size=5000,
                               freq_cutoff=3)
    #terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=0)
    phrase_vocab = gen_phrase_vocab(data)
    pos_vocab = gen_pos_vocab(data)
    #assert '_STR:0_' in terminal_vocab

    train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data',
                         phrase_vocab, pos_vocab)
    dev_data = DataSet(annot_vocab, terminal_vocab, grammar, 'dev_data',
                       phrase_vocab, pos_vocab)
    test_data = DataSet(annot_vocab, terminal_vocab, grammar, 'test_data',
                        phrase_vocab, pos_vocab)

    all_examples = []

    can_fully_gen_num = 0

    # second pass
    for entry in data:
        idx = entry['id']
        query_tokens = entry['query_tokens']
        code = entry['code']
        str_map = entry['str_map']
        parse_tree = entry['parse_tree']

        rule_list, rule_parents = parse_tree.get_productions(
            include_value_node=True)

        #print "Rule List - "
        #for r in rule_list:
        #    print "Rule -", r

        #for k, v in rule_parents.iteritems():
        #    print "Rule parents - ", k, " - ", v
        #print "Rule parents - ", rule_parents

        actions = []
        can_fully_gen = True
        rule_pos_map = dict()

        for rule_count, rule in enumerate(rule_list):
            if not grammar.is_value_node(rule.parent):
                assert rule.value is None
                parent_rule = rule_parents[(rule_count, rule)][0]
                if parent_rule:
                    parent_t = rule_pos_map[parent_rule]
                else:
                    parent_t = 0

                rule_pos_map[rule] = len(actions)

                d = {
                    'rule': rule,
                    'parent_t': parent_t,
                    'parent_rule': parent_rule
                }
                action = Action(APPLY_RULE, d)

                actions.append(action)
            else:
                assert rule.is_leaf

                parent_rule = rule_parents[(rule_count, rule)][0]
                parent_t = rule_pos_map[parent_rule]

                terminal_val = rule.value
                terminal_str = str(terminal_val)
                terminal_tokens = get_terminal_tokens(terminal_str)

                # assert len(terminal_tokens) > 0

                for terminal_token in terminal_tokens:
                    term_tok_id = terminal_vocab[terminal_token]
                    tok_src_idx = -1
                    try:
                        tok_src_idx = query_tokens.index(terminal_token)
                    except ValueError:
                        pass

                    d = {
                        'literal': terminal_token,
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }

                    # cannot copy, only generation
                    # could be unk!
                    if tok_src_idx < 0 or tok_src_idx >= MAX_QUERY_LENGTH:
                        action = Action(GEN_TOKEN, d)
                        if terminal_token not in terminal_vocab:
                            if terminal_token not in query_tokens:
                                # print terminal_token
                                can_fully_gen = False
                    else:  # copy
                        if term_tok_id != terminal_vocab.unk:
                            d['source_idx'] = tok_src_idx
                            action = Action(GEN_COPY_TOKEN, d)
                        else:
                            d['source_idx'] = tok_src_idx
                            action = Action(COPY_TOKEN, d)

                    actions.append(action)

                d = {
                    'literal': '<eos>',
                    'rule': rule,
                    'parent_rule': parent_rule,
                    'parent_t': parent_t
                }
                actions.append(Action(GEN_TOKEN, d))

        if len(actions) == 0:
            empty_actions_count += 1
            continue

        example = DataEntry(
            idx, query_tokens, parse_tree, code, actions, {
                'raw_code': entry['raw_code'],
                'str_map': entry['str_map'],
                'phrase': entry['phrase'],
                'pos': entry['pos'],
                'bannot': entry['bannot'],
                'bcode': entry['bcode'],
                'ref_type': entry['ref_type']
            })

        if can_fully_gen:
            can_fully_gen_num += 1

        # train, valid, test
        if 0 <= idx < 13000:
            train_data.add(example)
        elif 13000 <= idx < 14000:
            dev_data.add(example)
        else:
            test_data.add(example)

        # modified train valid test counts
        #if 0 <= idx < 10000:
        #    train_data.add(example)
        #elif 10000 <= idx < 11000:
        #    dev_data.add(example)
        #else:
        #    test_data.add(example)

        all_examples.append(example)

    # print statistics
    max_query_len = max(len(e.query) for e in all_examples)
    max_actions_len = max(len(e.actions) for e in all_examples)

    serialize_to_file([len(e.query) for e in all_examples], 'query.len')
    serialize_to_file([len(e.actions) for e in all_examples], 'actions.len')

    logging.info('examples that can be fully reconstructed: %d/%d=%f',
                 can_fully_gen_num, len(all_examples),
                 can_fully_gen_num / len(all_examples))
    logging.info('empty_actions_count: %d', empty_actions_count)
    logging.info('max_query_len: %d', max_query_len)
    logging.info('max_actions_len: %d', max_actions_len)

    train_data.init_data_matrices()
    dev_data.init_data_matrices()
    test_data.init_data_matrices()

    #print train_data

    ## print train_data matrix
    #print "Data matrix: query_tokens "
    #print train_data.data_matrix['query_tokens']
    #print "\n" * 2
    #print "Data matrix : query_tokens_phrase"
    #print "\n" * 2
    #print train_data.data_matrix['query_tokens_phrase']
    #print "\n" * 2
    #print "Data matrix : query_tokens_pos"
    #print "\n" * 2
    #print train_data.data_matrix['query_tokens_pos']
    #print "\n" * 2
    #print "Data matrix : query_tokens_cid"
    #print "\n" * 2
    #print train_data.data_matrix['query_tokens_cid']
    #print "\n" * 2

    ## print few data entries
    #for d in train_data.examples[:5]:
    #    print "\n" * 2
    #    print d

    ## lets print dataset for good measure

    serialize_to_file(
        (train_data, dev_data, test_data),
        # 'data/django.pnet.qparse.dataset.freq3.par_info.refact.space_only.bin')
        'data/django.pnet.fullcanon.dataset.freq3.par_info.refact.space_only.bin'
    )
    # 'data/django.pnet.dataset.freq3.par_info.refact.space_only.bin')
    #'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')
    # 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

    return train_data, dev_data, test_data