def validate(self, data_valid, step): # # valid_result save-path if not os.path.exists(meta.dir_results_valid): os.mkdir(meta.dir_results_valid) # self.create_graph_all(training = False) # with self.graph.as_default(): # saver = tf.train.Saver() with tf.Session(config = self.sess_config) as sess: # tf.global_variables_initializer().run() #sess.run(tf.assign(self.is_train, tf.constant(False, dtype=tf.bool))) # # restore with saved data ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir) # if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # # pb constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names = \ ['rnn_cls','rnn_ver','rnn_hor']) with tf.gfile.FastGFile(self.pb_file, mode='wb') as f: f.write(constant_graph.SerializeToString()) # # test NumImages = len(data_valid) curr = 0 for img_file in data_valid: # print(img_file) # txt_file = model_data.get_target_txt_file(img_file) # # input data img_data, feat_size, target_cls, target_ver, target_hor = \ model_data.get_image_and_targets(img_file, txt_file, meta.anchor_heights) # img_size = img_data[0].shape # height, width, channel # w_arr = np.array([ img_size[1] ], dtype = np.int32) # # feed_dict = {self.x: img_data, self.w: w_arr, \ self.t_cls: target_cls, self.t_ver: target_ver, self.t_hor: target_hor} # r_cls, r_ver, r_hor, loss_value = sess.run([self.rnn_cls, self.rnn_ver, self.rnn_hor, self.loss], feed_dict) # # curr += 1 print('curr: %d / %d, loss: %f' % (curr, NumImages, loss_value)) # # trans text_bbox, conf_bbox = model_data.trans_results(r_cls, r_ver, r_hor, \ meta.anchor_heights, meta.threshold) # conn_bbox = model_data.do_nms_and_connection(text_bbox, conf_bbox) # # image # filename = os.path.basename(img_file) file_target = os.path.join(meta.dir_results_valid, str(step) + '_predicted_' + filename) img_target = Image.fromarray(np.uint8(img_data[0] *255) ) #.convert('RGB') img_target.save(file_target) model_data.draw_text_boxes(file_target, text_bbox) # id_remove = step - self.valid_freq * self.keep_near if id_remove % self.keep_freq: file_temp = os.path.join(meta.dir_results_valid, str(id_remove) + '_predicted_' + filename) if os.path.exists(file_temp): os.remove(file_temp) # # print('validation finished')
def train_and_valid(self, data_train, data_valid): # # model save-path if not os.path.exists(meta.model_detect_dir): os.mkdir(meta.model_detect_dir) # # graph self.create_graph_all(training = True) # # restore and train with self.graph.as_default(): # saver = tf.train.Saver() with tf.Session(config = self.sess_config) as sess: # tf.global_variables_initializer().run() sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base, dtype=tf.float32))) # # restore net structure for viewing in tensorboard writer = tf.summary.FileWriter("D:/ocr_codes/OCR-DETECTION-CTPN-master/tensorboard_view", sess.graph) # restore with saved data ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir) # if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # print('begin to train ...') # # start training start_time = time.time() begin_time = start_time # step = sess.run(self.global_step) # train_step_half = int(self.train_steps * 0.5) train_step_quar = int(self.train_steps * 0.75) # while step < self.train_steps: # if step == train_step_half: sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base/10, dtype=tf.float32))) if step == train_step_quar: sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base/100, dtype=tf.float32))) # # save and validation if step % self.valid_freq == 0: # print('save model to ckpt ...') saver.save(sess, os.path.join(meta.model_detect_dir, meta.model_detect_name), \ global_step = step) # print('validating ...') model_v = ModelDetect() model_v.validate(data_valid, step) # # img_file = random.choice(data_train) # list image files if not os.path.exists(img_file): print('image_file: %s NOT exist' % img_file) continue # txt_file = model_data.get_target_txt_file(img_file) if not os.path.exists(txt_file): print('label_file: %s NOT exist' % txt_file) continue # # input data img_data, feat_size, target_cls, target_ver, target_hor = \ model_data.get_image_and_targets(img_file, txt_file, meta.anchor_heights) # img_size = img_data[0].shape # height, width, channel # w_arr = np.array([ img_size[1] ], dtype = np.int32) # # feed_dict = {self.x: img_data, self.w: w_arr, \ self.t_cls: target_cls, self.t_ver: target_ver, self.t_hor: target_hor} # _, loss_value, step, lr = sess.run([self.train_op, self.loss, self.global_step, self.learning_rate],\ feed_dict) # if step % self.loss_freq == 0: # curr_time = time.time() # print('step: %d, loss: %g, lr: %g, sect_time: %.1f, total_time: %.1f, %s' % (step, loss_value, lr, curr_time - begin_time, curr_time - start_time, os.path.basename(img_file))) # begin_time = curr_time