示例#1
0
 def predict(self, sess, img_file, out_dir='./results_prediction'):
     #
     # input data
     img_data, feat_size, target_cls, target_ver, target_hor = \
     model_detect_data.getImageAndTargets(img_file, meta.anchor_heights)
     #
     img_size = model_detect_data.getImageSize(img_file)  # width, height
     #
     w_arr = np.ones((feat_size[0], ), dtype=np.int32) * img_size[0]
     #
     # predication_result save-path
     if not os.path.exists(out_dir): os.mkdir(out_dir)
     #
     with self.graph.as_default():
         #
         feed_dict = {self.x: img_data, self.w: w_arr}
         #
         r_cls, r_ver, r_hor = sess.run(
             [self.rnn_cls, self.rnn_ver, self.rnn_hor], feed_dict)
         #
         #
         filename = os.path.basename(img_file)
         arr_str = os.path.splitext(filename)
         #
         # image
         r = Image.fromarray(img_data[0][:, :, 0] * 255).convert('L')
         g = Image.fromarray(img_data[0][:, :, 1] * 255).convert('L')
         b = Image.fromarray(img_data[0][:, :, 2] * 255).convert('L')
         #
         file_target = os.path.join(out_dir, arr_str[0] + '_predict.png')
         img_target = Image.merge("RGB", (r, g, b))
         img_target.save(file_target)
         #
         # trans
         text_bbox = model_detect_data.transResults(r_cls, r_ver, r_hor, \
                                                    meta.anchor_heights, meta.threshold)
         #
         model_detect_data.drawTextBox(file_target, text_bbox)
 def train_and_valid(self):
     #
     # get training images
     list_images_train = model_detect_data.getFilesInDirect(meta.dir_images_train, meta.str_dot_img_ext)
     #      
     # model save-path
     if not os.path.exists(meta.model_detect_dir): os.mkdir(meta.model_detect_dir)
     #                   
     # training graph
     self.z_graph = tf.Graph()
     #
     self.z_define_graph_all(self.z_graph, True)
     #
     with self.z_graph.as_default():
         #
         saver = tf.train.Saver()
         #
         # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.95, allocator_type = 'BFC')
         # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
         #
         with tf.Session(config = self.z_sess_config) as sess:
             #
             tf.global_variables_initializer().run()
             #
             # restore with saved data
             ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir)
             #
             if ckpt and ckpt.model_checkpoint_path:
                 saver.restore(sess, ckpt.model_checkpoint_path)
             #
             
             #
             # variables
             #
             x = self.z_graph.get_tensor_by_name('x-input:0')
             w = self.z_graph.get_tensor_by_name('w-input:0')
             #
             t_cls = self.z_graph.get_tensor_by_name('c-input:0')
             t_ver = self.z_graph.get_tensor_by_name('v-input:0')
             t_hor = self.z_graph.get_tensor_by_name('h-input:0')
             #
             loss = self.z_graph.get_tensor_by_name('loss:0')
             #
             global_step = self.z_graph.get_tensor_by_name('global_step:0')
             learning_rate = self.z_graph.get_tensor_by_name('learning_rate:0')
             train_op = self.z_graph.get_tensor_by_name('train_op/control_dependency:0')
             #
             
             #
             print('begin to train ...')
             #
             # start training
             start_time = time.time()
             begin_time = start_time 
             #
             for i in range(TRAINING_STEPS):
                 #
                 print(list_images_train)
                 img_file = random.choice(list_images_train)
                 #
                 # print(img_file)
                 #
                 # input data
                 print(meta.anchor_heights)
                 img_data, feat_size, target_cls, target_ver, target_hor = \
                 model_detect_data.getImageAndTargets(img_file, meta.anchor_heights)
                 #
                 img_size = model_detect_data.getImageSize(img_file) # width, height
                 #
                 print(feat_size[0])
                 w_arr = np.ones((int(feat_size[0]),), dtype = np.int32) * img_size[0]
                 #
                 #
                 feed_dict = {x: img_data, w: w_arr, \
                              t_cls: target_cls, t_ver: target_ver, t_hor: target_hor}
                 #                    
                 #print(img_data.size)
                 #print(feat_size)
                 #print(w_arr)
                 #
                 #rnn_cls_v = sess.run(seq_len, feed_dict)
                 #print(len(rnn_cls_v))
                 #
                 #loss_value = sess.run(loss, feed_dict)
                 #print('sess.run, loss = %g' % loss_value)
                 #
                 _, loss_value, step, lr = sess.run([train_op, loss, global_step, learning_rate], \
                                                     feed_dict)
                 #
                 if i % 1 == 0:
                     #
                     curr_time = time.time()            
                     #
                     print('step: %d, loss: %g, lr: %g, sect_time: %.1f, total_time: %.1f, %s' %
                           (step, loss_value, lr,
                            curr_time - begin_time,
                            curr_time - start_time,
                            os.path.basename(img_file)))
                     #
                     begin_time = curr_time
                     #                        
                 #
                 # validation
                 if step % self.z_valid_freq == 0:
                     #
                     # ckpt
                     saver.save(sess, os.path.join(meta.model_detect_dir, meta.model_detect_name), \
                                global_step = step)
                     #
                     self.validate(step, self.z_valid_option)
 def validate(self, step, training):
     #
     # get validation images
     list_images_valid = model_detect_data.getFilesInDirect(meta.dir_images_valid, meta.str_dot_img_ext)
     #
     # valid_result save-path
     if not os.path.exists(meta.dir_results_valid): os.mkdir(meta.dir_results_valid)
     #
     # if os.path.exists(dir_results): shutil.rmtree(dir_results)
     # time.sleep(0.1)
     # os.mkdir(dir_results)
     #
     # validation graph
     self.graph = tf.Graph()
     #
     self.z_define_graph_all(self.graph, training)
     #
     with self.graph.as_default():
         #
         saver = tf.train.Saver()
         #
         # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.95)
         # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
         #
         with tf.Session(config = self.z_sess_config) as sess:
             #
             tf.global_variables_initializer().run()
             #
             # restore with saved data
             ckpt = tf.train.get_checkpoint_state(meta.model_detect_dir)
             #
             if ckpt and ckpt.model_checkpoint_path:
                 saver.restore(sess, ckpt.model_checkpoint_path)
             #
             # pb
             constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names = \
                                                                        ['rnn_cls','rnn_ver','rnn_hor'])
             with tf.gfile.FastGFile(self.z_pb_file, mode='wb') as f:
                 f.write(constant_graph.SerializeToString())
             #
             # variables
             #
             x = self.graph.get_tensor_by_name('x-input:0')
             w = self.graph.get_tensor_by_name('w-input:0')
             #
             rnn_cls = self.graph.get_tensor_by_name('rnn_cls:0')
             rnn_ver = self.graph.get_tensor_by_name('rnn_ver:0')
             rnn_hor = self.graph.get_tensor_by_name('rnn_hor:0')
             #
             t_cls = self.graph.get_tensor_by_name('c-input:0')
             t_ver = self.graph.get_tensor_by_name('v-input:0')
             t_hor = self.graph.get_tensor_by_name('h-input:0')
             #
             loss = self.graph.get_tensor_by_name('loss:0')
             #
             # test
             NumImages = len(list_images_valid)
             curr = 0
             for img_file in list_images_valid:
                 #
                 # input data
                 img_data, feat_size, target_cls, target_ver, target_hor = \
                 model_detect_data.getImageAndTargets(img_file, meta.anchor_heights)
                 #
                 img_size = model_detect_data.getImageSize(img_file) # width, height
                 #
                 w_arr = np.ones((int(feat_size[0]),), dtype = np.int32) * img_size[0]
                 #
                 feed_dict = {x: img_data, w: w_arr, \
                              t_cls: target_cls, t_ver: target_ver, t_hor: target_hor}
                 #
                 r_cls, r_ver, r_hor, loss_value = sess.run([rnn_cls, rnn_ver, rnn_hor, loss], feed_dict)
                 #
                 #
                 curr += 1
                 print('curr: %d / %d, loss: %f' % (curr, NumImages, loss_value))
                 #
                 filename = os.path.basename(img_file)
                 arr_str = os.path.splitext(filename)
                 #
                 # image
                 r = Image.fromarray(img_data[0][:,:,0] *255).convert('L')
                 g = Image.fromarray(img_data[0][:,:,1] *255).convert('L')
                 b = Image.fromarray(img_data[0][:,:,2] *255).convert('L')
                 #
                 file_target = os.path.join(meta.dir_results_valid, str(step) + '_' +arr_str[0] + '.png')
                 img_target = Image.merge("RGB", (r, g, b))
                 img_target.save(file_target)
                 #
                 # trans
                 text_bbox = model_detect_data.transResults(r_cls, r_ver, r_hor, \
                                                            meta.anchor_heights, meta.threshold)
                 #
                 model_detect_data.drawTextBox(file_target, text_bbox)
                 #
             #
             print('validation finished')
示例#4
0
    def train(
        self,
        data_path='validation/',
    ):

        list_path = model_detect_data.getFilesInDirect(data_path)

        self.graph = tf.Graph()
        self.define_graph(self.graph)

        with self.graph.as_default():
            with tf.Session(config=self.config) as sess:

                self.get_node()
                self.restore(sess, ckpt_path=CKPT_PATH)
                tf.global_variables_initializer().run()

                total_loss = 0
                for i in range(1, TRAINING_STEPS):

                    img_file = random.choice(list_path)
                    img_data, rate = model_detect_data.transform_image(
                        img_file)

                    feat = sess.run(self.conv_feat,
                                    feed_dict={
                                        self.x: [img_data],
                                    })
                    feat_size = (feat.shape[0], feat.shape[1])

                    img_data, target_cls, target_ver, target_hor = \
                    model_detect_data.getImageAndTargets(img_file, anchor_heights, feat_size)

                    feed_dict = {
                        self.x: img_data,
                        self.t_cls: target_cls,
                        self.t_ver: target_ver,
                        self.t_hor: target_hor,
                        self.sequence_length: np.ones([feat_size[0]]) * 64
                    }

                    _, loss_value, step, lr = sess.run([
                        self.train_op,
                        self.loss,
                        self.global_step,
                        self.learning_rate,
                    ], feed_dict)

                    total_loss += loss_value
                    display = 500
                    if i % display == 0:
                        logging.info(
                            'iter: {:}, loss: {:.4f}, learning_rate: {:g}'.
                            format(step, total_loss / display, lr))
                        # print('step: %d, loss: %g, lr: %g, ' %
                        #       (step, total_loss/50, lr))
                        total_loss = 0

                    if i in [50000, 60000, 70000, 80000, 90000]:
                        self.save(sess=sess, step=i)

                    if i % 5000 == 0:
                        path = 'result/' + str(i) + '/'
                        if not os.path.exists(path): os.mkdir(path)
                        show_list = model_detect_data.getFilesInDirect(
                            'result/')
                        count = 0
                        for show_path in show_list:
                            img_data, rate = model_detect_data.transform_image(
                                show_path)
                            feat = sess.run(self.conv_feat,
                                            feed_dict={
                                                self.x: [img_data],
                                            })
                            feat_size = (feat.shape[0], feat.shape[1])

                            img_data, target_cls, target_ver, target_hor = model_detect_data.getImageAndTargets(
                                show_path, meta.anchor_heights, feat_size)

                            feed_dict = {
                                self.x: img_data,
                                self.t_cls: target_cls,
                                self.t_ver: target_ver,
                                self.t_hor: target_hor,
                                self.sequence_length:
                                np.ones([feat_size[0]]) * 64
                            }

                            r_cls, r_ver, r_hor = sess.run(
                                [self.rnn_cls, self.rnn_ver, self.rnn_hor],
                                feed_dict)

                            file_target = path + str(count) + '.png'
                            count += 1
                            img, rate = model_detect_data.transform_image(
                                show_path)
                            misc.imsave(file_target, img)
                            # trans
                            text_bbox = model_detect_data.transResults(
                                r_cls, r_ver, r_hor, anchor_heights,
                                meta.threshold)
                            #
                            model_detect_data.drawTextBox(
                                file_target, text_bbox)
 if ckpt and ckpt.model_checkpoint_path:
     saver.restore(sess, ckpt.model_checkpoint_path)
 #
 # start training
 start_time = time.time()
 begin_time = start_time
 #
 for i in range(TRAINING_STEPS):
     #
     img_file = random.choice(list_images_train)
     #
     #print(img_file)
     #
     # input data
     img_data, feat_size, target_cls, target_ver, target_hor = \
           model_detect_data.getImageAndTargets(img_file, anchor_heights)
     #
     img_size = model_detect_data.getImageSize(img_file)  # width, height
     #
     w_arr = np.ones((feat_size[0], ), dtype=np.int32) * img_size[0]
     #
     #
     feed_dict = {x: img_data, w: w_arr, \
                  t_cls: target_cls, t_ver: target_ver, t_hor: target_hor}
     #
     _, loss_value, step, lr = sess.run(
         [train_op, loss, global_step, learning_rate], feed_dict)
     #
     if i % 1 == 0:
         #
         saver.save(sess,