def avg_and_max_number_of_actions(data_dir): nodes_numbers = [] for split in splits: file = os.path.join(data_dir, '{}/{}.out.bin'.format(split, split)) codes = deserialize_from_file(file) for code in codes: nodes_numbers.append(number_of_ast_nodes(code)) return np.mean(nodes_numbers), max(nodes_numbers), nodes_numbers
def avg_and_max_number_char_in_code(data_dir): char_len = [] for split in splits: file = os.path.join(data_dir, '{}/{}.out.bin'.format(split, split)) codes = deserialize_from_file(file) for code in codes: char_len.append(len(code)) return np.mean(char_len), max(char_len), char_len
def load_input(self, data_dir, file_name, syntax): parents_file = os.path.join( data_dir, '{}.in.{}_parents'.format(file_name, parents_prefix[syntax])) tokens_file = os.path.join(data_dir, '{}.in.tokens'.format(file_name)) strmap_file = os.path.join(data_dir, '{}.in.strmap.bin'.format(file_name)) logging.info('Reading query trees...') self.query_trees = self.read_query_trees(parents_file) logging.info('Reading query tokens...') self.queries, self.query_tokens = self.read_query(tokens_file) logging.info('Reading strmap...') self.strmaps = deserialize_from_file(strmap_file)
def load_test_dataset(data_dir, syntax, max_example_actions_num): # all with unary closures terminal_vocab_file = os.path.join(data_dir, 'terminal_vocab.txt') grammar_file = os.path.join(data_dir, 'grammar.txt.uc.bin') grammar = deserialize_from_file(grammar_file) terminal_vocab = Vocab( terminal_vocab_file, data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD]) vocab = Vocab( os.path.join(data_dir, 'vocab.txt'), data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD]) prefix = 'uc_' + syntax + '_' test_dir = os.path.join(data_dir, 'test') test = Dataset(test_dir, 'test', grammar, vocab, terminal_vocab, syntax, max_example_actions_num, True) torch.save(test, test_file)
def load_dataset(config, force_regenerate=False): dj_dir = './preprocessed/django' logging.info('=' * 80) logging.info('Loading datasets from folder ' + dj_dir) logging.info('=' * 80) train, test, dev = None, None, None prefix = config.syntax + '_' if config.unary_closures: prefix += 'uc_' train_dir = os.path.join(dj_dir, 'train') train_file = os.path.join(train_dir, prefix + 'train.pth') if not force_regenerate and os.path.isfile(train_file): logging.info('Train dataset found, loading...') train = torch.load(train_file) train.config = config test_dir = os.path.join(dj_dir, 'test') test_file = os.path.join(test_dir, prefix + 'test.pth') if not force_regenerate and os.path.isfile(test_file): logging.info('Test dataset found, loading...') test = torch.load(test_file) test.config = config dev_dir = os.path.join(dj_dir, 'dev') dev_file = os.path.join(dev_dir, prefix + 'dev.pth') if not force_regenerate and os.path.isfile(dev_file): logging.info('Dev dataset found, loading...') dev = torch.load(dev_file) dev.config = config if train is None or test is None or dev is None: terminal_vocab_file = os.path.join(dj_dir, 'terminal_vocab.txt') if config.unary_closures: grammar_file = os.path.join(dj_dir, 'grammar.txt.uc.bin') else: grammar_file = os.path.join(dj_dir, 'grammar.txt.bin') grammar = deserialize_from_file(grammar_file) terminal_vocab = Vocab( terminal_vocab_file, data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD]) vocab = Vocab( os.path.join(dj_dir, 'vocab.txt'), data=[Constants.UNK_WORD, Constants.EOS_WORD, Constants.PAD_WORD]) if test is None: logging.info('Test dataset not found, generating...') test = Dataset(test_dir, 'test', grammar, vocab, terminal_vocab, config.syntax, config.max_example_action_num, config.unary_closures) torch.save(test, test_file) if dev is None: logging.info('Dev dataset not found, generating...') dev = Dataset(dev_dir, 'dev', grammar, vocab, terminal_vocab, config.syntax, config.max_example_action_num, config.unary_closures) torch.save(dev, dev_file) if train is None: logging.info('Train dataset not found, generating...') train = Dataset(train_dir, 'train', grammar, vocab, terminal_vocab, config.syntax, config.max_example_action_num, config.unary_closures) torch.save(train, train_file) train.prepare_torch(config.cuda) dev.prepare_torch(config.cuda) test.prepare_torch(config.cuda) return train, dev, test
def load_output(self, data_dir, file_name): logging.info('Reading code files...') if self.unary_closures: trees_file = '{}.out.trees.uc.bin'.format(file_name) else: trees_file = '{}.out.trees.bin'.format(file_name) trees_file = os.path.join(data_dir, trees_file) code_file = os.path.join(data_dir, '{}.out.bin'.format(file_name)) code_raw_file = os.path.join(data_dir, '{}.out.raw.bin'.format(file_name)) self.code_trees = deserialize_from_file(trees_file) self.codes = deserialize_from_file(code_file) self.codes_raw = deserialize_from_file(code_raw_file) logging.info('Constructing code representation...') self.actions = [] for code_tree, query_tokens, query_tree in \ tqdm(zip(self.code_trees, self.query_tokens, self.query_trees)): if code_tree is None or query_tree is None: self.actions.append(None) continue rule_list, rule_parents = code_tree.get_productions( include_value_node=True) actions = [] rule_pos_map = dict() for rule_count, rule in enumerate(rule_list): if not self.grammar.is_value_node(rule.parent): assert rule.value is None parent_rule = rule_parents[(rule_count, rule)][0] if parent_rule: parent_t = rule_pos_map[parent_rule] else: parent_t = 0 rule_pos_map[rule] = len(actions) d = { 'rule': rule, 'parent_t': parent_t, 'parent_rule': parent_rule } action = Action(APPLY_RULE, d) actions.append(action) else: assert rule.is_leaf parent_rule = rule_parents[(rule_count, rule)][0] parent_t = rule_pos_map[parent_rule] terminal_val = rule.value terminal_str = str(terminal_val) terminal_tokens = get_terminal_tokens(terminal_str) for terminal_token in terminal_tokens: term_tok_id = self.terminal_vocab.getIndex( terminal_token, Constants.UNK) tok_src_idx = -1 try: tok_src_idx = query_tokens.index(terminal_token) except ValueError: pass d = { 'literal': terminal_token, 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } # cannot copy, only generation # could be unk! if tok_src_idx < 0 or tok_src_idx > query_tree.size( ) - 1: action = Action(GEN_TOKEN, d) else: # copy if term_tok_id != Constants.UNK: d['source_idx'] = tok_src_idx action = Action(GEN_COPY_TOKEN, d) else: d['source_idx'] = tok_src_idx action = Action(COPY_TOKEN, d) actions.append(action) d = { 'literal': '<eos>', 'rule': rule, 'parent_rule': parent_rule, 'parent_t': parent_t } actions.append(Action(GEN_TOKEN, d)) if len(actions) == 0: continue self.actions.append(actions)