def nt_seq_to_int(time_steps=50, status='TRAIN'): # 对NT seq进行进一步的处理,首先将每个token转换为number, # 然后对于train data和valid data将所有ast-seq extend成一个list 便于训练时的格式转换 # 对于test data,将所有ast-seq append,保留各个ast的独立seq tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter(is_lower=False) total_num_nt_pair = 0 if status == 'TRAIN': sub_data_dir = sub_train_data_dir num_sub_data = num_sub_train_data sub_int_data_dir = sub_int_train_dir elif status == 'VALID': sub_data_dir = sub_valid_data_dir num_sub_data = num_sub_valid_data sub_int_data_dir = sub_int_valid_dir elif status == 'TEST': sub_data_dir = sub_test_data_dir num_sub_data = num_sub_test_data sub_int_data_dir = sub_int_test_dir else: print('ERROR! Unknown commend!!') sys.exit(1) def get_subset_data(): # 对每个part的nt_sequence读取并返回,等待进行处理 for i in range(1, num_sub_data + 1): data_path = sub_data_dir + 'part{}.json'.format(i) data = utils.pickle_load(data_path) yield (i, data) subset_generator = get_subset_data() for index, data in subset_generator: data_seq = [] for one_ast in data: # 将每个nt_seq进行截取,并encode成integer,然后保存 if len(one_ast) < time_steps: # 该ast大小不足time step 舍去 continue try: nt_int_seq = [(nt_token_to_int[n], tt_token_to_int.get( t, tt_token_to_int[unknown_token])) for n, t in one_ast] except KeyError: print('key error') continue # 在train和valid中,是直接将所有ast-seq extend到一起,在test中,保留各个ast-seq的独立 if status == 'TEST': data_seq.append(nt_int_seq) total_num_nt_pair += len(nt_int_seq) else: data_seq.extend(nt_int_seq) total_num_nt_pair += len(nt_int_seq) one_sub_int_data_dir = sub_int_data_dir + 'int_part{}.json'.format(index) utils.pickle_save(one_sub_int_data_dir, data_seq) # old:14,976,250 new:157,237,460 size of training dataset comparison # old: 1,557,285 new: 81,078,099 测试数据集数据量对比 print('There are {} nt_pair in {} dataset...'.format(total_num_nt_pair, status))
def __init__( self, num_ntoken, num_ttoken, ): self.model = RnnModel(num_ntoken, num_ttoken, is_training=False) self.log_file = open(completion_log_dir, 'w') self.tt_token_to_int, self.tt_int_to_token, self.nt_token_to_int, self.nt_int_to_token = \ utils.load_dict_parameter(is_lower=False) self.session = tf.Session() checkpoints_path = tf.train.latest_checkpoint(model_save_dir) saver = tf.train.Saver() self.generator = DataGenerator() saver.restore(self.session, checkpoints_path) self.test_log(checkpoints_path + ' is using...') self.get_identifer_set()
def get_one_test_ast(): """""" sys.setrecursionlimit(10000) # 设置递归最大深度 print('setrecursionlimit == 10000') tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter( is_lower=False) def nt_seq_to_int(nt_sequence): """""" try: int_seq = [(nt_token_to_int[n], tt_token_to_int.get(t, tt_token_to_int[unknown_token])) for n, t in nt_sequence] except KeyError: print('key error') else: return int_seq unknown_token = test_setting.unknown_token data_path = test_setting.origin_test_data_dir total_size = 50000 file = open(data_path, 'r') for i in range(1, total_size + 1): try: line = file.readline() # read a lind from file(one ast) ast = json.loads(line) # transform it to json format ori_ast = copy.deepcopy(ast) binary_tree = data_process.bulid_binary_tree( ast) # AST to binary tree prefix, expectation, predict_token_index = ast_to_seq( binary_tree) # binary to nt_sequence except UnicodeDecodeError as error: # arise by readline print(error) except JSONDecodeError as error: # arise by json_load print(error) except RecursionError as error: print(error) except BaseException: pass #print('other unknown error, plesae check the code') else: int_prefix = nt_seq_to_int(prefix) if len(int_prefix) != 0: yield i, ori_ast, int_prefix, expectation, predict_token_index
def __init__( self, num_ntoken, num_ttoken, ): origin_graph = tf.Graph() embedding_graph = tf.Graph() with origin_graph.as_default(): self.origin_model = orgin_model(num_ntoken, num_ttoken, is_training=False) origin_checkpoints_path = tf.train.latest_checkpoint( origin_trained_model_dir) saver = tf.train.Saver() self.origin_session = tf.Session(graph=origin_graph) saver.restore(self.origin_session, origin_checkpoints_path) with embedding_graph.as_default(): self.embedding_model = embedding_model(num_ntoken, num_ttoken, is_training=False) self.embedding_session = tf.Session(graph=embedding_graph) embedding_checkpoints_path = tf.train.latest_checkpoint( embedding_trained_model_dir) saver = tf.train.Saver() saver.restore(self.embedding_session, embedding_checkpoints_path) self.generator = DataGenerator() self.tt_token_to_int, self.tt_int_to_token, self.nt_token_to_int, self.nt_int_to_token = \ utils.load_dict_parameter(is_lower=False) self.n_incorrect = open( 'temp_data/predict_compare/nt_compare' + str(current_time) + '.txt', 'w') self.t_incorrect = open( 'temp_data/predict_compare/tt_compare' + str(current_time) + '.txt', 'w') self.n_incorrect_pickle_list = [] self.t_incorrect_pickle_list = []
# valid_n_accuracy += n_accuracy # if valid_step >= valid_times: # break # # valid_n_accuracy /= valid_step # valid_end_time = time.time() # valid_log = "VALID epoch:{}/{} ".format(epoch, self.num_epochs) + \ # "global step:{} ".format(global_step) + \ # "valid_nt_accu:{:.2f}% ".format(valid_n_accuracy * 100) + \ # "valid time cost:{:.2f}s".format(valid_end_time - valid_start_time) # self.print_and_log(valid_log) def print_and_log(self, info): try: self.log_file.write(info) self.log_file.write('\n') except BaseException: self.log_file = open(training_log_dir, 'w') self.log_file.write(info) self.log_file.write('\n') print(info) if __name__ == '__main__': tt_token_to_int, tt_int_to_token, nt_token_to_int, nt_int_to_token = utils.load_dict_parameter( ) n_ntoken = len(nt_int_to_token) n_ttoken = len(tt_int_to_token) model = DoubleLstmModel(n_ntoken, n_ttoken) model.train_tt_model()