Exemplo n.º 1
0
    def __init__(self, args):
        super(ReplaceValOp, self).__init__(args, op_name=OP_REPLACE_VAL)
        v = torch.zeros(Dataset.num_const_values(), args.latent_dim, dtype=t_float)
        glorot_uniform(v)
        self.const_name_embedding = Parameter(v)

        self.const_name_pred = MLP(input_dim=args.latent_dim, 
                                   hidden_dims=[args.latent_dim, Dataset.num_const_values()],
                                   nonlinearity=args.act_func)
Exemplo n.º 2
0
 def get_node_type(self, target_types=None, beam_size=1):
     target_indices = None
     if target_types is not None:
         target_indices = [Dataset.get_id_of_ntype(t) for t in target_types]
     logits = self.node_type_pred(self.states)
     num_tries = min(beam_size, Dataset.num_node_types())
     self._get_fixdim_choices(logits,
                              self.node_type_embedding,
                              target_indices=target_indices,
                              masks=None, 
                              beam_size=beam_size, 
                              num_tries=num_tries)
     found_types = []
     for i in range(len(self.hist_choices)):
         t = Dataset.get_ntype_of_id(self.hist_choices[i][-1])
         self.hist_choices[i][-1] = t
         found_types.append(t)
     return found_types if target_types is None else target_types
Exemplo n.º 3
0
    def forward(self, ll, states, gnn_graphs, node_embedding, sample_indices, init_edits, hist_choices, beam_size=1, pred_gt=False, loc_given=False):
        self.setup_forward(ll, states, gnn_graphs, node_embedding, sample_indices, hist_choices, init_edits)

        # select node name
        if pred_gt:
            target_name_indices = []
            for i, g in enumerate(self.sample_buf):
                if self.cur_edits[i].node_name in g.pg.contents:
                    target_name_idx = g.pg.contents[self.cur_edits[i].node_name].index
                elif self.cur_edits[i].node_name in Dataset.get_value_vocab():
                    target_name_idx = g.num_nodes + Dataset.get_value_vocab()[self.cur_edits[i].node_name]
                else:
                    raise NotImplementedError
                target_name_indices.append(target_name_idx)
        else:
            target_name_indices = None
        self.get_node(self.node_embedding,
                      fn_node_select=lambda bid, node: node.node_type == CONTENT_NODE_TYPE,
                      target_indices=target_name_indices,
                      const_node_embeds=self.const_name_embedding,
                      const_nodes=Dataset._id_value_map,
                      fn_const_node_pred=self.const_name_pred,
                      beam_size=beam_size)

        new_asts = []
        new_refs = []
        for i, g in enumerate(self.sample_buf):
            ast = deepcopy(g.pg.ast)
            target_node, _, target_name = self.hist_choices[i]
            if isinstance(target_name, CgNode):
                target_name = target_name.name
            node_to_be_edit = ast.nodes[target_node.index]
            node_to_be_edit.value = target_name
            g.pg.refs = adjust_refs(deepcopy(g.pg.refs), target_node.index)
            new_asts.append(ast)
            new_refs.append(g.pg.refs)
            target_name = "None" if not target_name else target_name
            ast.append_edit(GraphEditCmd(SEPARATOR.join([OP_REPLACE_VAL, str(target_node.index), target_name])))

        return new_asts, self.ll, self.states, self.sample_indices, new_refs
Exemplo n.º 4
0
import numpy as np
from time import time
import torch
from tqdm import tqdm
from gtrans.eval.utils import ast_acc_cnt, setup_dicts, loc_acc_cnt, val_acc_cnt, type_acc_cnt, op_acc_cnt, get_top_k, get_val
from gtrans.data_process.utils import get_bug_prefix
from gtrans.common.configs import cmd_args
from gtrans.common.dataset import Dataset, GraphEditCmd
from gtrans.model.gtrans_model import GraphTrans
from gtrans.common.consts import DEVICE
from gtrans.common.consts import OP_REPLACE_VAL, OP_ADD_NODE, OP_REPLACE_TYPE, OP_DEL_NODE, OP_NONE

const_val_vocab = np.load(os.path.join(cmd_args.data_root, "vocab_" +
                                       cmd_args.vocab_type + ".npy"),
                          allow_pickle=True).item()
Dataset.set_value_vocab(const_val_vocab)
Dataset.add_value2vocab(None)
Dataset.add_value2vocab("UNKNOWN")

dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type)
dataset.load_partition()

phase = "test"

torch.set_num_threads(1)


def sample_gen(s_list):
    yield s_list

Exemplo n.º 5
0
import numpy as np
import torch
import re
import os
import sys
import glob
from tqdm import tqdm
from gtrans.common.dataset import Dataset
from gtrans.model.gtrans_model import GraphTrans
from gtrans.common.configs import cmd_args
from gtrans.common.consts import DEVICE

const_val_vocab = np.load(os.path.join(cmd_args.data_root, "vocab_" + cmd_args.vocab_type + ".npy"), allow_pickle=True).item()
Dataset.set_value_vocab(const_val_vocab)
Dataset.add_value2vocab(None)
Dataset.add_value2vocab("UNKNOWN")

dataset = Dataset(cmd_args.data_root, cmd_args.gnn_type)
dataset.load_partition()

torch.set_num_threads(1)

reg = re.escape(cmd_args.save_dir) + r"/epoch-([0-9]*).ckpt" # was missing forward slash
# print(reg) # /Users/zhutao/lab/data/small_trainingResult/epoch-([0-9]*).ckpt
loss_file = cmd_args.loss_file
loss_dict = {}

if not os.path.exists(loss_file):
    open(loss_file, "w").close()
else:
    with open(loss_file, "r") as f:
Exemplo n.º 6
0
        os.makedirs(out_dir)
    return out_dir


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.set_num_threads(1)
    torch.manual_seed(cmd_args.seed)
    torch.autograd.set_detect_anomaly(True)

    vocab_name = 'vocab_%s.npy' % cmd_args.vocab_type
    print('loading value vocab from', vocab_name)
    const_val_vocab = np.load(os.path.join(cmd_args.data_root, vocab_name),
                              allow_pickle=True).item()
    Dataset.set_value_vocab(const_val_vocab)
    Dataset.add_value2vocab(None)
    Dataset.add_value2vocab("UNKNOWN")
    print('global value table size', Dataset.num_const_values())

    dataset = Dataset(cmd_args.data_root,
                      cmd_args.gnn_type,
                      data_in_mem=cmd_args.data_in_mem,
                      resampling=cmd_args.resampling)

    f_per_part = 1000
    cur_part = 0
    cnt = 0
    cur_out_dir = get_save_dir(cur_part)
    for s in tqdm(dataset.data_samples):
        for fname, cg in [(s.f_bug, s.buggy_code_graph),
Exemplo n.º 7
0
    torch.set_num_threads(1)
    torch.manual_seed(cmd_args.seed)
    torch.autograd.set_detect_anomaly(True)

    # remove previously saved ckpt files in save_dir if not empty
    if os.listdir(cmd_args.save_dir):
        files = glob.glob(os.path.join(cmd_args.save_dir, '*'))
        for f in files:
            print(f)
            os.remove(f)

    vocab_name = 'vocab_%s.npy' % cmd_args.vocab_type  # vocab_fixes.npy or vocab_full.npy
    print('loading value vocab from', vocab_name)
    const_val_vocab = np.load(os.path.join(cmd_args.data_root, vocab_name),
                              allow_pickle=True).item()
    Dataset.set_value_vocab(const_val_vocab)
    Dataset.add_value2vocab(None)
    Dataset.add_value2vocab("UNKNOWN")
    print('global value table size', Dataset.num_const_values())

    dataset = Dataset(cmd_args.data_root,
                      cmd_args.gnn_type,
                      data_in_mem=cmd_args.data_in_mem,
                      resampling=cmd_args.resampling)

    dataset.load_partition()
    train_gen = dataset.data_gen(cmd_args.batch_size,
                                 phase='train',
                                 infinite=True)

    best_test_loss = None
Exemplo n.º 8
0
 def get_cached_mask(self, i):
     if not i in self.cached_masks:
         self.cached_masks[i] = torch.zeros(Dataset.num_node_types(), dtype=t_float).to(DEVICE)
     return self.cached_masks[i]
Exemplo n.º 9
0
    def forward(self, ll, states, gnn_graphs, node_embedding, sample_indices, init_edits, hist_choices, beam_size=1, pred_gt=False, loc_given=False):
        self.setup_forward(ll, states, gnn_graphs, node_embedding, sample_indices, hist_choices, init_edits)

        # self.node_type, self.node_name, self.parent_id, self.child_rank
        target_types = [e.node_type for e in self.cur_edits] if pred_gt else None
        self.get_node_type(target_types, beam_size=beam_size)

        # select node name
        if pred_gt:
            target_name_indices = []
            for i, g in enumerate(self.sample_buf):
                if self.cur_edits[i].node_name in g.pg.contents:
                    target_name_idx = g.pg.contents[self.cur_edits[i].node_name].index
                elif self.cur_edits[i].node_name in Dataset.get_value_vocab():
                    target_name_idx = g.num_nodes + Dataset.get_value_vocab()[self.cur_edits[i].node_name]
                else:
                    raise NotImplementedError
                target_name_indices.append(target_name_idx)
        else:
            target_name_indices = None

        self.get_node(self.node_embedding,
                      fn_node_select=lambda bid, node: node.node_type == CONTENT_NODE_TYPE,
                      target_indices=target_name_indices,
                      const_node_embeds=self.const_name_embedding,
                      const_nodes=Dataset._id_value_map,
                      fn_const_node_pred=self.const_name_pred,
                      beam_size=beam_size)

        parent_nodes = [self.hist_choices[i][0] for i in range(len(self.hist_choices))]

        for e in self.cur_edits:
            if not hasattr(e, 'child_rank'):
                loc_given = False
                break

        # select left sibling
        if pred_gt or loc_given:
            target_prenode_ids = []
            for i, g in enumerate(self.sample_buf):
                if self.cur_edits[i].child_rank == 0:
                    target_prenode_ids.append(parent_nodes[i].index)
                else:
                    ch_idx = self.cur_edits[i].child_rank - 1
                    par_node = parent_nodes[i].ast_node
                    target_prenode_ids.append(par_node.children[ch_idx].index)
        else:
            target_prenode_ids = None

        self.get_node(self.node_embedding,
                      fn_node_select=lambda bid, node: self._sibling_filter(node, self.sample_buf[bid].pg.ast, parent_nodes[bid].index),
                      target_indices=target_prenode_ids,
                      beam_size=beam_size)

        new_asts = []
        new_refs = []
        for i, g in enumerate(self.sample_buf):
            ast = deepcopy(g.pg.ast)

            parent_node, _, node_type, target_name, pre_node = self.hist_choices[i]
            if isinstance(target_name, CgNode):
                target_name = target_name.name
            if pre_node.index == parent_node.index:
                ch_rank = 0
            else:
                ch_rank = parent_node.ast_node.child_rank(pre_node.ast_node) + 1

            new_node = AstNode(node_type=node_type, value=target_name)
            ast.add_node(new_node)

            p_node = ast.nodes[parent_node.index]
            p_node.add_child(new_node, ch_rank)
            new_asts.append(ast)
            new_refs.append(g.pg.refs)

            if not target_name:
                target_name = "None"

            if SEPARATOR in node_type:
                tmp_type = node_type.split(SEPARATOR)[0]
            else:
                tmp_type = node_type
            e = GraphEditCmd(SEPARATOR.join([OP_ADD_NODE, str(parent_node.index), str(ch_rank), tmp_type, target_name]))
            e.node_type = node_type
            ast.append_edit(e)

        return new_asts, self.ll, self.states, self.sample_indices, new_refs
Exemplo n.º 10
0
 def __init__(self, args, op_name):
     super(TypedGraphOp, self).__init__(args, op_name=op_name)
     self.node_type_pred = MLP(input_dim=args.latent_dim,
                               hidden_dims=[args.latent_dim, Dataset.num_node_types()],
                               nonlinearity=args.act_func)
     self.node_type_embedding = nn.Embedding(Dataset.num_node_types(), args.latent_dim)