def image_cnn(image): train_x = tf.placeholder(tf.float32, [None, Model.image_h * Model.image_w * 3]) keep_prob = tf.placeholder(tf.float32) model = CNN() output = model.cnn_model() saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(logDIR) with tf.Session() as sess: saver.restore(sess, ckpt.model_checkpoint_path) predict = tf.argmax(tf.reshape(output, [-1, Model.max_length, Model.len_char]), 2) text_list = sess.run(predict, feed_dict={train_x: [image], keep_prob: 1}) text = text_list[0].tolist() ans = [Model.char_set[char] for i, char in enumerate(text)] return ''.join(ans)
def train_cnn(image_txt, label_txt, image_file): train_x = tf.placeholder(tf.float32, [None, Model.image_h * Model.image_w * 3]) train_y = tf.placeholder(tf.float32, [None, Model.max_length * Model.len_char]) image, label = get_data(image_txt, label_txt, image_file) keep_prob = tf.placeholder(tf.float32) model = CNN() output = model.cnn_model() optimizer, loss = model.model_loss(output) accuracy = model.model_acc(output) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 0 count = 0 while True: batch_x, batch_y = get_next_batch(32, image, label) _, loss_ = sess.run([optimizer, loss], feed_dict={ train_x: batch_x, train_y: batch_y, keep_prob: 0.75 }) print(step, loss_) # 每100 step计算一次准确率 if step % 100 == 0: batch_x_test, batch_y_test = get_next_batch(64, image, label) acc = sess.run(accuracy, feed_dict={ train_x: batch_x_test, train_y: batch_y_test, keep_prob: 1. }) print(step, acc) if step % 1000 == 0: checkpoint_path = os.path.join(logDIR, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) if acc == 1: count = count + 1 if count == 10: checkpoint_path = os.path.join(logDIR, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) break step += 1