def make_graph_edits(file_tuple): vocab = {} if any([not os.path.isfile(f) for f in file_tuple]): return file_tuple, (None, None), None, ('buggy/fixed/ast_diff file missing', None), vocab f_bug, f_bug_src, f_fixed, f_diff = file_tuple sample_name = get_bug_prefix(f_bug) if os.path.exists( os.path.join(cmd_args.save_dir, sample_name + "_refs.npy")): return file_tuple, (None, None), None, ('Already exists', None), vocab elif f_bug in processed: return file_tuple, (None, None), None, ('Already processed', None), vocab ast_bug = build_ast(f_bug) ast_fixed = build_ast(f_fixed) if not ast_bug or not ast_fixed or ast_bug.num_nodes > cmd_args.max_ast_nodes or ast_fixed.num_nodes > cmd_args.max_ast_nodes: return file_tuple, (None, None), None, ('too many nodes in ast', None), vocab all_nodes = ast_bug.nodes + ast_fixed.nodes for node in all_nodes: if node.value and not node.value in js_keywords: if node.value in vocab.keys(): vocab[node.value] += 1 else: vocab[node.value] = 1 gp = GraphEditParser(ast_bug, ast_fixed, f_diff) buggy_pkl = os.path.join(cmd_args.save_dir, '%s_buggy.pkl' % sample_name) with open(buggy_pkl, 'wb') as f: cp.dump(ast_bug, f) fixed_pkl = os.path.join(cmd_args.save_dir, '%s_fixed.pkl' % sample_name) with open(fixed_pkl, 'wb') as f: cp.dump(ast_fixed, f) buggy_refs = get_ref_edges(f_bug_src, buggy_pkl) return file_tuple, (ast_bug, ast_fixed), buggy_refs, gp.parse_edits(), vocab
def make_graph_edits(file_tuple): vocab = {} for f in file_tuple: if not os.path.isfile(f): print(f) if any([not os.path.isfile(f) for f in file_tuple]): return file_tuple, (None, None), None, ('buggy/fixed/ast_diff file missing', None), vocab f_bug, f_bug_src, f_fixed, f_diff = file_tuple sample_name = get_bug_prefix( f_bug) # e.g. 'SHIFT_01-01-2019:00_6_0selectors' if os.path.exists( os.path.join(cmd_args.save_dir, sample_name + "_refs.npy")): return file_tuple, (None, None), None, ('Already exists', None), vocab elif f_bug in processed: return file_tuple, (None, None), None, ('Already processed', None), vocab ast_bug = build_ast(f_bug) ast_fixed = build_ast(f_fixed) if not ast_bug or not ast_fixed or ast_bug.num_nodes > cmd_args.max_ast_nodes or ast_fixed.num_nodes > cmd_args.max_ast_nodes: return file_tuple, (None, None), None, ('too many nodes in ast', None), vocab gp = GraphEditParser(ast_bug, ast_fixed, f_diff) buggy_pkl = os.path.join(cmd_args.save_dir, '%s_buggy.pkl' % sample_name) with open(buggy_pkl, 'wb') as f: cp.dump(ast_bug, f) fixed_pkl = os.path.join(cmd_args.save_dir, '%s_fixed.pkl' % sample_name) with open(fixed_pkl, 'wb') as f: cp.dump(ast_fixed, f) buggy_refs = get_ref_edges(f_bug_src, buggy_pkl) return file_tuple, (ast_bug, ast_fixed), buggy_refs, gp.parse_edits(), vocab
for e_idx, edit in enumerate(edits): ge = GraphEditCmd(edit) if not (ge.op == OP_REPLACE_VAL or ge.op == OP_ADD_NODE): continue if ge.clean_name: values.append(ge.clean_name) elif cmd_args.vocab_type == "full": if not ast_bug or not ast_fixed: continue all_nodes = ast_bug.nodes + ast_fixed.nodes for node in all_nodes: if node.value: values.append(node.value) sample_name = get_bug_prefix(buggy_file) if error_log is not None: writer.writerow([sample_name, error_log]) else: sample_list.append(sample_name) pbar.set_description('# valid: %d' % len(sample_list)) with open( os.path.join(cmd_args.save_dir, '%s_gedit.txt' % sample_name), 'w') as f: json_arr = [] for row in edits: edit_obj = {} edit_obj["edit"] = row json_arr.append(edit_obj)
def __init__(self, data_root, gnn_type, data_in_mem=False, resampling=False, sample_types=None, lang_dict=None, valpred_type=None, phases=None): self.data_root = data_root self.gnn_type = gnn_type self.data_in_mem = data_in_mem self.resampling = resampling self.sample_types = sample_types Dataset._lang_type = lang_dict print('loading cooked asts and edits') self.data_samples = [] self.sample_index = {} self.sample_edit_type = {} cooked_gen = code_group_generator(data_root, file_suffix=[ '_buggy.pkl', '_fixed.pkl', '_gedit.txt', '_refs.npy' ]) if phases is not None: avail_set = set() for phase in phases: idx_file = os.path.join(self.data_root, '%s.txt' % phase) if not os.path.isfile(idx_file): continue with open(idx_file, 'r') as f: for row in f: sname = row.strip() avail_set.add(sname) fidx = 0 for file_tuple in tqdm(cooked_gen): f_bug, f_fixed, f_diff, b_refs = file_tuple sample_name = get_bug_prefix(f_bug) if phases is not None and sample_name not in avail_set: continue if any([not os.path.isfile(f) for f in file_tuple]): continue sample = DataSample(fidx, f_bug, f_fixed, f_diff, b_refs) if self.resampling or sample_types is not None: s_type = sample.g_edits[0].op if sample_types is not None and not s_type in sample_types: continue self.sample_edit_type[fidx] = s_type self.data_samples.append(sample) self.sample_index[sample_name] = fidx fidx += 1 assert len(self.data_samples) == fidx print(fidx, 'samples loaded.') f_type_vocab = os.path.join(data_root, 'type_vocab.pkl') if os.path.isfile(f_type_vocab): Dataset.load_type_vocab(f_type_vocab) else: print('building vocab and saving to', f_type_vocab) self.build_node_types() with open(f_type_vocab, 'wb') as f: d = {} d['_node_type_dict'] = Dataset._node_type_dict d['_id_ntype_map'] = Dataset._id_ntype_map cp.dump(d, f, cp.HIGHEST_PROTOCOL) if lang_dict is not None and lang_dict != 'None': f_val_dict = os.path.join( data_root, 'val_dict-%s-%s.pkl' % (lang_dict, valpred_type)) if os.path.isfile(f_val_dict): print('loading %s dict from' % lang_dict, f_val_dict) with open(f_val_dict, 'rb') as f: d = cp.load(f) Dataset._lang_dict = d['_lang_dict'] Dataset._id_lang_map = d['_id_lang_map'] else: print('building %s dict and saving to' % lang_dict, f_val_dict) for s in tqdm(self.data_samples): for e in s._gedits: e.clean_unk(s.buggy_code_graph.contents) val = e.node_name if valpred_type == 'node_name' else e.clean_name if val is None: val = 'None' for c in Dataset.split_sentence(val): self.add_language_token(c) with open(f_val_dict, 'wb') as f: d = {} d['_lang_dict'] = Dataset._lang_dict d['_id_lang_map'] = Dataset._id_lang_map cp.dump(d, f, cp.HIGHEST_PROTOCOL) print('language dict size', len(Dataset._lang_dict)) print(Dataset.num_node_types(), 'types of nodes in total.')