예제 #1
0
#
model = ModelRecog()
#

# data
print('loading data ...')
data_train = model_data.load_data(meta.dir_data_train)
data_valid = model_data.load_data(meta.dir_data_valid)
print('load finished.')

#
# train
model.train_and_valid(data_train, data_valid)
#

#
# predict
model.load_pb_for_prediction()
sess = model.create_session_for_prediction()
#
list_images_valid = model_data.getFilesInDirect(meta.dir_images_valid,
                                                meta.str_dot_img_ext)
for img_file in list_images_valid:
    #
    # img_file = './data_test/images/bkgd_1_0_generated_0.png'
    #
    print(img_file)
    #
    model.predict(sess, img_file, out_dir='./results_prediction')
    #
예제 #2
0
# input-output, graph
x = tf.placeholder(tf.float32, (None, None, None, 3), name='x-input')
yT = tf.sparse_placeholder(tf.int32, shape=(None, None), name='y-input')
w = tf.placeholder(tf.int32, (None, ), name='w-input')
#
features, sequence_length = model_recog.conv_feat_layers(
    x, w, learn.ModeKeys.TRAIN)  #INFER
result_logits = model_recog.rnn_recog_layers(features, sequence_length,
                                             num_classes)
#
loss = model_recog.ctc_loss_layer(yT, result_logits, sequence_length)
#
print('graph loaded')
#
# get test images
list_images = model_recog_data.getFilesInDirect(dir_images, str_dot_img_ext)
#
# test_result save-path
if os.path.exists(dir_results): shutil.rmtree(dir_results)
time.sleep(0.1)
os.mkdir(dir_results)
#
# to process
saver = tf.train.Saver()
with tf.Session() as sess:
    #
    tf.global_variables_initializer().run()
    #
    # restore with saved data
    ckpt = tf.train.get_checkpoint_state(model_dir)
    #
    def train_and_valid(self,
                        data_train,
                        data_valid,
                        load_from_pretrained=True):
        #
        # model save-path
        if not os.path.exists(meta.model_recog_dir):
            os.mkdir(meta.model_recog_dir)
        #
        # training graph
        self.z_graph = tf.Graph()
        self.z_define_graph_all(self.z_graph, True)
        #
        # load from pretained
        list_ckpt = model_recog_data.getFilesInDirect(meta.model_recog_dir,
                                                      '.meta')
        #
        print(' ')
        #
        if len(list_ckpt) > 0:
            print(
                'model_recog ckpt already exists, no need to load common tensors.'
            )
        elif load_from_pretrained == False:
            print('not to load common tensors, by manual setting.')
        else:
            print('load common tensors from pretrained detection model.')
            self.z_load_from_pretrained_detection_model()
        print(' ')
        #
        # restore and train
        with self.z_graph.as_default():
            #
            saver = tf.train.Saver()
            with tf.Session(config=self.z_sess_config) as sess:
                #
                tf.global_variables_initializer().run()
                #
                # restore with saved data
                ckpt = tf.train.get_checkpoint_state(meta.model_recog_dir)
                #
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                #
                # variables
                #
                x = self.z_graph.get_tensor_by_name('x-input:0')
                w = self.z_graph.get_tensor_by_name('w-input:0')
                #
                y_s = self.z_graph.get_tensor_by_name('y-input/shape:0')
                y_i = self.z_graph.get_tensor_by_name('y-input/indices:0')
                y_v = self.z_graph.get_tensor_by_name('y-input/values:0')
                #
                # <tf.Operation 'y-input/shape' type=Placeholder>,
                # <tf.Operation 'y-input/values' type=Placeholder>,
                # <tf.Operation 'y-input/indices' type=Placeholder>]
                #
                #conv_feat = self.z_graph.get_tensor_by_name('conv_comm/conv_feat:0')
                #
                loss = self.z_graph.get_tensor_by_name('loss:0')
                #
                global_step = self.z_graph.get_tensor_by_name('global_step:0')
                learning_rate = self.z_graph.get_tensor_by_name(
                    'learning_rate:0')
                train_op = self.z_graph.get_tensor_by_name(
                    'train_op/control_dependency:0')
                #
                print('begin to train ...')
                #
                num_samples = len(data_train['x'])
                #
                # start training
                start_time = time.time()
                begin_time = start_time
                step = 0
                #
                for curr_iter in range(TRAINING_STEPS):
                    #
                    # save and validate
                    if step % self.z_valid_freq == 0:
                        #
                        # ckpt
                        print('save model to ckpt ...')
                        saver.save(sess, os.path.join(meta.model_recog_dir, meta.model_recog_name), \
                                   global_step = step)
                        #
                        # validate
                        print('validating ...')
                        self.validate(data_valid, step)
                        #
                    #
                    # train
                    index_batch = random.sample(range(num_samples),
                                                self.z_batch_size)
                    #
                    images = [data_train['x'][i] for i in index_batch]
                    targets = [data_train['y'][i] for i in index_batch]
                    w_arr = np.ones((self.z_batch_size, ),
                                    dtype=np.float32) * meta.width_norm
                    #
                    # targets_sparse_value
                    tsv = model_def.convert2SparseTensorValue(targets)
                    #
                    #
                    feed_dict = {
                        x: images,
                        w: w_arr,
                        y_s: tsv.dense_shape,
                        y_i: tsv.indices,
                        y_v: tsv.values
                    }
                    #

                    #print(width)
                    #conv_v = sess.run(conv_feat, feed_dict)
                    #print(len(conv_v))

                    # sess.run
                    _, loss_value, step, lr = sess.run([train_op, loss, global_step, learning_rate], \
                                                        feed_dict)
                    #
                    if step % 1 == 0:
                        #
                        curr_time = time.time()
                        #
                        print(
                            'step: %d, loss: %g, lr: %g, sect_time: %.1f, total_time: %.1f'
                            % (step, loss_value, lr, curr_time - begin_time,
                               curr_time - start_time))
                        #
                        begin_time = curr_time