Ejemplo n.º 1
0
    def eval_model(self):
        model = cnn_lstm_ctc_ocr.LSTMOCR('eval')
        model.build_graph()
        val_feeder, num_samples = self.input_batch_generator(self.split_name,
                                                             batch_size=FLAGS.batch_size,
                                                             data_dir=FLAGS.data_dir)

        num_batches_per_epoch = int(math.ceil(num_samples / float(FLAGS.batch_size)))

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            # eval_writer = tf.summary.FileWriter("{}/{}".format(log_dir, self.split_name), sess.graph)

            if tf.gfile.IsDirectory(self.checkpoint_path):
                checkpoint_file = tf.train.latest_checkpoint(self.checkpoint_path)
            else:
                checkpoint_file = self.checkpoint_path
            print('Evaluating checkpoint_path={}, split={}, num_samples={}'.format(checkpoint_file, self.split_name,
                                                                                   num_samples))

            saver.restore(sess, checkpoint_file)
            true = 0.
            false = 0.
            for i in range(num_batches_per_epoch):
                inputs, labels, _ = next(val_feeder)
                feed = {model.inputs: inputs,
                        model.labels: labels}
                start = time.time()
                _, predictions = sess.run([model.names_to_updates, model.dense_decoded], feed)
                # --
                gt_encode = self.label_from_sparse_tuple(labels)
                gt = list()
                pred = list()
                for j in range(len(gt_encode)):
                    gt_code = [utils.decode_maps[c] if c != -1 else '' for c in gt_encode[j]]
                    gt_code = ''.join(gt_code)
                    gt.append(gt_code)
                for j in range(len(predictions)):
                    code = [utils.decode_maps[c] if c != -1 else '' for c in predictions[j]]
                    code = ''.join(code)
                    pred.append(code)
                for j in range(len(gt)):
                    print("%s  :  %s" % (gt[j], pred[j]))
                    if gt[j] == pred[j]:
                        true += 1
                    else:
                        false += 1
                # --
                elapsed = time.time()
                elapsed = elapsed - start
                print('{}/{}, {:.5f} seconds.'.format(i, num_batches_per_epoch, elapsed))
                # print the decode result
            print("accuracy: %f" % (true/(true+false)))

            # summary_str, step = sess.run([CCR.merged_summay, CCR.global_step])
            # eval_writer.add_summary(summary_str, step)
            return
Ejemplo n.º 2
0
    def infer_model(self, img):

        # image processed
        img = img.astype(np.float32) / 255.
        img = cv2.resize(img, (FLAGS.image_width, FLAGS.image_height))
        img = np.reshape(
            img, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel])

        # CCR
        model = cnn_lstm_ctc_ocr.LSTMOCR('eval')
        model.build_graph()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)

            if tf.gfile.IsDirectory(self.checkpoint_path):
                checkpoint_file = tf.train.latest_checkpoint(
                    self.checkpoint_path)
            else:
                checkpoint_file = self.checkpoint_path
            print('Evaluating checkpoint_path={}'.format(checkpoint_file))

            saver.restore(sess, checkpoint_file)
            # restore CCR finish

            inputs = [img]
            feed = {model.inputs: inputs}
            # start = time.time()
            predictions = sess.run(model.dense_decoded, feed)

            pred = list()

            for j in range(len(predictions)):
                code = [
                    utils.decode_maps[c] if c != -1 else ''
                    for c in predictions[j]
                ]
                code = ''.join(code)
                pred.append(code)
                # print("%s" % pred[-1])

            # elapsed = time.time()
            # elapsed = elapsed - start
            # print('Spent {:.5f} seconds.'.format(elapsed))
            return pred[-1]
Ejemplo n.º 3
0
Archivo: train.py Proyecto: Yorwxue/CCR
def train(mode='train'):
    model = cnn_lstm_ctc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder, num_train_samples = data_prep.input_batch_generator(
        'train', batch_size=FLAGS.batch_size, data_dir=FLAGS.data_dir)
    print('get image: ', num_train_samples)

    print('loading validation data, please wait---------------------')
    val_feeder, num_val_samples = data_prep.input_batch_generator(
        'val', batch_size=FLAGS.batch_size * 2, data_dir=FLAGS.data_dir)
    print('get image: ', num_val_samples)

    num_batches_per_epoch = int(
        math.ceil(num_train_samples / float(FLAGS.batch_size)))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        for cur_epoch in range(FLAGS.num_epochs):
            start_time = time.time()
            batch_time = time.time()

            # the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                batch_inputs, batch_labels, _ = next(train_feeder)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                # if summary is needed
                # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                #print("----------------------------")
                #print(sess.run([stn_output], feed))
                #print("----------------------------")
                #exit()

                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step,
                              model.train_op], feed)
                # calculate the cost

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-CCR'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:

                    val_inputs, val_labels, ori_labels = next(val_feeder)
                    val_feed = {
                        model.inputs: val_inputs,
                        model.labels: val_labels
                    }

                    dense_decoded, lr = \
                        sess.run([model.dense_decoded, model.lrn_rate],
                                 val_feed)

                    # print the decode result
                    accuracy = utils.accuracy_calculation(ori_labels,
                                                          dense_decoded,
                                                          ignore_value=-1,
                                                          isPrint=True)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.5f},train_cost = {:.5f}, " \
                          ", time = {:.3f},lr={:.8f}"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   accuracy, batch_cost,
                                   time.time() - start_time, lr))
Ejemplo n.º 4
0
if __name__ == "__main__":
    with tf.get_default_graph().as_default():
        # send image by base64
        image_string_list = tf.placeholder(tf.string, shape=[None, ], name='image_string')

        batch_input_tensor = tf.map_fn(image_decode, image_string_list, dtype=tf.float32)

        tfconfig = tf.ConfigProto()
        tfconfig.gpu_options.allow_growth = True  # maybe necessary, used to avoid cuda initialize error
        tfconfig.allow_soft_placement = True  # maybe necessary, used to avoid cuda initialize error
        # tfconfig.log_device_placement = True  # print message verbose
        with tf.Session(config=tfconfig) as sess:
            # # model of cnn lstm ctc
            cnn_lstm_ctc = EvaluateModel()
            ocr_model = cnn_lstm_ctc_ocr.LSTMOCR('eval', inputs=batch_input_tensor)
            ocr_model.build_graph()
            if tf.gfile.IsDirectory(configure.ccr_checkpoint_path):
                checkpoint_file = tf.train.latest_checkpoint(configure.ccr_checkpoint_path)
            else:
                checkpoint_file = configure.ccr_checkpoint_path
            ocr_cnn_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='cnn')
            ocr_lstm_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='lstm')
            ocr_stn_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='stn-1')
            ocr_restore = tf.train.Saver(
                ocr_cnn_scope_to_restore + ocr_lstm_scope_to_restore + ocr_stn_scope_to_restore)
            ocr_restore.restore(sess, checkpoint_file)

            ccr_dense_decoded = ocr_model.dense_decoded

            # tf server configure
Ejemplo n.º 5
0
        return predictions, img


# -------------------------------------------

from tensorflow.python.tools import inspect_checkpoint as chkp

if __name__ == "__main__":
    # config = configure.Config(root_path=root_path)
    ocr_checkpoint_path = "/data2/CNN_LSTM_CTC_Tensorflow/checkpoint"

    with tf.get_default_graph().as_default():
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            cnn_lstm_ctc = EvaluateModel()
            ocr_model = cnn_lstm_ctc_ocr.LSTMOCR('eval')
            ocr_model.build_graph()

            if tf.gfile.IsDirectory(ocr_checkpoint_path):
                checkpoint_file = tf.train.latest_checkpoint(
                    ocr_checkpoint_path)
            else:
                checkpoint_file = ocr_checkpoint_path

            # show tensors in the checkpoint
            chkp.print_tensors_in_checkpoint_file(checkpoint_file,
                                                  tensor_name='',
                                                  all_tensors=False)

            # get variable to restore
            ocr_stn_scope_to_restore = tf.get_collection(