def test(ckpt_path, img_path): x = tf.placeholder(shape=[None, 512, 512, 3], dtype=tf.float32) # y = tf.placeholder(shape=[None, 256, 256, 2], dtype=tf.float32) y_pre, end_point = CRAFT_net(x) src_img = cv2.resize(Image.imread(img_path), (512, 512)) textimg = normalizeMeanVariance(src_img) textimg = np.reshape(textimg, (1, 512, 512, 3)) restore = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print('------loading weight------') restore.restore(sess, ckpt_path) print('------complete------') res = sess.run(y_pre, feed_dict={x: textimg}) res = np.reshape(res, (256, 256, 2)) get_result_img(src_img, res[:, :, 0], res[:, :, 1]) res = cv2.resize(res, (512, 512)) score_txt = res[:, :, 0] score_link = res[:, :, 1] plt.imsave('./result/weight.jpg', score_txt) plt.imsave('./result/weight_aff.jpg', score_link)
def train(train=True): x = tf.placeholder(shape=[None, 512, 512, 3], dtype=tf.float32, name='x') y = tf.placeholder(shape=[None, 256, 256, 2], dtype=tf.float32, name='y') y_pre, end_point = CRAFT_net(x) modelpath = './model' loss = MSE_OHEM_Loss(y_pre, y) # char_loss, aff_loss, loss_f = loss(y_pre, y) end_point['loss'] = loss textimg = Image.imread('./te.jpg') textimg1 = np.reshape(textimg, (1, 512, 512, 3)) textimg = normalizeMeanVariance(textimg1) exclude = ['vgg_16/fc6', 'vgg_16/fc7', 'vgg_16/mean_rgb', 'vgg_16/fc8'] include = [ 'vgg_16/conv1/conv1_1', 'vgg_16/conv1/conv1_2', 'vgg_16/conv2/conv2_1', 'vgg_16/conv2/conv2_2' 'vgg_16/conv3/conv3_1', 'vgg_16/conv3/conv3_2', 'vgg_16/conv3/conv3_3', 'vgg_16/conv4/conv4_1', 'vgg_16/conv4/conv4_2', 'vgg_16/conv4/conv4_3', 'vgg_16/conv5/conv5_1', 'vgg_16/conv5/conv5_2', 'vgg_16/conv5/conv5_3' ] variables_to_restore = slim.get_variables_to_restore(include=include) global_step = tf.Variable(0) boundaries = [15000, 25000] learning_rate = [0.001, 0.0001, 0.00001] learning_rate = tf.train.piecewise_constant(global_step, boundaries=boundaries, values=learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) train_step = optimizer.minimize(loss, global_step=global_step) if train: restorer = tf.train.Saver(variables_to_restore) else: restorer = tf.train.Saver() # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.85) saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.98 config.allow_soft_placement = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if train: print('-----load vgg-----') # ckpt = tf.train.get_checkpoint_state(modelpath) restorer.restore(sess, './model/vgg16.ckpt') print('-----load vgg complete-----') print('-----training-----') else: print('-----load ckpt-----') restorer.restore(sess, './demo/CRAFT_15000.ckpt') print('-----load ckpt complete') print('-----training------') batch_size = 3 epoch = 1 data_len = 64735 # 858750 char_loss_t = 0 aff_loss_t = 0 loss_t = 0 for e in range(epoch): gen = generator(shuffle=True, batch_size=batch_size) chkpnt = time.time() start = time.time() for i in range(data_len // batch_size): image, label = next(gen) _, loss_f0, learning_rate0, global_step0 = sess.run( [train_step, loss, learning_rate, global_step], feed_dict={ x: image, y: label }) avg_time = (time.time() - start) start = time.time() print( '\rstep: %2d learning_rate: %4g total_loss: %4g avg_time: %2g' % (global_step0, learning_rate0, loss_f0, avg_time), end='') loss_t += loss_f0 if global_step0 % 100 == 0: avg_loss = loss_t / 100 res = sess.run(y_pre, feed_dict={x: textimg}) get_result_img(textimg1, res[0, :, :, 0], res[0, :, :, 1]) # res = np.clip(res, 0, 1) #res_0, res_1 = text_utils.get_res_hmp(res) plt.imsave('./result/result_c.jpg', cv2.resize(res[0, :, :, 0], (512, 512))) plt.imsave('./result/result_a.jpg', cv2.resize(res[0, :, :, 1], (512, 512))) avg_time = (time.time() - chkpnt) / 100 chkpnt = time.time() print( '\nstep: %2d learning_rate: %4g avg_total_loss: %4g avg_time: %2g' % (global_step0, learning_rate0, avg_loss, avg_time)) char_loss_t = 0 aff_loss_t = 0 loss_t = 0 if global_step0 % 2000 == 0: saver.save(sess, "./demo/CRAFT_%d.ckpt" % (global_step0))