def _inference_crnn_ctc():
    input_image = tf.placeholder(dtype=tf.float32,
                                 shape=[1, _IMAGE_HEIGHT, None, 3])
    char_map_dict = json.load(open(FLAGS.char_map_json_file, 'r'))
    # initialise the net model
    crnn_net = model.CRNNCTCNetwork(phase='test',
                                    hidden_num=FLAGS.lstm_hidden_uints,
                                    layers_num=FLAGS.lstm_hidden_layers,
                                    num_classes=len(char_map_dict.keys()) + 1)

    with tf.variable_scope('CRNN_CTC', reuse=False):
        net_out = crnn_net.build_network(input_image)

    input_sequence_length = tf.placeholder(tf.int32,
                                           shape=[1],
                                           name='input_sequence_length')

    ctc_decoded, ct_log_prob = tf.nn.ctc_beam_search_decoder(
        net_out, input_sequence_length, merge_repeated=True)

    with open(FLAGS.image_list, 'r') as fd:
        image_names = [line.strip() for line in fd.readlines()]

    # set checkpoint saver
    saver = tf.train.Saver()
    save_path = tf.train.latest_checkpoint(FLAGS.model_dir)

    with tf.Session() as sess:
        # restore all variables
        saver.restore(sess=sess, save_path=save_path)

        for image_name in image_names:
            image_path = os.path.join(FLAGS.image_dir, image_name)
            image = cv2.imread(image_path)
            h, w, c = image.shape
            height = _IMAGE_HEIGHT
            width = int(w * height / h)
            image = cv2.resize(image, (width, height))
            image = np.expand_dims(image, axis=0)
            image = np.array(image, dtype=np.float32)
            seq_len = np.array([width / 4], dtype=np.int32)

            preds = sess.run(ctc_decoded,
                             feed_dict={
                                 input_image: image,
                                 input_sequence_length: seq_len
                             })

            preds = _sparse_matrix_to_list(preds[0])

            print('Predict {:s} image as: {:s}'.format(image_name, preds[0]))
def _eval_crnn_ctc():
    tfrecord_path = os.path.join(FLAGS.data_dir, 'validation.tfrecord')
    images, labels, sequence_lengths, imagenames = _read_tfrecord(
        tfrecord_path=tfrecord_path)

    # decode the training data from tfrecords
    batch_images, batch_labels, batch_sequence_lengths, batch_imagenames = tf.train.batch(
        tensors=[images, labels, sequence_lengths, imagenames],
        batch_size=FLAGS.batch_size,
        dynamic_pad=True,
        capacity=1000 + 2 * FLAGS.batch_size,
        num_threads=FLAGS.num_threads)

    input_images = tf.placeholder(tf.float32,
                                  shape=[FLAGS.batch_size, 32, None, 3],
                                  name='input_images')
    input_labels = tf.sparse_placeholder(tf.int32, name='input_labels')
    input_sequence_lengths = tf.placeholder(dtype=tf.int32,
                                            shape=[FLAGS.batch_size],
                                            name='input_sequence_lengths')

    char_map_dict = json.load(open(FLAGS.char_map_json_file, 'r'))
    # initialise the net model
    crnn_net = model.CRNNCTCNetwork(phase='test',
                                    hidden_num=FLAGS.lstm_hidden_uints,
                                    layers_num=FLAGS.lstm_hidden_layers,
                                    num_classes=len(char_map_dict.keys()) + 1)

    with tf.variable_scope('CRNN_CTC', reuse=False):
        net_out = crnn_net.build_network(
            images=input_images, sequence_length=input_sequence_lengths)

    ctc_decoded, ct_log_prob = tf.nn.ctc_beam_search_decoder(
        net_out, input_sequence_lengths, merge_repeated=False)

    # set checkpoint saver
    saver = tf.train.Saver()
    save_path = tf.train.latest_checkpoint(FLAGS.model_dir)

    test_sample_count = 0
    for record in tf.python_io.tf_record_iterator(tfrecord_path):
        test_sample_count += 1
    step_nums = test_sample_count // FLAGS.batch_size

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        # restore all variables
        saver.restore(sess=sess, save_path=save_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        accuracy = []

        for _ in range(step_nums):
            imgs, lbls, seq_lens, names = sess.run([
                batch_images, batch_labels, batch_sequence_lengths,
                batch_imagenames
            ])
            preds = sess.run(ctc_decoded,
                             feed_dict={
                                 input_images: imgs,
                                 input_labels: lbls,
                                 input_sequence_lengths: seq_lens
                             })

            preds = _sparse_matrix_to_list(preds[0], char_map_dict)
            lbls = _sparse_matrix_to_list(lbls, char_map_dict)

            #print(preds)
            #print(lbls)
            for index, lbl in enumerate(lbls):
                pred = preds[index]
                total_count = len(lbl)
                correct_count = 0
                try:
                    for i, tmp in enumerate(lbl):
                        if tmp == pred[i]:
                            correct_count += 1
                except IndexError:
                    continue
                finally:
                    try:
                        accuracy.append(correct_count / total_count)
                    except ZeroDivisionError:
                        if len(pred) == 0:
                            accuracy.append(1)
                        else:
                            accuracy.append(0)

            for index, img in enumerate(imgs):
                print(
                    'Predict {:s} image with gt label: {:s} <--> predict label: {:s}'
                    .format(str(names[index]), str(lbls[index]),
                            str(preds[index])))

        accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)
        print('Mean test accuracy is {:5f}'.format(accuracy))

        # stop file queue
        coord.request_stop()
        coord.join(threads=threads)
def _train_crnn_ctc():
    tfrecord_path = os.path.join(FLAGS.data_dir, 'train.tfrecord')
    images, labels, sequence_lengths, _ = _read_tfrecord(
        tfrecord_path=tfrecord_path)

    # decode the training data from tfrecords
    batch_images, batch_labels, batch_sequence_lengths = tf.train.batch(
        tensors=[images, labels, sequence_lengths],
        batch_size=FLAGS.batch_size,
        dynamic_pad=True,
        capacity=1000 + 2 * FLAGS.batch_size,
        num_threads=FLAGS.num_threads)

    input_images = tf.placeholder(tf.float32,
                                  shape=[FLAGS.batch_size, 32, None, 3],
                                  name='input_images')
    input_labels = tf.sparse_placeholder(tf.int32, name='input_labels')
    input_sequence_lengths = tf.placeholder(dtype=tf.int32,
                                            shape=[FLAGS.batch_size],
                                            name='input_sequence_lengths')

    char_map_dict = json.load(open(FLAGS.char_map_json_file, 'r'))
    # initialise the net model
    crnn_net = model.CRNNCTCNetwork(phase='train',
                                    hidden_num=FLAGS.lstm_hidden_uints,
                                    layers_num=FLAGS.lstm_hidden_layers,
                                    num_classes=len(char_map_dict.keys()) + 1)

    with tf.variable_scope('CRNN_CTC', reuse=False):
        net_out = crnn_net.build_network(
            images=input_images, sequence_length=input_sequence_lengths)

    ctc_loss = tf.reduce_mean(
        tf.nn.ctc_loss(labels=input_labels,
                       inputs=net_out,
                       sequence_length=input_sequence_lengths,
                       ignore_longer_outputs_than_inputs=True))

    ctc_decoded, ct_log_prob = tf.nn.ctc_beam_search_decoder(
        net_out, input_sequence_lengths, merge_repeated=False)

    sequence_distance = tf.reduce_mean(
        tf.edit_distance(tf.cast(ctc_decoded[0], tf.int32), input_labels))

    global_step = tf.train.create_global_step()

    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_rate,
                                               staircase=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdadeltaOptimizer(
            learning_rate=learning_rate).minimize(loss=ctc_loss,
                                                  global_step=global_step)

    init_op = tf.global_variables_initializer()

    # set tf summary
    tf.summary.scalar(name='CTC_Loss', tensor=ctc_loss)
    tf.summary.scalar(name='Learning_Rate', tensor=learning_rate)
    tf.summary.scalar(name='Seqence_Distance', tensor=sequence_distance)
    merge_summary_op = tf.summary.merge_all()

    # set checkpoint saver
    saver = tf.train.Saver()
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = 'crnn_ctc_ocr_{:s}.ckpt'.format(str(train_start_time))
    model_save_path = os.path.join(FLAGS.model_dir, model_name)

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
        summary_writer.add_graph(sess.graph)

        # init all variables
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.max_train_steps):
            imgs, lbls, seq_lens = sess.run(
                [batch_images, batch_labels, batch_sequence_lengths])

            _, cl, lr, sd, preds, summary = sess.run(
                [
                    optimizer, ctc_loss, learning_rate, sequence_distance,
                    ctc_decoded, merge_summary_op
                ],
                feed_dict={
                    input_images: imgs,
                    input_labels: lbls,
                    input_sequence_lengths: seq_lens
                })

            if (step + 1) % FLAGS.step_per_save == 0:
                summary_writer.add_summary(summary=summary, global_step=step)
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=step)

            if (step + 1) % FLAGS.step_per_eval == 0:
                # calculate the precision
                preds = _sparse_matrix_to_list(preds[0])
                gt_labels = _sparse_matrix_to_list(lbls)

                accuracy = []

                for index, gt_label in enumerate(gt_labels):
                    pred = preds[index]
                    total_count = len(gt_label)
                    correct_count = 0
                    try:
                        for i, tmp in enumerate(gt_label):
                            if tmp == pred[i]:
                                correct_count += 1
                    except IndexError:
                        continue
                    finally:
                        try:
                            accuracy.append(correct_count / total_count)
                        except ZeroDivisionError:
                            if len(pred) == 0:
                                accuracy.append(1)
                            else:
                                accuracy.append(0)
                accuracy = np.mean(np.array(accuracy).astype(np.float32),
                                   axis=0)

                print(
                    'step:{:d} learning_rate={:9f} ctc_loss={:9f} sequence_distance={:9f} train_accuracy={:9f}'
                    .format(step + 1, lr, cl, sd, accuracy))

        # close tensorboard writer
        summary_writer.close()

        # stop file queue
        coord.request_stop()
        coord.join(threads=threads)
Exemple #4
0
def _inference_crnn_ctc():
    input_image = tf.placeholder(dtype=tf.float32, shape=[1, _IMAGE_HEIGHT, None, 3])
    char_map_dict = json.load(open(json_word_dict_file_path, 'r'))
    # initialise the net model
    crnn_net = model.CRNNCTCNetwork(phase='test',
                                    hidden_num=train_lstm_hidden_uints,
                                    layers_num=train_lstm_hidden_layers,
                                    num_classes=len(char_map_dict.keys()) + 1)

    with tf.variable_scope('CRNN_CTC', reuse=False):
        net_out = crnn_net.build_network(input_image)

    input_sequence_length = tf.placeholder(tf.int32, shape=[1], name='input_sequence_length')

    # ctc_decoded, ct_log_prob = tf.nn.ctc_beam_search_decoder(net_out, input_sequence_length, merge_repeated=True)
    ctc_decoded, ct_log_prob = tf.nn.ctc_beam_search_decoder(net_out, input_sequence_length, beam_width=100,
                                                             top_paths=1, merge_repeated=False)

    with open(image_list_path, 'r') as fd:
        image_names = [line.strip() for line in fd.readlines()]

    # set checkpoint saver
    saver = tf.train.Saver()
    save_path = tf.train.latest_checkpoint(model_path)

    with tf.Session() as sess:
        # restore all variables
        saver.restore(sess=sess, save_path=save_path)

        accuracy = []

        for image_name in image_names:
            image_path = os.path.join(create_image_path, image_name)
            image = cv2.imread(image_path)
            h, w, c = image.shape
            height = _IMAGE_HEIGHT
            width = int(w * height / h)
            image = cv2.resize(image, (width, height))
            image = np.expand_dims(image, axis=0)
            image = np.array(image, dtype=np.float32)
            seq_len = np.array([width / 4], dtype=np.int32)

            preds = sess.run(ctc_decoded, feed_dict={input_image: image, input_sequence_length: seq_len})

            # print('preds[0]: ', preds[0])

            preds = _sparse_matrix_to_list(preds[0], char_map_dict)

            # print('preds: ', preds)

            print('Predict {:s} image as: {:s}'.format(image_name, preds[0]))

            gt_label = re.match(r'(\d+_)(.*)(\.jpg)', image_name).group(2)

            # print('gt_label: ', gt_label)

            total_count = len(gt_label)
            correct_count = 0
            try:
                for i, tmp in enumerate(gt_label):
                    if tmp == preds[0][i]:
                        correct_count += 1
            except IndexError:
                continue
            finally:
                try:
                    accuracy.append(correct_count / total_count)
                except ZeroDivisionError:
                    if len(preds[0][i]) == 0:
                        accuracy.append(1)
                    else:
                        accuracy.append(0)

        accuracy = np.mean(np.array(accuracy).astype(np.float32), axis=0)

        print(
            'test_accuracy={:9f}'.format(accuracy))