コード例 #1
0
def avg_and_max_number_of_actions(data_dir):
    nodes_numbers = []
    for split in splits:
        file = os.path.join(data_dir, '{}/{}.out.bin'.format(split, split))
        codes = deserialize_from_file(file)
        for code in codes:
            nodes_numbers.append(number_of_ast_nodes(code))

    return np.mean(nodes_numbers), max(nodes_numbers), nodes_numbers
コード例 #2
0
def avg_and_max_number_char_in_code(data_dir):
    char_len = []
    for split in splits:
        file = os.path.join(data_dir, '{}/{}.out.bin'.format(split, split))
        codes = deserialize_from_file(file)
        for code in codes:
            char_len.append(len(code))

    return np.mean(char_len), max(char_len), char_len
コード例 #3
0
    def load_input(self, data_dir, file_name, syntax):
        parents_file = os.path.join(
            data_dir, '{}.in.{}_parents'.format(file_name,
                                                parents_prefix[syntax]))
        tokens_file = os.path.join(data_dir, '{}.in.tokens'.format(file_name))
        strmap_file = os.path.join(data_dir,
                                   '{}.in.strmap.bin'.format(file_name))

        logging.info('Reading query trees...')
        self.query_trees = self.read_query_trees(parents_file)

        logging.info('Reading query tokens...')
        self.queries, self.query_tokens = self.read_query(tokens_file)

        logging.info('Reading strmap...')
        self.strmaps = deserialize_from_file(strmap_file)
コード例 #4
0
def load_test_dataset(data_dir, syntax, max_example_actions_num):
    # all with unary closures
    terminal_vocab_file = os.path.join(data_dir, 'terminal_vocab.txt')
    grammar_file = os.path.join(data_dir, 'grammar.txt.uc.bin')

    grammar = deserialize_from_file(grammar_file)
    terminal_vocab = Vocab(
        terminal_vocab_file,
        data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD])
    vocab = Vocab(
        os.path.join(data_dir, 'vocab.txt'),
        data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD])

    prefix = 'uc_' + syntax + '_'
    test_dir = os.path.join(data_dir, 'test')
    test = Dataset(test_dir, 'test', grammar, vocab, terminal_vocab, syntax,
                   max_example_actions_num, True)
    torch.save(test, test_file)
コード例 #5
0
def load_dataset(config, force_regenerate=False):
    dj_dir = './preprocessed/django'
    logging.info('=' * 80)
    logging.info('Loading datasets from folder ' + dj_dir)
    logging.info('=' * 80)
    train, test, dev = None, None, None
    prefix = config.syntax + '_'
    if config.unary_closures:
        prefix += 'uc_'

    train_dir = os.path.join(dj_dir, 'train')
    train_file = os.path.join(train_dir, prefix + 'train.pth')
    if not force_regenerate and os.path.isfile(train_file):
        logging.info('Train dataset found, loading...')
        train = torch.load(train_file)
        train.config = config

    test_dir = os.path.join(dj_dir, 'test')
    test_file = os.path.join(test_dir, prefix + 'test.pth')
    if not force_regenerate and os.path.isfile(test_file):
        logging.info('Test dataset found, loading...')
        test = torch.load(test_file)
        test.config = config

    dev_dir = os.path.join(dj_dir, 'dev')
    dev_file = os.path.join(dev_dir, prefix + 'dev.pth')
    if not force_regenerate and os.path.isfile(dev_file):
        logging.info('Dev dataset found, loading...')
        dev = torch.load(dev_file)
        dev.config = config

    if train is None or test is None or dev is None:
        terminal_vocab_file = os.path.join(dj_dir, 'terminal_vocab.txt')
        if config.unary_closures:
            grammar_file = os.path.join(dj_dir, 'grammar.txt.uc.bin')
        else:
            grammar_file = os.path.join(dj_dir, 'grammar.txt.bin')

        grammar = deserialize_from_file(grammar_file)
        terminal_vocab = Vocab(
            terminal_vocab_file,
            data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD])
        vocab = Vocab(
            os.path.join(dj_dir, 'vocab.txt'),
            data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD])

        if test is None:
            logging.info('Test dataset not found, generating...')
            test = Dataset(test_dir, 'test', grammar, vocab, terminal_vocab,
                           config.syntax, config.max_example_action_num,
                           config.unary_closures)
            torch.save(test, test_file)

        if dev is None:
            logging.info('Dev dataset not found, generating...')
            dev = Dataset(dev_dir, 'dev', grammar, vocab, terminal_vocab,
                          config.syntax, config.max_example_action_num,
                          config.unary_closures)
            torch.save(dev, dev_file)

        if train is None:
            logging.info('Train dataset not found, generating...')
            train = Dataset(train_dir, 'train', grammar, vocab, terminal_vocab,
                            config.syntax, config.max_example_action_num,
                            config.unary_closures)
            torch.save(train, train_file)

    train.prepare_torch(config.cuda)
    dev.prepare_torch(config.cuda)
    test.prepare_torch(config.cuda)
    return train, dev, test
コード例 #6
0
    def load_output(self, data_dir, file_name):
        logging.info('Reading code files...')
        if self.unary_closures:
            trees_file = '{}.out.trees.uc.bin'.format(file_name)
        else:
            trees_file = '{}.out.trees.bin'.format(file_name)

        trees_file = os.path.join(data_dir, trees_file)
        code_file = os.path.join(data_dir, '{}.out.bin'.format(file_name))
        code_raw_file = os.path.join(data_dir,
                                     '{}.out.raw.bin'.format(file_name))

        self.code_trees = deserialize_from_file(trees_file)
        self.codes = deserialize_from_file(code_file)
        self.codes_raw = deserialize_from_file(code_raw_file)

        logging.info('Constructing code representation...')
        self.actions = []

        for code_tree, query_tokens, query_tree in \
                tqdm(zip(self.code_trees, self.query_tokens, self.query_trees)):

            if code_tree is None or query_tree is None:
                self.actions.append(None)
                continue

            rule_list, rule_parents = code_tree.get_productions(
                include_value_node=True)

            actions = []
            rule_pos_map = dict()

            for rule_count, rule in enumerate(rule_list):
                if not self.grammar.is_value_node(rule.parent):
                    assert rule.value is None
                    parent_rule = rule_parents[(rule_count, rule)][0]
                    if parent_rule:
                        parent_t = rule_pos_map[parent_rule]
                    else:
                        parent_t = 0

                    rule_pos_map[rule] = len(actions)

                    d = {
                        'rule': rule,
                        'parent_t': parent_t,
                        'parent_rule': parent_rule
                    }
                    action = Action(APPLY_RULE, d)

                    actions.append(action)
                else:
                    assert rule.is_leaf

                    parent_rule = rule_parents[(rule_count, rule)][0]
                    parent_t = rule_pos_map[parent_rule]

                    terminal_val = rule.value
                    terminal_str = str(terminal_val)
                    terminal_tokens = get_terminal_tokens(terminal_str)

                    for terminal_token in terminal_tokens:
                        term_tok_id = self.terminal_vocab.getIndex(
                            terminal_token, Constants.UNK)
                        tok_src_idx = -1
                        try:
                            tok_src_idx = query_tokens.index(terminal_token)
                        except ValueError:
                            pass

                        d = {
                            'literal': terminal_token,
                            'rule': rule,
                            'parent_rule': parent_rule,
                            'parent_t': parent_t
                        }

                        # cannot copy, only generation
                        # could be unk!
                        if tok_src_idx < 0 or tok_src_idx > query_tree.size(
                        ) - 1:
                            action = Action(GEN_TOKEN, d)
                        else:  # copy
                            if term_tok_id != Constants.UNK:
                                d['source_idx'] = tok_src_idx
                                action = Action(GEN_COPY_TOKEN, d)
                            else:
                                d['source_idx'] = tok_src_idx
                                action = Action(COPY_TOKEN, d)

                        actions.append(action)

                    d = {
                        'literal': '<eos>',
                        'rule': rule,
                        'parent_rule': parent_rule,
                        'parent_t': parent_t
                    }
                    actions.append(Action(GEN_TOKEN, d))

            if len(actions) == 0:
                continue

            self.actions.append(actions)