def __init__(self, args): super(ReplaceValOp, self).__init__(args, op_name=OP_REPLACE_VAL) v = torch.zeros(Dataset.num_const_values(), args.latent_dim, dtype=t_float) glorot_uniform(v) self.const_name_embedding = Parameter(v) self.const_name_pred = MLP(input_dim=args.latent_dim, hidden_dims=[args.latent_dim, Dataset.num_const_values()], nonlinearity=args.act_func)
def get_node_type(self, target_types=None, beam_size=1): target_indices = None if target_types is not None: target_indices = [Dataset.get_id_of_ntype(t) for t in target_types] logits = self.node_type_pred(self.states) num_tries = min(beam_size, Dataset.num_node_types()) self._get_fixdim_choices(logits, self.node_type_embedding, target_indices=target_indices, masks=None, beam_size=beam_size, num_tries=num_tries) found_types = [] for i in range(len(self.hist_choices)): t = Dataset.get_ntype_of_id(self.hist_choices[i][-1]) self.hist_choices[i][-1] = t found_types.append(t) return found_types if target_types is None else target_types
def forward(self, ll, states, gnn_graphs, node_embedding, sample_indices, init_edits, hist_choices, beam_size=1, pred_gt=False, loc_given=False): self.setup_forward(ll, states, gnn_graphs, node_embedding, sample_indices, hist_choices, init_edits) # select node name if pred_gt: target_name_indices = [] for i, g in enumerate(self.sample_buf): if self.cur_edits[i].node_name in g.pg.contents: target_name_idx = g.pg.contents[self.cur_edits[i].node_name].index elif self.cur_edits[i].node_name in Dataset.get_value_vocab(): target_name_idx = g.num_nodes + Dataset.get_value_vocab()[self.cur_edits[i].node_name] else: raise NotImplementedError target_name_indices.append(target_name_idx) else: target_name_indices = None self.get_node(self.node_embedding, fn_node_select=lambda bid, node: node.node_type == CONTENT_NODE_TYPE, target_indices=target_name_indices, const_node_embeds=self.const_name_embedding, const_nodes=Dataset._id_value_map, fn_const_node_pred=self.const_name_pred, beam_size=beam_size) new_asts = [] new_refs = [] for i, g in enumerate(self.sample_buf): ast = deepcopy(g.pg.ast) target_node, _, target_name = self.hist_choices[i] if isinstance(target_name, CgNode): target_name = target_name.name node_to_be_edit = ast.nodes[target_node.index] node_to_be_edit.value = target_name g.pg.refs = adjust_refs(deepcopy(g.pg.refs), target_node.index) new_asts.append(ast) new_refs.append(g.pg.refs) target_name = "None" if not target_name else target_name ast.append_edit(GraphEditCmd(SEPARATOR.join([OP_REPLACE_VAL, str(target_node.index), target_name]))) return new_asts, self.ll, self.states, self.sample_indices, new_refs
import numpy as np from time import time import torch from tqdm import tqdm from gtrans.eval.utils import ast_acc_cnt, setup_dicts, loc_acc_cnt, val_acc_cnt, type_acc_cnt, op_acc_cnt, get_top_k, get_val from gtrans.data_process.utils import get_bug_prefix from gtrans.common.configs import cmd_args from gtrans.common.dataset import Dataset, GraphEditCmd from gtrans.model.gtrans_model import GraphTrans from gtrans.common.consts import DEVICE from gtrans.common.consts import OP_REPLACE_VAL, OP_ADD_NODE, OP_REPLACE_TYPE, OP_DEL_NODE, OP_NONE const_val_vocab = np.load(os.path.join(cmd_args.data_root, "vocab_" + cmd_args.vocab_type + ".npy"), allow_pickle=True).item() Dataset.set_value_vocab(const_val_vocab) Dataset.add_value2vocab(None) Dataset.add_value2vocab("UNKNOWN") dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type) dataset.load_partition() phase = "test" torch.set_num_threads(1) def sample_gen(s_list): yield s_list
import numpy as np import torch import re import os import sys import glob from tqdm import tqdm from gtrans.common.dataset import Dataset from gtrans.model.gtrans_model import GraphTrans from gtrans.common.configs import cmd_args from gtrans.common.consts import DEVICE const_val_vocab = np.load(os.path.join(cmd_args.data_root, "vocab_" + cmd_args.vocab_type + ".npy"), allow_pickle=True).item() Dataset.set_value_vocab(const_val_vocab) Dataset.add_value2vocab(None) Dataset.add_value2vocab("UNKNOWN") dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type) dataset.load_partition() torch.set_num_threads(1) reg = re.escape(cmd_args.save_dir) + r"/epoch-([0-9]*).ckpt" # was missing forward slash # print(reg) # /Users/zhutao/lab/data/small_trainingResult/epoch-([0-9]*).ckpt loss_file = cmd_args.loss_file loss_dict = {} if not os.path.exists(loss_file): open(loss_file, "w").close() else: with open(loss_file, "r") as f:
os.makedirs(out_dir) return out_dir if __name__ == '__main__': random.seed(cmd_args.seed) np.random.seed(cmd_args.seed) torch.set_num_threads(1) torch.manual_seed(cmd_args.seed) torch.autograd.set_detect_anomaly(True) vocab_name = 'vocab_%s.npy' % cmd_args.vocab_type print('loading value vocab from', vocab_name) const_val_vocab = np.load(os.path.join(cmd_args.data_root, vocab_name), allow_pickle=True).item() Dataset.set_value_vocab(const_val_vocab) Dataset.add_value2vocab(None) Dataset.add_value2vocab("UNKNOWN") print('global value table size', Dataset.num_const_values()) dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type, data_in_mem=cmd_args.data_in_mem, resampling=cmd_args.resampling) f_per_part = 1000 cur_part = 0 cnt = 0 cur_out_dir = get_save_dir(cur_part) for s in tqdm(dataset.data_samples): for fname, cg in [(s.f_bug, s.buggy_code_graph),
torch.set_num_threads(1) torch.manual_seed(cmd_args.seed) torch.autograd.set_detect_anomaly(True) # remove previously saved ckpt files in save_dir if not empty if os.listdir(cmd_args.save_dir): files = glob.glob(os.path.join(cmd_args.save_dir, '*')) for f in files: print(f) os.remove(f) vocab_name = 'vocab_%s.npy' % cmd_args.vocab_type # vocab_fixes.npy or vocab_full.npy print('loading value vocab from', vocab_name) const_val_vocab = np.load(os.path.join(cmd_args.data_root, vocab_name), allow_pickle=True).item() Dataset.set_value_vocab(const_val_vocab) Dataset.add_value2vocab(None) Dataset.add_value2vocab("UNKNOWN") print('global value table size', Dataset.num_const_values()) dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type, data_in_mem=cmd_args.data_in_mem, resampling=cmd_args.resampling) dataset.load_partition() train_gen = dataset.data_gen(cmd_args.batch_size, phase='train', infinite=True) best_test_loss = None
def get_cached_mask(self, i): if not i in self.cached_masks: self.cached_masks[i] = torch.zeros(Dataset.num_node_types(), dtype=t_float).to(DEVICE) return self.cached_masks[i]
def forward(self, ll, states, gnn_graphs, node_embedding, sample_indices, init_edits, hist_choices, beam_size=1, pred_gt=False, loc_given=False): self.setup_forward(ll, states, gnn_graphs, node_embedding, sample_indices, hist_choices, init_edits) # self.node_type, self.node_name, self.parent_id, self.child_rank target_types = [e.node_type for e in self.cur_edits] if pred_gt else None self.get_node_type(target_types, beam_size=beam_size) # select node name if pred_gt: target_name_indices = [] for i, g in enumerate(self.sample_buf): if self.cur_edits[i].node_name in g.pg.contents: target_name_idx = g.pg.contents[self.cur_edits[i].node_name].index elif self.cur_edits[i].node_name in Dataset.get_value_vocab(): target_name_idx = g.num_nodes + Dataset.get_value_vocab()[self.cur_edits[i].node_name] else: raise NotImplementedError target_name_indices.append(target_name_idx) else: target_name_indices = None self.get_node(self.node_embedding, fn_node_select=lambda bid, node: node.node_type == CONTENT_NODE_TYPE, target_indices=target_name_indices, const_node_embeds=self.const_name_embedding, const_nodes=Dataset._id_value_map, fn_const_node_pred=self.const_name_pred, beam_size=beam_size) parent_nodes = [self.hist_choices[i][0] for i in range(len(self.hist_choices))] for e in self.cur_edits: if not hasattr(e, 'child_rank'): loc_given = False break # select left sibling if pred_gt or loc_given: target_prenode_ids = [] for i, g in enumerate(self.sample_buf): if self.cur_edits[i].child_rank == 0: target_prenode_ids.append(parent_nodes[i].index) else: ch_idx = self.cur_edits[i].child_rank - 1 par_node = parent_nodes[i].ast_node target_prenode_ids.append(par_node.children[ch_idx].index) else: target_prenode_ids = None self.get_node(self.node_embedding, fn_node_select=lambda bid, node: self._sibling_filter(node, self.sample_buf[bid].pg.ast, parent_nodes[bid].index), target_indices=target_prenode_ids, beam_size=beam_size) new_asts = [] new_refs = [] for i, g in enumerate(self.sample_buf): ast = deepcopy(g.pg.ast) parent_node, _, node_type, target_name, pre_node = self.hist_choices[i] if isinstance(target_name, CgNode): target_name = target_name.name if pre_node.index == parent_node.index: ch_rank = 0 else: ch_rank = parent_node.ast_node.child_rank(pre_node.ast_node) + 1 new_node = AstNode(node_type=node_type, value=target_name) ast.add_node(new_node) p_node = ast.nodes[parent_node.index] p_node.add_child(new_node, ch_rank) new_asts.append(ast) new_refs.append(g.pg.refs) if not target_name: target_name = "None" if SEPARATOR in node_type: tmp_type = node_type.split(SEPARATOR)[0] else: tmp_type = node_type e = GraphEditCmd(SEPARATOR.join([OP_ADD_NODE, str(parent_node.index), str(ch_rank), tmp_type, target_name])) e.node_type = node_type ast.append_edit(e) return new_asts, self.ll, self.states, self.sample_indices, new_refs
def __init__(self, args, op_name): super(TypedGraphOp, self).__init__(args, op_name=op_name) self.node_type_pred = MLP(input_dim=args.latent_dim, hidden_dims=[args.latent_dim, Dataset.num_node_types()], nonlinearity=args.act_func) self.node_type_embedding = nn.Embedding(Dataset.num_node_types(), args.latent_dim)