示例#1
0
def pool_compile_and_save(full_original_code):
    compile_pool = get_compile_pool()
    compile_result_list = list(
        compile_pool.starmap(compile_and_read_error_info,
                             [[i] for i in full_original_code]))
    # compile_result_list = list(itertools.starmap(compile_and_read_error_info, [[i] for i in full_original_code]))
    compile_res_list, compile_info_list = list(zip(*compile_result_list))
    error_list = [extract_error_message(info) for info in compile_info_list]
    return compile_res_list, compile_info_list, error_list
示例#2
0
def compile_code_ids_list(final_output, continue_list, result_list, vocabulary, includes_list, file_path='',
                          target_file_path='main.out', log_file_path='main.log', do_compile_pool=True, need_transform=True):
    compile_pool = get_compile_pool()
    batch_size = len(final_output)
    cur_continue = [True for _ in range(batch_size)]
    cur_result_list = [False for _ in range(batch_size)]
    compile_args_list = []
    code_index_dict = []

    count_i = 0
    for code_list, con, includes in zip(final_output, continue_list, includes_list):
        if not con:
            cur_continue[count_i] = False
            res = result_list[count_i]
            cur_result_list[count_i] = res
            count_i += 1
            continue
        if need_transform:
            code_list = [vocabulary.id_to_word(c) for c in code_list]
        code = ' '.join(code_list)
        for inc in includes:
            code = inc + '\n' + code
        compile_args_list += [(code, file_path, target_file_path, log_file_path)]
        code_index_dict += [count_i]
        count_i += 1
    if do_compile_pool:
        # part_res_list = list(compile_pool.starmap(compile_c_code_by_gcc, compile_args_list))
        part_res_list = list(compile_pool.starmap(compile_and_read_error_info, compile_args_list))
    else:
        # part_res_list = map(compile_c_code_by_gcc_one_arg, compile_args_list)
        part_res_list = map(compile_and_read_error_info_one_arg, compile_args_list)

    error_count_list = [-1 for _ in range(batch_size)]
    for i, (res, msg) in enumerate(part_res_list):
        error_list = extract_error_message(msg)
        act_i = code_index_dict[i]
        cur_result_list[act_i] = res
        c = not res
        cur_continue[act_i] = c
        error_count_list[act_i] = len(error_list)

    return cur_continue, cur_result_list, error_count_list
示例#3
0
def save_addition_data(original_states, states, tokenize_fn, batch_size, file_path, target_file_path, vocabulary=None,
                       max_distande=None, only_error=False, save_list=None):
    from common.reinforcement_generate_util import generate_action_between_two_code
    save_data_dict = {'ac_code': [], 'action_character_list': [], 'includes': [],
                      'error_count': [], 'distance': [], 'id': []}

    ac_code_names_list = original_states['input_seq_name']
    error_code_ids_list = [c[1:l - 1] for c, l in zip(states['input_seq'], states['copy_length'])]
    error_code_names_list = states['input_seq_name']

    for ids, c in zip(error_code_ids_list, states['copy_length']):
        for p in ids:
            if p > 5941:
                a = 1

    error_code_names_list = retokenize_error_code(error_code_names_list, tokenize_fn)

    do_compile_check = True
    if do_compile_check:
        compile_list = ac_code_names_list + error_code_names_list
        continue_list = [True for _ in range(len(compile_list))]
        last_res_list = [False for _ in range(len(compile_list))]
        include_list = original_states['includes'] + original_states['includes']

        _, compile_res_list, _ = compile_code_ids_list(compile_list, continue_list, last_res_list,
                                                    vocabulary=vocabulary,
                                                    includes_list=include_list, file_path=file_path,
                                                    target_file_path=target_file_path, do_compile_pool=True,
                                                    need_transform=False)
        ac_res_list = compile_res_list[:len(ac_code_names_list)]
        error_res_list = compile_res_list[len(ac_code_names_list):]
    else:
        ac_res_list = [True for _ in range(len(ac_code_names_list))]
        error_res_list = [False for _ in range(len(ac_code_names_list))]

    pool = get_compile_pool()
    max_distance_list = [None for _ in range(batch_size)]
    generate_args = list(zip(error_code_names_list, ac_code_names_list, max_distance_list))
    generate_result = list(pool.starmap(generate_action_between_two_code, generate_args))
    # generate_result = list(itertools.starmap(generate_action_between_two_code, generate_args))
    distance_list, action_list = list(zip(*generate_result))

    print_save_data = False
    if print_save_data:
        for i in range(batch_size):
            info('--------------------------- in save data {} batch ------------------------------------'.format(i))
            ac_full_code = ' '.join(ac_code_names_list[i])
            error_full_code = ' '.join(error_code_names_list[i])
            actions = action_list[i]
            dis = distance_list[i]
            info('ac_code : {}'.format(ac_full_code))
            info('err_code: {}'.format(error_full_code))
            info('dis: {}'.format(dis))
            info('actions: {}'.format(str(actions)))
            info('effect batch: {}'.format(save_list[i]))

    a = 1

    if save_list is None:
        save_list = [True for _ in range(len(ac_code_names_list))]

    for ac_code_list, inc, prog_id, ac_res, err_res, actions, dis, sav \
            in zip(ac_code_names_list, original_states['includes'], original_states['id'],
                   ac_res_list, error_res_list, action_list, distance_list, save_list):
        if not sav:
            continue

        if dis < 0:
            continue
        if max_distande is not None and dis > max_distande:
            continue

        if only_error and err_res:
            continue

        # if 0 > dis or dis >= max_generate_distance:
        #     continue

        ac_code = ' '.join(ac_code_list)
        if len(actions) == 0:
            actions = create_random_action(ac_code_list)
        save_data_dict['ac_code'] += [ac_code]
        save_data_dict['action_character_list'] += [actions]
        save_data_dict['includes'] += [inc]
        save_data_dict['error_count'] += [dis]
        save_data_dict['distance'] += [dis]
        save_data_dict['id'] += [prog_id]
    return save_data_dict
    def step(self, actions, states, states_tensor, file_path,
             target_file_path):
        with torch.no_grad():
            # calculate p1 and p2 in code without label
            p1 = (actions[0] - 1).tolist()
            p2 = (actions[1] - 1).tolist()
            ac_action_pos = list(zip(p1, p2))

            # create error code by generate output
            # for i in range(len(self.step_action_list)):
            ori_states = states.copy()
            batch_data, output_ids, effect_sample_output_list_length = \
                self.preprocess_next_input_for_solver_fn(states, states_tensor, actions)

            # recovery not continue records
            for k in batch_data.keys():
                batch_data[k] = [
                    b if c else s for s, b, c, in zip(
                        ori_states[k], batch_data[k], self.continue_list)
                ]
            ori_error_data = batch_data.copy()

            # a = 1
            # for i, c in enumerate(self.continue_list):
            #     if not c:
            #         for k in batch_data.keys():
            #             batch_data[k][i] = ori_states[k][i]

            pool = get_compile_pool()
            batch_size = len(output_ids)

            # generate action between ac code and error code
            max_distance_list = [None for _ in range(batch_size)]
            cur_error_code_ids = batch_data['input_seq']
            cur_error_code_names = batch_data['input_seq_name']
            ac_code_ids = self.ac_batch_data['input_seq']
            ac_code_names = self.ac_batch_data['input_seq_name']
            # generate_args = list(zip(cur_error_code_ids, ac_code_ids, max_distance_list))
            # generate_result_list = list(pool.starmap(generate_action_between_two_code, generate_args))
            # action_list, distance_list = list(zip(*generate_result_list))

            # ac_code_ids = self.ac_batch_data['input_seq']
            prog_id_list = self.ac_batch_data['id']
            includes_list = self.ac_batch_data['includes']
            keyword_ids_list = [self.keyword_ids for _ in range(batch_size)]
            inner_begin_label_list = [
                self.inner_begin_label for _ in range(batch_size)
            ]
            inner_end_label_list = [
                self.inner_end_label for _ in range(batch_size)
            ]
            use_ast_list = [self.use_ast for _ in range(batch_size)]
            vocabulary_list = [self.vocabulary for _ in range(batch_size)]
            # generate_target_by_code_and_actions(ac_code_ids, action_list, prog_id_list, includes_list, self.keyword_ids,
            #                                     self.inner_begin_label, self.inner_end_label)

            generate_args = list(
                zip(cur_error_code_names, ac_code_names, max_distance_list,
                    prog_id_list, includes_list, keyword_ids_list,
                    inner_begin_label_list, inner_end_label_list,
                    self.continue_list, self.last_sample, use_ast_list,
                    vocabulary_list))
            # generate_args = [list(args) for args in generate_args]
            generate_result_list = list(
                pool.starmap(
                    generate_ac_to_error_action_and_create_input_and_target,
                    generate_args))
            # generate_result_list = list(itertools.starmap(generate_ac_to_error_action_and_create_input_and_target, generate_args))
            self.last_sample = generate_result_list
            result = torch.ones(batch_size).byte().to(actions[0].device)
            for one_iterate_sample in get_one_step_of_sample(
                    generate_result_list):
                model_input = self.parse_input_batch_data_for_solver_fn(
                    one_iterate_sample, do_sample=False)
                model_output = self.s_model.forward(*model_input,
                                                    do_sample=False)
                output_records_list = self.create_records_all_output_fn(
                    model_input, model_output, do_sample=False)
                model_target = self.parse_target_batch_data_fn(
                    one_iterate_sample)
                result_list = self.evaluate_output_result_fn(
                    output_records_list, model_target, one_iterate_sample)
                result = result_list & result

            # generate_action_between_two_code(batch_data, self.ac_batch_data, max_distance=0)
            # for i in range(len(self.step_action_list)):
            #     model_input = self.parse_input_batch_data_for_solver_fn(batch_data, do_sample=True)
            #     model_output = self.s_model.forward(*model_input, do_sample=True)
            #     input_data, final_output, output_records = self.solver_create_next_input_batch_fn(batch_data, model_input, model_output, self.continue_list)
            #
            #     _, self.result_list = self.compile_code_ids_fn(final_output, self.continue_list, self.result_list,
            #                                          vocabulary=self.vocabulary,
            #                                          includes_list=self.extract_includes_fn(input_data),
            #                                          file_path=file_path,
            #                                          target_file_path=target_file_path)

            print_output = False
            global count
            count += 1
            if print_output and count % 10 == 0:
                k = 0
                for ori_code_id, ori_error_id, fin_code_id, res in zip(
                        ori_states['input_seq'], ori_error_data['input_seq'],
                        final_output, self.result_list):
                    if not res:
                        ori_code_id = ori_code_id[1:-1]
                        ori_error_id = ori_error_id[1:-1]

                        ori_code_list = [
                            self.vocabulary.id_to_word(c) for c in ori_code_id
                        ]
                        ori_code = ' '.join(ori_code_list)

                        ori_error_list = [
                            self.vocabulary.id_to_word(c) for c in ori_error_id
                        ]
                        ori_error_code = ' '.join(ori_error_list)

                        fin_code_list = [
                            self.vocabulary.id_to_word(c) for c in fin_code_id
                        ]
                        fin_code = ' '.join(fin_code_list)

                        info(
                            '--------------------------- one ------------------------------------'
                        )
                        for a in actions:
                            info(str(a[k]))
                        info('ori_code: ' + ori_code)
                        info('err_code: ' + ori_error_code)
                        info('fin_code: ' + fin_code)
                    k += 1

            reward_list, done_list = self.create_reward_by_compile_fn(
                result, states, actions, self.continue_list)
            self.continue_list = [not done for done in done_list]

            save_list = [reward > 0 for reward in reward_list]
            # done_list = [False for _ in range(len(reward_list))]
        return ori_error_data, reward_list, done_list, {
            'save_list': save_list,
            'ac_action_pos': ac_action_pos,
            'effect_sample_output_list_length':
            effect_sample_output_list_length
        }