def predict(self, img_file, out_dir=None): # # img_data img = Image.open(img_file) img_data = np.array(img, dtype=np.float32) / 255 # height, width, channel # img_data = [img_data[:, :, 0:3]] # rgba img_size = img.size # (width, height) w_arr = np.array([img_size[0]], dtype=np.int32) # with self.graph.as_default(): # feed_dict = {self.x: img_data, self.w: w_arr} # r_cls, r_ver, r_hor = self.sess.run( [self.rnn_cls, self.rnn_ver, self.rnn_hor], feed_dict) # # trans text_bbox, conf_bbox = model_data.trans_results(r_cls, r_ver, r_hor, \ meta.anchor_heights, meta.threshold) # conn_bbox = model_data.do_nms_and_connection(text_bbox, conf_bbox) # if out_dir == None: return conn_bbox, text_bbox, conf_bbox # # # predication_result save-path if not os.path.exists(out_dir): os.mkdir(out_dir) # filename = os.path.basename(img_file) # # image # file_target = os.path.join(out_dir, 'predicted_' + filename) img_target = Image.fromarray(np.uint8(img_data[0] * 255)) #.convert('RGB') img_target.save(file_target) model_data.draw_text_boxes(file_target, text_bbox) # file_target = os.path.join(out_dir, 'connected_' + filename) img_target = Image.fromarray(np.uint8(img_data[0] * 255)) #.convert('RGB') img_target.save(file_target) model_data.draw_text_boxes(file_target, conn_bbox) # return conn_bbox, text_bbox, conf_bbox
def validate(self, data_valid, step): # # valid_result save-path if not os.path.exists(meta.dir_results_valid): os.mkdir(meta.dir_results_valid) # self.create_graph_all(training = False) # with self.graph.as_default(): # saver = tf.train.Saver() with tf.Session(config = self.sess_config) as sess: # tf.global_variables_initializer().run() #sess.run(tf.assign(self.is_train, tf.constant(False, dtype=tf.bool))) # # restore with saved data ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir) # if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # # pb constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names = \ ['rnn_cls','rnn_ver','rnn_hor']) with tf.gfile.FastGFile(self.pb_file, mode='wb') as f: f.write(constant_graph.SerializeToString()) # # test NumImages = len(data_valid) curr = 0 for img_file in data_valid: # print(img_file) # txt_file = model_data.get_target_txt_file(img_file) # # input data img_data, feat_size, target_cls, target_ver, target_hor = \ model_data.get_image_and_targets(img_file, txt_file, meta.anchor_heights) # img_size = img_data[0].shape # height, width, channel # w_arr = np.array([ img_size[1] ], dtype = np.int32) # # feed_dict = {self.x: img_data, self.w: w_arr, \ self.t_cls: target_cls, self.t_ver: target_ver, self.t_hor: target_hor} # r_cls, r_ver, r_hor, loss_value = sess.run([self.rnn_cls, self.rnn_ver, self.rnn_hor, self.loss], feed_dict) # # curr += 1 print('curr: %d / %d, loss: %f' % (curr, NumImages, loss_value)) # # trans text_bbox, conf_bbox = model_data.trans_results(r_cls, r_ver, r_hor, \ meta.anchor_heights, meta.threshold) # conn_bbox = model_data.do_nms_and_connection(text_bbox, conf_bbox) # # image # filename = os.path.basename(img_file) file_target = os.path.join(meta.dir_results_valid, str(step) + '_predicted_' + filename) img_target = Image.fromarray(np.uint8(img_data[0] *255) ) #.convert('RGB') img_target.save(file_target) model_data.draw_text_boxes(file_target, text_bbox) # id_remove = step - self.valid_freq * self.keep_near if id_remove % self.keep_freq: file_temp = os.path.join(meta.dir_results_valid, str(id_remove) + '_predicted_' + filename) if os.path.exists(file_temp): os.remove(file_temp) # # print('validation finished')