コード例 #1
0
    def train(self):
        train, test = self.load_data()
        train_batch, test_batch = self.process_data(train, test)
        model, optimizer, scheduler = self.init_model()
        prev_pplx = sys.maxsize

        for epoch in range(1, self._epoch + 1):
            msg = 'Epoch %d' % epoch
            print_msg(msg, 'INFO')

            self.run_epoch(model,
                           train_batch,
                           epoch,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           mode='Train')
            with torch.no_grad():
                test_pplx = self.run_epoch(model,
                                           test_batch,
                                           epoch,
                                           mode='Test')

            if prev_pplx > test_pplx:
                prev_pplx = test_pplx
                self.save_model(model, epoch)
コード例 #2
0
    def gen_code(self, printer, model):
        stack = []
        ins_cnt = 0
        (seed_name, root, model_input) = self.prepare_seed(model)
        frag, hidden, parent_idx, frag_type = model_input

        while parent_idx != None:
            # Check max insertion condition
            if ins_cnt >= self._max_ins:
                return None
            else:
                ins_cnt += 1

            frag = data2tensor(frag)
            valid_type = frag_type
            parent_idx, frag_type = self.info2tensor(parent_idx, frag_type)
            outputs, hidden = model.run(frag, hidden, parent_idx, frag_type)

            _, cand_tensor = torch.topk(outputs[0][0], self._top_k)
            cand_list = cand_tensor.data.tolist()

            (found, frag, parent_idx,
             frag_type) = self.append_frag(cand_list, valid_type, root, stack)
            if not found:
                msg = 'Failed to select valid frag at %d' % ins_cnt
                print_msg(msg, 'WARN')
                return None

        harness_list = self._harness.get_list(seed_name)
        self.resolve_errors(root, harness_list)

        root = self.postprocess(root, harness_list)
        js_path = printer.ast2code(root)
        return js_path
コード例 #3
0
def construct_map(conf, def_dict):
    no_err_path = os.path.join(conf.data_dir, 'seed')
    no_err_list = os.listdir(no_err_path)

    harness = Harness(conf.seed_dir)
    harness_keys = harness.get_keys()
    num_files = len(harness_keys)

    id_harness_map = {}
    for idx, file_name in enumerate(harness_keys):
        msg = 'Harness [%d/%d] %s' % (idx + 1, num_files, file_name)
        print_msg(msg, 'INFO')

        if file_name not in no_err_list: continue
        for harness_file in harness.get_list(file_name):
            if harness_file not in def_dict: continue
            for def_name in def_dict[harness_file]:
                if (file_name in def_dict and
                        def_name in def_dict[file_name]):
                    continue

                if def_name not in id_harness_map:
                    id_harness_map[def_name] = set()
                id_harness_map[def_name].add(harness_file)
    return id_harness_map
コード例 #4
0
 def resolve_errors(self, root, harness_list):
     try:
         # ID Resolve
         symbols = hoisting(root, ([], []), True)
         resolve_id(root, None, symbols, True, cand=[], hlist=harness_list)
     except ResolveBug as error:
         msg = 'Resolve Failed: {}'.format(error)
         print_msg(msg, 'WARN')
コード例 #5
0
ファイル: map.py プロジェクト: thuanpv/Montage
def build_id_map(conf):
    print_msg('[1/2] Building def dictionary')
    ast_dir = parse_seed(conf)
    def_dict = build_dict(ast_dir)

    print_msg('[2/2] Building ID map')
    id_harness_map = construct_map(conf, def_dict)

    write_map(id_harness_map)
コード例 #6
0
def framentize(ast_path):
    try:
        file_name, ast = load_ast(ast_path)
    except Exception as e:
        print_msg(str(e), "WARN")
        return

    frag_seq = []
    make_frags(ast, frag_seq)
コード例 #7
0
def exec_fuzz(conf):
    if not os.path.exists('fuzz/id_map.py'):
        print_msg(
            'Please build a map for identifiers predefined in the harness files first.',
            'ERROR')
        sys.exit(1)

    from fuzz.fuzz import fuzz
    fuzz(conf)
コード例 #8
0
def pool_map(pool, func, list, **args):
    try:
        func = partial(func, **args)
        return pool.map(func, list)
    except KeyboardInterrupt:
        print_msg('Terminating workers ...', 'INFO')
        pool.terminate()
        pool.join()
        print_msg('Killed processes', 'INFO')
        os.killpg(os.getpid(), signal.SIGKILL)
コード例 #9
0
    def parse(self, seed_dir, ast_dir):
        js_list = list_dir(seed_dir)
        num_js = len(js_list)
        msg = 'Start parsing %d JS files' % (num_js)
        print_msg(msg, 'INFO')

        cmd = ['node', 'utils/parse.js']
        cmd += [seed_dir, ast_dir]
        parser = Popen(cmd, cwd='./', stdin=PIPE, stdout=PIPE, stderr=PIPE)
        parser.wait()
コード例 #10
0
ファイル: fragmentize.py プロジェクト: thuanpv/Montage
def fragmentize(ast_path):
    try:
        file_name, ast = load_ast(ast_path)
    except Exception as e:
        print_msg(str(e), 'WARN')
        return

    frag_seq, frag_info_seq, stack = [], [], []
    node_types = set()
    make_frags(ast, frag_seq, frag_info_seq, node_types, stack)
    return (file_name, frag_seq, frag_info_seq, node_types)
コード例 #11
0
    def strip(self, ast_path):
        parser = SingleParser()
        try:
            _, ast = load_ast(ast_path)
            org_ast = deepcopy(ast)
        except Exception as e:
            print_msg(str(e), 'WARN')
            return

        self.rewrite(ast, parser)
        if org_ast != ast:
            ast_path = self.rewrite_ast(ast, ast_path)
コード例 #12
0
ファイル: normalize.py プロジェクト: hao-wang/vae_montage
def normalize(ast_path):
    try:
        js_name, ast = load_ast(ast_path)
    except Exception as e:
        print_msg(str(e), 'WARN')
        return

    id_dict = {}
    id_cnt = {'v': 0, 'f': 0}
    collect_id(ast, id_dict, id_cnt)
    normalize_id(ast, id_dict)
    write_ast(ast_path, ast)
コード例 #13
0
ファイル: map.py プロジェクト: thuanpv/Montage
def build_dict(ast_dir):
    def_dict = {}
    ast_list = list_dir(ast_dir)
    num_ast = len(ast_list)
    for idx, ast_name in enumerate(ast_list):
        msg = '[%d/%d] %s' % (idx + 1, num_ast, ast_name)
        print_msg(msg, 'INFO')
        js_name, ast = load_ast(ast_name)
        js_name = trim_seed_name(js_name)
        if js_name not in def_dict:
            def_dict[js_name] = set()
        build_def_dict(ast, def_dict[js_name])
    return def_dict
コード例 #14
0
ファイル: execute.py プロジェクト: thuanpv/Montage
def main(pool, conf):
    make_dir(conf.log_dir)

    js_list = []
    for js in list_dir(conf.seed_dir):
        if (js.endswith('.js')
                and os.path.getsize(js) < 30 * 1024):  # Excludes JS over 3KB
            js_list += [js]

    num_js = len(js_list)
    msg = 'Start executing %d JS files' % (num_js)
    print_msg(msg, 'INFO')

    pool_map(pool, exec_eng, js_list, conf=conf)
コード例 #15
0
    def parse(self, seed_dir, ast_dir):
        js_list = list_dir(seed_dir)
        num_js = len(js_list)
        msg = 'Start parsing %d JS files' % (num_js)
        print_msg(msg, 'INFO')

        cmd = ['node', 'utils/parse.js']
        cmd += [seed_dir, ast_dir]
        print(cmd)
        parser = Popen(cmd, cwd='./', stdin=PIPE, stdout=PIPE, stderr=PIPE)
        ## Error would be silented out by default. Uncomment when having problems.
        #for line in parser.stderr.readlines():
        #    print(line)
        #    sys.stderr.flush()
        parser.wait()
コード例 #16
0
    def preprocess(self):
        print_msg('[1/8] Filtering out JS with errors')
        self.remove_js_with_errors()

        print_msg('[2/8] Parsing JS code into ASTs')
        self.parse()

        print_msg('[3/8] Stripping args of eval func calls')
        self.strip_eval()

        print_msg('[4/8] Normalizing identifiers')
        self.normalize_ast()

        print_msg('[5/8] Fragmentizing JS ASTs')
        ast_data = self.fragment_ast()
コード例 #17
0
 def exec_eng(self, js_path):
     cmd = [self._eng_path] + self._opt + [js_path]
     proc = Popen(cmd, cwd=self._seed_dir, stdout=PIPE, stderr=PIPE)
     timer = threading.Timer(self._timeout, kill_proc, [proc])
     timer.start()
     stdout, stderr = proc.communicate()
     timer.cancel()
     if proc.returncode in [-4, -11]:
         log = [self._eng_path] + self._opt
         log += [js_path, str(proc.returncode)]
         log = str.encode(','.join(log) + '\n')
         self._crash_log.write(log)
         msg = 'Found a bug (%s)' % js_path
         print_msg(msg, 'INFO')
     else:
         os.remove(js_path)
コード例 #18
0
def make_dir(dir_path):
    ans = 'y'
    if os.path.exists(dir_path):
        msg = 'Do you want to delete %s? (y/n): ' % (dir_path)
        msg = get_msg(msg, 'WARN')
        ans = input(msg)
        if ans == 'y':
            shutil.rmtree(dir_path)
            os.makedirs(dir_path)
        else:
            if ans != 'n':
                print_msg('Wrong Answer', 'ERROR')
            os._exit(1)
    else:
        os.makedirs(dir_path)
    return dir_path
コード例 #19
0
ファイル: oov.py プロジェクト: hao-wang/vae_montage
def update_ast(seed_dict, frag_list, hash_frag_list, new_frag_dict):
    new_seed_dict = {}
    num_files = len(seed_dict.keys())
    for idx, file_name in enumerate(seed_dict.keys()):
        msg = '[%d/%d] %s' % (idx + 1, num_files, file_name)
        print_msg(msg, 'INFO')

        frag_seq, frag_info_seq = seed_dict[file_name]

        # Update frag_seq
        new_frag_seq = update_frag_seq(frag_seq, new_frag_dict, frag_list,
                                       hash_frag_list)
        # Update frag_info_seq
        new_frag_info_seq = update_frag_info(frag_info_seq, new_frag_dict,
                                             frag_list, hash_frag_list)
        new_seed_dict[file_name] = (new_frag_seq, new_frag_info_seq)
    return new_seed_dict
コード例 #20
0
def make_dir(dir_path):
    ans = 'y'
    if os.path.exists(dir_path):
        msg = 'Do you want to delete %s? (y/n): ' % (dir_path)
        msg = get_msg(msg, 'WARN')
        # ans = input(msg) # vscode jupyter mode does not allow for input, due to some unknown bug
        # print(msg + " (hardcoded to be 'y', otherwise it hangs due to some unknown vscode-jupyter bug)")
        ans = 'y'
        if ans == 'y':
            shutil.rmtree(dir_path)
            os.makedirs(dir_path)
        else:
            if ans != 'n':
                print_msg('Wrong Answer', 'ERROR')
            os._exit(1)
    else:
        os.makedirs(dir_path)
    return dir_path
コード例 #21
0
ファイル: oov.py プロジェクト: hao-wang/vae_montage
def sort_frags(seed_dict):
    all_frags = []
    for frag_seq, _ in seed_dict.values():
        all_frags += frag_seq

    # Count fragments
    counter = collections.Counter(all_frags)
    frag_frq = []
    for key in counter:
        frag_frq += [(key, counter[key])]

    # Sort by frequencies
    frag_frq = sorted(frag_frq, key=lambda x: x[1], reverse=True)
    sorted_frags, frq = zip(*frag_frq)
    oov_idx = frq.index(5)

    msg = 'OOV IDX = %d' % oov_idx
    print_msg(msg, 'INFO')
    return sorted_frags, oov_idx
コード例 #22
0
def main():
    if not torch.cuda.is_available():
        print_msg('Montage only supports CUDA-enabled machines', 'ERROR')
        sys.exit(1)

    # Increase max recursion depth limit
    sys.setrecursionlimit(10000)

    args = get_args()
    config_path = args.config
    conf = Config(config_path)

    if args.opt == 'preprocess':
        exec_preprocess(conf)
    elif args.opt == 'train':
        exec_train(conf)
    elif args.opt == 'fuzz':
        exec_fuzz(conf)
    elif args.opt == 'build_map':
        build_map(conf)
コード例 #23
0
def build_dict(ast_dir):
    '''
    Gather all ast .json files, load them (a dict for each .json), and
    product a map for each file.

    :param ast_dir:
    :return:
    '''
    def_dict = {}
    ast_list = list_dir(ast_dir)
    num_ast = len(ast_list)
    for idx, ast_name in enumerate(ast_list):
        msg = '[%d/%d] %s' % (idx + 1, num_ast, ast_name)
        print_msg(msg, 'INFO')
        js_name, ast = load_ast(ast_name)
        js_name = trim_seed_name(js_name)
        if js_name not in def_dict:
            def_dict[js_name] = set()
        build_def_dict(ast, def_dict[js_name])
    return def_dict
コード例 #24
0
  def print_metrics(self, mode, epoch,
                    total_loss, pplx, total_diff, acc):
    msg = '%s Loss at Epoch %d = %f' % (mode, epoch, total_loss)
    print_msg(msg, 'INFO')

    msg = '%s Perplexity at Epoch %d = %f' % (mode, epoch, pplx)
    print_msg(msg, 'INFO')

    msg = '%s Top K Difference at Epoch %d = %f' % (mode, epoch, total_diff)
    print_msg(msg, 'INFO')

    msg = '%s Accuracy at Epoch %d = %f\n' % (mode, epoch, acc)
    print_msg(msg, 'INFO')
コード例 #25
0
def main(pool, conf):
    """
    Read from js-test-suite, process with the engine, and write stdout/stderr to data/log/*

    :param pool:
    :param conf:
    :return:
    """
    make_dir(conf.log_dir)

    js_list = []
    for js in list_dir(conf.seed_dir):
        if (js.endswith('.js')
                and os.path.getsize(js) < 30 * 1024):  # Excludes JS over 3KB
            js_list += [js]

    num_js = len(js_list)
    msg = 'Start executing %d JS files' % (num_js)
    print_msg(msg, 'INFO')

    pool_map(pool, exec_eng, js_list, conf=conf)
コード例 #26
0
def fragmentize(ast_path):
    """
    For each file, obtain it's fragments, frag_info_seq (what is it? why is it?), and node_types (what use?).

    :param ast_path:
    :return:
        file_name
        frag_seq:
        fraq_info_seq:
        node_types:
    """
    try:
        file_name, ast = load_ast(ast_path)
    except Exception as e:
        print_msg(str(e), 'WARN')
        return

    frag_seq, frag_info_seq, stack = [], [], []
    node_types = set()
    make_frags(ast, frag_seq, frag_info_seq,
               node_types, stack)
    return (file_name,
            frag_seq, frag_info_seq,
            node_types)
コード例 #27
0
ファイル: preprocess.py プロジェクト: thuanpv/Montage
    def preprocess(self):
        print_msg('[1/8] Filtering out JS with errors')
        self.remove_js_with_errors()

        print_msg('[2/8] Parsing JS code into ASTs')
        self.parse()

        print_msg('[3/8] Stripping args of eval func calls')
        self.strip_eval()

        print_msg('[4/8] Normalizing identifiers')
        self.normalize_ast()

        print_msg('[5/8] Fragmentizing JS ASTs')
        ast_data = self.fragment_ast()

        print_msg('[6/8] Aggregating fragments')
        self.aggregate_frags(ast_data)
        self._pool.terminate()

        print_msg('[7/8] Replacing uncommon fragments')
        self.mark_oov()

        print_msg('[8/8] Writing data into files')
        self.write_data()
コード例 #28
0
  def print_config(self):
    vocab_size = len(self._oov_frag_list)
    msg = '# of Vocabularies: %d' % vocab_size
    print_msg(msg, 'INFO')

    msg = 'Embedding Size = %d' % self._emb_size
    print_msg(msg, 'INFO')

    msg = 'Initial LR = %f' % self._lr
    print_msg(msg, 'INFO')

    msg = 'LR Decay Factor = %f' % self._gamma
    print_msg(msg, 'INFO')

    msg = 'Momentum = %f' % self._momentum
    print_msg(msg, 'INFO')

    msg = 'L2 Regularization Penalty = %f\n' % self._weight_decay
    print_msg(msg, 'INFO')