예제 #1
0
 def validate(self, data_valid, step):
     #
     # valid_result save-path
     if not os.path.exists(meta.dir_results_valid): os.mkdir(meta.dir_results_valid)
     #
     self.create_graph_all(training = False)
     #
     with self.graph.as_default():
         #
         saver = tf.train.Saver()
         with tf.Session(config = self.sess_config) as sess:                
             #
             tf.global_variables_initializer().run()
             #sess.run(tf.assign(self.is_train, tf.constant(False, dtype=tf.bool)))
             #
             # restore with saved data
             ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir)
             #
             if ckpt and ckpt.model_checkpoint_path:
                 saver.restore(sess, ckpt.model_checkpoint_path)
             #
             # pb
             constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names = \
                                                                        ['rnn_cls','rnn_ver','rnn_hor'])
             with tf.gfile.FastGFile(self.pb_file, mode='wb') as f:
                 f.write(constant_graph.SerializeToString())
             #                
             # test
             NumImages = len(data_valid)
             curr = 0
             for img_file in data_valid:
                 #
                 print(img_file)
                 #
                 txt_file = model_data.get_target_txt_file(img_file)
                 #
                 # input data
                 img_data, feat_size, target_cls, target_ver, target_hor = \
                 model_data.get_image_and_targets(img_file, txt_file, meta.anchor_heights)
                 #
                 img_size = img_data[0].shape   # height, width, channel
                 #
                 w_arr = np.array([ img_size[1] ], dtype = np.int32)
                 #
                 #
                 feed_dict = {self.x: img_data, self.w: w_arr, \
                              self.t_cls: target_cls, self.t_ver: target_ver, self.t_hor: target_hor}
                 #
                 r_cls, r_ver, r_hor, loss_value = sess.run([self.rnn_cls, self.rnn_ver, self.rnn_hor, self.loss], feed_dict)
                 #
                 #
                 curr += 1
                 print('curr: %d / %d, loss: %f' % (curr, NumImages, loss_value))
                 #                    
                 # trans
                 text_bbox, conf_bbox = model_data.trans_results(r_cls, r_ver, r_hor, \
                                                                 meta.anchor_heights, meta.threshold)
                 # conn_bbox = model_data.do_nms_and_connection(text_bbox, conf_bbox)
                 #
                 # image
                 #
                 filename = os.path.basename(img_file)
                 file_target = os.path.join(meta.dir_results_valid, str(step) + '_predicted_' + filename)
                 img_target = Image.fromarray(np.uint8(img_data[0] *255) ) #.convert('RGB')
                 img_target.save(file_target)
                 model_data.draw_text_boxes(file_target, text_bbox)
                 #
                 id_remove = step - self.valid_freq * self.keep_near
                 if id_remove % self.keep_freq:
                     file_temp = os.path.join(meta.dir_results_valid, str(id_remove) + '_predicted_' + filename)
                     if os.path.exists(file_temp): os.remove(file_temp)
                 #
             #
             print('validation finished')
예제 #2
0
    def train_and_valid(self, data_train, data_valid):
        #     
        # model save-path
        if not os.path.exists(meta.model_detect_dir): os.mkdir(meta.model_detect_dir)
        #                   
        # graph
        self.create_graph_all(training = True)
        #
        # restore and train
        with self.graph.as_default():
            #
            saver = tf.train.Saver()
            with tf.Session(config = self.sess_config) as sess:
                #
                tf.global_variables_initializer().run()
                sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base, dtype=tf.float32)))
                #
                # restore net structure for viewing in tensorboard
                writer = tf.summary.FileWriter("D:/ocr_codes/OCR-DETECTION-CTPN-master/tensorboard_view", sess.graph)
                # restore with saved data
                ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir)
                #
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)                
                #
                print('begin to train ...')
                #
                # start training
                start_time = time.time()
                begin_time = start_time 
                #
                step = sess.run(self.global_step)
                #
                train_step_half = int(self.train_steps * 0.5)
                train_step_quar = int(self.train_steps * 0.75)
                #
                while step < self.train_steps:
                    #
                    if step == train_step_half:
                        sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base/10, dtype=tf.float32)))
                    if step == train_step_quar:
                        sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base/100, dtype=tf.float32)))
                    #
                    # save and validation
                    if step % self.valid_freq == 0:
                        #
                        print('save model to ckpt ...')
                        saver.save(sess, os.path.join(meta.model_detect_dir, meta.model_detect_name), \
                                   global_step = step)
                        #
                        print('validating ...')
                        model_v = ModelDetect()
                        model_v.validate(data_valid, step)
                        #

                    #
                    img_file = random.choice(data_train)  # list image files
                    if not os.path.exists(img_file):
                        print('image_file: %s NOT exist' % img_file)
                        continue
                    #
                    txt_file = model_data.get_target_txt_file(img_file)
                    if not os.path.exists(txt_file):
                        print('label_file: %s NOT exist' % txt_file)
                        continue
                    #
                    # input data                    
                    img_data, feat_size, target_cls, target_ver, target_hor = \
                    model_data.get_image_and_targets(img_file, txt_file, meta.anchor_heights)
                    #
                    img_size = img_data[0].shape   # height, width, channel
                    #
                    w_arr = np.array([ img_size[1] ], dtype = np.int32)
                    #
                    #
                    feed_dict = {self.x: img_data, self.w: w_arr, \
                                 self.t_cls: target_cls, self.t_ver: target_ver, self.t_hor: target_hor}
                    #
                    _, loss_value, step, lr = sess.run([self.train_op, self.loss, self.global_step, self.learning_rate],\
                                                       feed_dict)
                    #
                    if step % self.loss_freq == 0:
                        #
                        curr_time = time.time()            
                        #
                        print('step: %d, loss: %g, lr: %g, sect_time: %.1f, total_time: %.1f, %s' %
                              (step, loss_value, lr,
                               curr_time - begin_time,
                               curr_time - start_time,
                               os.path.basename(img_file)))
                        #
                        begin_time = curr_time