示例#1
0
def nt_seq_to_int(time_steps=50, status='TRAIN'):
    # 对NT seq进行进一步的处理,首先将每个token转换为number,
    # 然后对于train data和valid data将所有ast-seq extend成一个list 便于训练时的格式转换
    # 对于test data,将所有ast-seq append,保留各个ast的独立seq
    tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter(is_lower=False)
    total_num_nt_pair = 0
    if status == 'TRAIN':
        sub_data_dir = sub_train_data_dir
        num_sub_data = num_sub_train_data
        sub_int_data_dir = sub_int_train_dir
    elif status == 'VALID':
        sub_data_dir = sub_valid_data_dir
        num_sub_data = num_sub_valid_data
        sub_int_data_dir = sub_int_valid_dir
    elif status == 'TEST':
        sub_data_dir = sub_test_data_dir
        num_sub_data = num_sub_test_data
        sub_int_data_dir = sub_int_test_dir
    else:
        print('ERROR! Unknown commend!!')
        sys.exit(1)

    def get_subset_data():  # 对每个part的nt_sequence读取并返回,等待进行处理
        for i in range(1, num_sub_data + 1):
            data_path = sub_data_dir + 'part{}.json'.format(i)
            data = utils.pickle_load(data_path)
            yield (i, data)

    subset_generator = get_subset_data()
    for index, data in subset_generator:
        data_seq = []
        for one_ast in data:  # 将每个nt_seq进行截取,并encode成integer,然后保存
            if len(one_ast) < time_steps:  # 该ast大小不足time step 舍去
                continue
            try:
                nt_int_seq = [(nt_token_to_int[n], tt_token_to_int.get(
                    t, tt_token_to_int[unknown_token])) for n, t in one_ast]
            except KeyError:
                print('key error')
                continue
            # 在train和valid中,是直接将所有ast-seq extend到一起,在test中,保留各个ast-seq的独立
            if status == 'TEST':
                data_seq.append(nt_int_seq)
                total_num_nt_pair += len(nt_int_seq)
            else:
                data_seq.extend(nt_int_seq)
                total_num_nt_pair += len(nt_int_seq)

        one_sub_int_data_dir = sub_int_data_dir + 'int_part{}.json'.format(index)
        utils.pickle_save(one_sub_int_data_dir, data_seq)
    # old:14,976,250  new:157,237,460  size of training dataset comparison
    # old: 1,557,285  new: 81,078,099  测试数据集数据量对比
    print('There are {} nt_pair in {} dataset...'.format(total_num_nt_pair, status))
    def __init__(
        self,
        num_ntoken,
        num_ttoken,
    ):
        self.model = RnnModel(num_ntoken, num_ttoken, is_training=False)
        self.log_file = open(completion_log_dir, 'w')
        self.tt_token_to_int, self.tt_int_to_token, self.nt_token_to_int, self.nt_int_to_token = \
            utils.load_dict_parameter(is_lower=False)
        self.session = tf.Session()
        checkpoints_path = tf.train.latest_checkpoint(model_save_dir)
        saver = tf.train.Saver()
        self.generator = DataGenerator()
        saver.restore(self.session, checkpoints_path)
        self.test_log(checkpoints_path + ' is using...')

        self.get_identifer_set()
示例#3
0
def get_one_test_ast():
    """"""
    sys.setrecursionlimit(10000)  # 设置递归最大深度
    print('setrecursionlimit == 10000')
    tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter(
        is_lower=False)

    def nt_seq_to_int(nt_sequence):
        """"""
        try:
            int_seq = [(nt_token_to_int[n],
                        tt_token_to_int.get(t, tt_token_to_int[unknown_token]))
                       for n, t in nt_sequence]
        except KeyError:
            print('key error')
        else:
            return int_seq

    unknown_token = test_setting.unknown_token
    data_path = test_setting.origin_test_data_dir
    total_size = 50000
    file = open(data_path, 'r')
    for i in range(1, total_size + 1):
        try:
            line = file.readline()  # read a lind from file(one ast)
            ast = json.loads(line)  # transform it to json format
            ori_ast = copy.deepcopy(ast)
            binary_tree = data_process.bulid_binary_tree(
                ast)  # AST to binary tree
            prefix, expectation, predict_token_index = ast_to_seq(
                binary_tree)  # binary to nt_sequence
        except UnicodeDecodeError as error:  # arise by readline
            print(error)
        except JSONDecodeError as error:  # arise by json_load
            print(error)
        except RecursionError as error:
            print(error)
        except BaseException:
            pass
            #print('other unknown error, plesae check the code')
        else:
            int_prefix = nt_seq_to_int(prefix)
            if len(int_prefix) != 0:
                yield i, ori_ast, int_prefix, expectation, predict_token_index
示例#4
0
    def __init__(
        self,
        num_ntoken,
        num_ttoken,
    ):
        origin_graph = tf.Graph()
        embedding_graph = tf.Graph()
        with origin_graph.as_default():
            self.origin_model = orgin_model(num_ntoken,
                                            num_ttoken,
                                            is_training=False)
            origin_checkpoints_path = tf.train.latest_checkpoint(
                origin_trained_model_dir)
            saver = tf.train.Saver()
            self.origin_session = tf.Session(graph=origin_graph)
            saver.restore(self.origin_session, origin_checkpoints_path)
        with embedding_graph.as_default():
            self.embedding_model = embedding_model(num_ntoken,
                                                   num_ttoken,
                                                   is_training=False)
            self.embedding_session = tf.Session(graph=embedding_graph)
            embedding_checkpoints_path = tf.train.latest_checkpoint(
                embedding_trained_model_dir)
            saver = tf.train.Saver()
            saver.restore(self.embedding_session, embedding_checkpoints_path)

        self.generator = DataGenerator()
        self.tt_token_to_int, self.tt_int_to_token, self.nt_token_to_int, self.nt_int_to_token = \
            utils.load_dict_parameter(is_lower=False)

        self.n_incorrect = open(
            'temp_data/predict_compare/nt_compare' + str(current_time) +
            '.txt', 'w')
        self.t_incorrect = open(
            'temp_data/predict_compare/tt_compare' + str(current_time) +
            '.txt', 'w')

        self.n_incorrect_pickle_list = []
        self.t_incorrect_pickle_list = []
示例#5
0
        #     valid_n_accuracy += n_accuracy
        #     if valid_step >= valid_times:
        #         break
        #
        # valid_n_accuracy /= valid_step
        # valid_end_time = time.time()
        # valid_log = "VALID epoch:{}/{}  ".format(epoch, self.num_epochs) + \
        #             "global step:{}  ".format(global_step) + \
        #             "valid_nt_accu:{:.2f}%  ".format(valid_n_accuracy * 100) + \
        #             "valid time cost:{:.2f}s".format(valid_end_time - valid_start_time)
        # self.print_and_log(valid_log)

    def print_and_log(self, info):
        try:
            self.log_file.write(info)
            self.log_file.write('\n')
        except BaseException:
            self.log_file = open(training_log_dir, 'w')
            self.log_file.write(info)
            self.log_file.write('\n')
        print(info)


if __name__ == '__main__':
    tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter(
    )
    n_ntoken = len(nt_int_to_token)
    n_ttoken = len(tt_int_to_token)
    model = DoubleLstmModel(n_ntoken, n_ttoken)
    model.train_tt_model()