def image_cnn(image):
    train_x = tf.placeholder(tf.float32, [None, Model.image_h * Model.image_w * 3])
    keep_prob = tf.placeholder(tf.float32)
    model = CNN()
    output = model.cnn_model()
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(logDIR)
    with tf.Session() as sess:
        saver.restore(sess, ckpt.model_checkpoint_path)
        predict = tf.argmax(tf.reshape(output, [-1, Model.max_length, Model.len_char]), 2)
        text_list = sess.run(predict, feed_dict={train_x: [image], keep_prob: 1})
        text = text_list[0].tolist()
        ans = [Model.char_set[char] for i, char in enumerate(text)]
    return ''.join(ans)
def train_cnn(image_txt, label_txt, image_file):
    train_x = tf.placeholder(tf.float32,
                             [None, Model.image_h * Model.image_w * 3])
    train_y = tf.placeholder(tf.float32,
                             [None, Model.max_length * Model.len_char])
    image, label = get_data(image_txt, label_txt, image_file)
    keep_prob = tf.placeholder(tf.float32)
    model = CNN()
    output = model.cnn_model()
    optimizer, loss = model.model_loss(output)
    accuracy = model.model_acc(output)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        step = 0
        count = 0
        while True:
            batch_x, batch_y = get_next_batch(32, image, label)
            _, loss_ = sess.run([optimizer, loss],
                                feed_dict={
                                    train_x: batch_x,
                                    train_y: batch_y,
                                    keep_prob: 0.75
                                })
            print(step, loss_)
            # 每100 step计算一次准确率
            if step % 100 == 0:
                batch_x_test, batch_y_test = get_next_batch(64, image, label)
                acc = sess.run(accuracy,
                               feed_dict={
                                   train_x: batch_x_test,
                                   train_y: batch_y_test,
                                   keep_prob: 1.
                               })
                print(step, acc)
                if step % 1000 == 0:
                    checkpoint_path = os.path.join(logDIR, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                if acc == 1:
                    count = count + 1
                if count == 10:
                    checkpoint_path = os.path.join(logDIR, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
                    break
            step += 1