Example #1
0
def train():
    model = cnn_lstm_otc_ocr.LSTMOCR('train')
    model.build_graph()
    train_feeder, num_train_samples = data_prep.input_batch_generator(
        'train', is_training=True, batch_size=batch_size)
    print('get image: ', num_train_samples)

    num_batches_per_epoch = int(
        math.ceil(num_train_samples / float(batch_size)))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
        if restore:
            ckpt = tf.train.latest_checkpoint(checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        for cur_epoch in range(num_epochs):
            # the tracing part
            for cur_batch in range(num_batches_per_epoch):

                batch_time = time.time()
                batch_inputs, batch_labels, _ = next(train_feeder)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                loss, step, _ = sess.run(
                    [model.cost, model.global_step, model.train_op], feed)

                if step % 100 == 0:
                    print('{}/{}:{},loss={}, time={}'.format(
                        step, cur_epoch, num_epochs, loss,
                        time.time() - batch_time))

                # monitor trainig process
                if step % validation_steps == 0 or (
                        cur_epoch == num_epochs - 1
                        and cur_batch == num_batches_per_epoch - 1):

                    batch_inputs, batch_labels, _ = next(train_feeder)
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels
                    }
                    summary_str = sess.run(model.merged_summay, feed)
                    train_writer.add_summary(summary_str, step)

            # save the checkpoint once very few epoochs
            if (cur_epoch % save_epochs == 0) or (cur_epoch == num_epochs - 1):
                if not os.path.isdir(checkpoint_dir):
                    os.mkdir(checkpoint_dir)
                print('save the checkpoint of step {}'.format(step))
                saver.save(sess,
                           os.path.join(checkpoint_dir, 'ocr-model'),
                           global_step=step)
Example #2
0
def infer(path, mode='infer'):
    
    imgList = [os.path.join(path, e) for e in os.listdir(path) if e.endswith('.jpg')]
    print(len(imgList))

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    total_steps = len(imgList) / FLAGS.batch_size

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        # ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        ckpt = FLAGS.checkpoint_dir
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        decoded_expression = []
        for curr_step in range(total_steps):

            show_img_name = []
            batch_imgs = []
            seq_len_input = []

            for img in imgList[curr_step * FLAGS.batch_size: (curr_step + 1) * FLAGS.batch_size]:
                show_img_name.append(img)
                im = cv2.imread(img, 0).astype(np.float32) / 255.
                # im = np.reshape(im, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel])
                scale = FLAGS.image_height / im.shape[0]
                im = cv2.resize(im, None, fx=scale, fy=scale)
                im = np.expand_dims(im, 2)
                batch_imgs.append(im)

            max_width = max([e.shape[1] for e in batch_imgs])
            inputs_imgs = np.zeros((len(batch_imgs), batch_imgs[0].shape[0], max_width, 1))

            for idx, item in enumerate(batch_imgs):
                inputs_imgs[idx, 0:item.shape[1], :] = item
            seq_len_input = [e.shape[1], for e in batch_imgs]

            imgs_input = inputs_imgs
            seq_len_input = np.asarray(seq_len_input)

            feed = {model.inputs: imgs_input, model.seq_len: seq_len_input}
            dense_decoded_code = sess.run(model.dense_decoded, feed)

            batch_result = []
            for decode_code in dense_decoded_code:
                pred_strings = utils.label2text(decode_code)
                batch_result.append(pred_strings)
            for i in range(len(show_img_name)):
                print(show_img_name[i], batch_result[i])
Example #3
0
def infer(img_path, batch_size=64, image_height=60, image_width=180, image_channel=1, checkpoint_dir="../checkpoint/"):
    # 读取图片的名称
    file_names = os.listdir(img_path)
    file_names = [t for t in file_names if t.find("label") < 0]
    file_names.sort(key=lambda x: int(x.split('.')[0]))
    file_names = np.asarray([os.path.join(img_path, file_name) for file_name in file_names])

    # 模型
    model = cnn_lstm_otc_ocr.LSTMOCR(num_classes=NumClasses, batch_size=batch_size, is_train=False)
    model.build_graph()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # 初始化模型
        sess.run(tf.global_variables_initializer())

        # 加载模型
        ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        if ckpt:
            print('restore from ckpt{}'.format(ckpt))
            tf.train.Saver(tf.global_variables(), max_to_keep=100).restore(sess, ckpt)
        else:
            print('cannot restore')
            raise Exception("cannot restore")

        results = []
        for curr_step in range(len(file_names) // batch_size):

            # 读取图片数据
            images_input = []
            for img in file_names[curr_step * batch_size: (curr_step + 1) * batch_size]:
                image_data = np.asarray(Image.open(img).convert("L"), dtype=np.float32) / 255.
                image_data = np.reshape(image_data, [image_height, image_width, image_channel])
                images_input.append(image_data)
            images_input = np.asarray(images_input)

            # 运行得到结果
            # net_results = sess.run(model.dense_decoded, {model.inputs: images_input})
            net_results = sess.run([model.logits, model.seq_len, model.decoded, model.log_prob, model.dense_decoded], {model.inputs: images_input})

            # 对网络输出进行解码得到结果
            for item in net_results:
                result = DataIterator.get_result(item)
                results.append(result)
                print(result)
                pass

            pass

        # 保存结果
        with open('./result.txt', 'a') as f:
            for code in results:
                f.write(code + '\n')
            pass
        pass

    pass
Example #4
0
	def __init__(self):
		self.model = cnn_lstm_otc_ocr.LSTMOCR('infer')
		self.graph = self.model.build_graph()

		config = tf.ConfigProto(allow_soft_placement=True)
		config.gpu_options.per_process_gpu_memory_fraction = 0.6

		self.sess = tf.Session(config = config)
		self.saver = tf.train.Saver()
		self._load_weights('/media/zdyd/dujing/yjx/textrecognition/checkpoint', self.sess, self.saver)
Example #5
0
 def __init__(self, model_dir = model_dir):
     self.X = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel], name='input')
     model = cnn_lstm_otc_ocr.LSTMOCR('infer', '0')
     self.decodes, self.prob = model.build_graph_for_export(self.X)
     config=tf.ConfigProto(allow_soft_placement=True)
     config.gpu_options.allow_growth=True
     self.sess = tf.Session(config = config)
     self.sess.run(tf.global_variables_initializer())
     ckpt = tf.train.latest_checkpoint(model_dir)
     saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
     saver.restore(self.sess, ckpt)
def train(train_dir=None, mode='train'):
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
        model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus)
        train_feeder = utils.DataIterator()
        X, Y = train_feeder.distored_inputs()
        train_op, _ = model.build_graph(X, Y)
        print('len(labels):%d, batch_size:%d' %
              (len(train_feeder.labels), FLAGS.batch_size))
        num_batches_per_epoch = int(
            len(train_feeder.labels) / FLAGS.batch_size / len(gpus))
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print('restore from checkpoint{0}'.format(ckpt))
                    print('global_step:', model.global_step.eval())
                    print('assign value %d' %
                          (FLAGS.num_epochs * num_batches_per_epoch / 3))
                    #sess.run(tf.assign(model.global_step, FLAGS.num_epochs*num_batches_per_epoch/3))
                    print('global_step:', model.global_step.eval())
            print(
                '=============================begin training============================='
            )
            for cur_epoch in range(FLAGS.num_epochs):
                start_time = time.time()
                batch_time = time.time()
                # the training part
                for cur_batch in range(num_batches_per_epoch):
                    res, step = sess.run([train_op, model.global_step])
                    #print("step ", step)
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)
                    if (step + 1) % 100 == 1:
                        print(
                            'step: %d, batch: %d time: %d, learning rate: %.8f, loss:%.4f'
                            % (step, cur_batch, time.time() - batch_time,
                               model.lrn_rate.eval(), model.loss.eval()))
            coord.request_stop()
            coord.join(threads)
Example #7
0
File: infer.py Project: Yorwxue/CCR
    def eval_model(self):
        model = cnn_lstm_otc_ocr.LSTMOCR('eval')
        model.build_graph()
        val_feeder, num_samples = self.input_batch_generator(
            self.split_name, is_training=False, batch_size=FLAGS.batch_size)
        num_batches_per_epoch = int(
            math.ceil(num_samples / float(FLAGS.batch_size)))

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            eval_writer = tf.summary.FileWriter(
                "{}/{}".format(log_dir, self.split_name), sess.graph)
            if tf.gfile.IsDirectory(self.checkpoint_path):
                checkpoint_file = tf.train.latest_checkpoint(
                    self.checkpoint_path)
            else:
                checkpoint_file = self.checkpoint_path
            print('Evaluating checkpoint_path={}, split={}, num_samples={}'.
                  format(checkpoint_file, self.split_name, num_samples))
            saver.restore(sess, checkpoint_file)

            for i in range(num_batches_per_epoch):
                inputs, labels, _ = next(val_feeder)
                feed = {model.inputs: inputs}
                start = time.time()
                predictions = sess.run(model.dense_decoded, feed)
                pred = list()
                for j in range(len(predictions)):
                    code = [
                        utils.decode_maps[c] if c != -1 else ''
                        for c in predictions[j]
                    ]
                    code = ''.join(code)
                    pred.append(code)
                    print("%s" % pred[-1])
                elapsed = time.time()
                elapsed = elapsed - start
                print('{}/{}, {:.5f} seconds.'.format(i, num_batches_per_epoch,
                                                      elapsed))
                # print the decode result

            summary_str, step = sess.run(
                [model.merged_summay, model.global_step])
            eval_writer.add_summary(summary_str, step)
            return
def infer(img_path, mode='infer'):
    # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/')
    imgList = helper.load_img_path(img_path)
    # actual = []
    # for name in imgList:
    #  # code = name.split('/')[-1].split('_')[1].split('.')[0]
    #  code = '-'.join(name.split('/')[-1].split('-')[:-1])
    #  actual.append(code)
    # actual = np.asarray(actual)
    # MAX = 120
    # imgList = imgList[:MAX]
    print(imgList[:5])
    with open('./actual.txt', 'w') as f:
        for name in imgList:
            code = name.split('/')[-1].split('_')[1].split('.')[0]
            # code = '-'.join(name.split('/')[-1].split('-')[:-1])
            f.write(code + '\n')
    # exit(1)
    # im = cv2.imread(imgList[0], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
    # cv2.imshow('image',im)

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    total_steps = len(imgList) // FLAGS.batch_size

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        decoded_expression = []
        for curr_step in range(total_steps):

            imgs_input = []
            seq_len_input = []
            for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) *
                               FLAGS.batch_size]:
                # im = cv2.imread(img, 0).astype(np.float32) / 255.
                im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype(
                    np.float32) / 255.
                im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height))
                im = np.reshape(im, [
                    FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel
                ])

                def get_input_lens(seqs):
                    length = np.array([FLAGS.max_stepsize for _ in seqs],
                                      dtype=np.int64)

                    return seqs, length

                inp, seq_len = get_input_lens(np.array([im]))
                imgs_input.append(im)
                seq_len_input.append(seq_len)

            imgs_input = np.asarray(imgs_input)
            seq_len_input = np.asarray(seq_len_input)
            seq_len_input = np.reshape(seq_len_input, [-1])

            feed = {model.inputs: imgs_input}
            dense_decoded_code = sess.run(model.dense_decoded, feed)

            for item in dense_decoded_code:
                expression = ''

                for i in item:
                    if i == -1:
                        expression += ''
                    else:
                        expression += utils.decode_maps[i]

                decoded_expression.append(expression)

        with open('./result.txt', 'w') as f:
            for code in decoded_expression:
                f.write(code + '\n')
Example #9
0
read_img_begin = time.time()
for image in os.listdir(root):
    image_name = os.path.join(root, image)
    img = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), -1)
    img_list.append(img)
read_img_end = time.time()
#print len(img_list)
#with tf.device('/gpus:0'):
#build_model_begin = time.time()
#model = cnn_lstm_otc_ocr.LSTMOCR('infer')
#model.build_graph()
#build_model_end = time.time()

with tf.device('/cpu:0'):
    build_model_begin = time.time()
    model = cnn_lstm_otc_ocr.LSTMOCR('infer')
    model.build_graph()
    build_model_end = time.time()
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(utils.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')
        fit_image_begin = time.time()
        result = fit(model, sess, img_list[:1])
        fit_image_end = time.time()
def export():
  with tf.device('/cpu:0'):
    with tf.Graph().as_default():
      serialized_tf_recognition = tf.placeholder(tf.string, name='tf_recognition')
      feature_configs = {
          'image/encoded': tf.FixedLenFeature(
              shape=[], dtype=tf.string),
      }
      tf_recognition = tf.parse_example(serialized_tf_recognition, feature_configs)
      jpegs = tf_recognition['image/encoded']
      images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32)
      model = cnn_lstm_otc_ocr.LSTMOCR('infer', '0')
      decodes = model.build_graph_for_export(images)
      with tf.device('/cpu:0'), tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        saver.restore(sess, ckpt)
        output_path = os.path.join(
          tf.compat.as_bytes(FLAGS.output_dir),
          tf.compat.as_bytes(str(FLAGS.model_version)))
        print('Exporting trained model to', output_path)
        builder = tf.saved_model.builder.SavedModelBuilder(output_path)

        # Build the signature_def_map.
        classify_inputs_tensor_info = tf.saved_model.utils.build_tensor_info(
            serialized_tf_recognition)
        classes_output_tensor_info = tf.saved_model.utils.build_tensor_info(
          decodes)

        classification_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs={
                  tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                      classify_inputs_tensor_info
              },
              outputs={
                  tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                      classes_output_tensor_info
              },
              method_name=tf.saved_model.signature_constants.
              CLASSIFY_METHOD_NAME))

        predict_inputs_tensor_info = tf.saved_model.utils.build_tensor_info(jpegs)
        prediction_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs={'images': predict_inputs_tensor_info},
              outputs={
                  'classes': classes_output_tensor_info,
              },
              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          ))
        builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map={
              'predict_images':
                  prediction_signature,
              tf.saved_model.signature_constants.
              DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                  classification_signature,
          }, 
          clear_devices=True)


        builder.save()
        print('Successfully exported model to %s' % FLAGS.output_dir)
Example #11
0
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    #########################read  train   data###############################
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=FLAGS.train_dir, istrain=True)
    print('get image  data size: ', train_feeder.size)
    filename = train_feeder.image
    label = train_feeder.labels
    print(len(filename))
    train_data = ReadData.ReadData(filename, label)
    ##################################read  test   data######################################
    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=FLAGS.val_dir, istrain=False)
    filename1 = val_feeder.image
    label1 = val_feeder.labels
    test_data = ReadData.ReadData(filename1, label1)
    print('val get image: ', val_feeder.size)
    ##################计算batch 数
    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数
    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数
    ###########################data################################################

    with tf.device('/cpu:0'):

        config = tf.ConfigProto(allow_soft_placement=True)

        #######################read  data###################################

        with tf.Session(config=config) as sess:
            #初始化data迭代器
            train_data.init_itetator(sess)
            test_data.init_itetator(sess)
            train_data = train_data.get_nex_batch()
            test_data = test_data.get_nex_batch()
            #全局变量初始化
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0

            for cur_epoch in range(FLAGS.num_epochs):

                train_cost = 0
                batch_time = time.time()
                for cur_batch in range(num_batches_per_epoch):
                    #获得这一轮batch数据的标号##############################
                    #read_data_start = time.time()
                    batch_inputs, batch_labels = sess.run(train_data)
                    #print('read data timr',time.time()-read_data_start)
                    process_data_start = time.time()
                    #print('233333333333333',type(batch_labels))

                    new_batch_labels = utils.sparse_tuple_from_label(
                        batch_labels.tolist())  # 对了
                    batch_seq_len = np.asarray(
                        [FLAGS.max_stepsize for _ in batch_inputs],
                        dtype=np.int64)
                    #print('process data timr', time.time() - process_data_start)

                    #train_data_start = time.time()
                    #print('2444444',batch_inputs.shape())
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: new_batch_labels,
                        model.seq_len: batch_seq_len
                    }
                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)

                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size
                    #print  train_cost
                    #train_writer.add_summary(summary_str, step)
                    #print('train data timr', time.time() - train_data_start)
                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)
                    if (cur_batch) % 100 == 1:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time, 'loss', batch_cost)
                        batch_time = time.time()
                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        validation_start_time = time.time()
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            batch_inputs, batch_labels = sess.run(test_data)
                            new_batch_labels = utils.sparse_tuple_from_label(
                                batch_labels.tolist())  # 对了
                            batch_seq_len = np.asarray(
                                [FLAGS.max_stepsize for _ in batch_inputs],
                                dtype=np.int64)
                            val_feed = {
                                model.inputs: batch_inputs,
                                model.labels: new_batch_labels,
                                model.seq_len: batch_seq_len
                            }


                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            acc = utils.accuracy_calculation(
                                batch_labels.tolist(),
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)
                            acc_batch_total += acc
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples
                        accuracy_res.append(accuracy)
                        epoch_res.append(cur_epoch)
                        if accuracy > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch
                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"

                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - validation_start_time,
                                       lr))
Example #12
0
def infer(img_path, mode='infer'):
    # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/')
    imgList = helper.load_img_path(img_path)
    print(imgList[:5])

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    total_steps = len(imgList) / FLAGS.batch_size

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        decoded_expression = []
        for curr_step in range(int(total_steps)):

            imgs_input = []
            seq_len_input = []
            for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) *
                               FLAGS.batch_size]:
                im = cv2.imread(img, 0).astype(np.float32) / 255.
                im = np.reshape(im, [
                    FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel
                ])

                def get_input_lens(seqs):
                    length = np.array([FLAGS.out_channels for _ in seqs],
                                      dtype=np.int64)

                    return seqs, length

                inp, seq_len = get_input_lens(np.array([im]))
                imgs_input.append(im)
                seq_len_input.append(seq_len)

            imgs_input = np.asarray(imgs_input)
            seq_len_input = np.asarray(seq_len_input)
            seq_len_input = np.reshape(seq_len_input, [-1])

            feed = {model.inputs: imgs_input}
            dense_decoded_code = sess.run(model.dense_decoded, feed)

            for item in dense_decoded_code:
                expression = ''

                for i in item:
                    if i == -1:
                        expression += ''
                    else:
                        expression += utils.decode_maps[i]

                decoded_expression.append(expression)

        print(decoded_expression)
        with open('./result.txt', 'w') as f:
            true_count = 0
            for ind, code in enumerate(decoded_expression[0:len(imgList)]):
                img_name = imgList[ind]
                img_label = img_name.split('_')[-1].replace('.jpg', '')
                if code == img_label:
                    true_count = true_count + 1
                f.write('{} {} {}\n'.format(img_name, img_label, code))
            print('{}/{} = {}'.format(true_count, len(imgList),
                                      float(true_count) / len(imgList)))
Example #13
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder, num_train_samples = data_prep.input_batch_generator('train', is_training=True, batch_size = FLAGS.batch_size)
    print('get image: ', num_train_samples)

    print('loading validation data, please wait---------------------')
    val_feeder, num_val_samples = data_prep.input_batch_generator('val', is_training=False, batch_size = FLAGS.batch_size * 2)
    print('get image: ', num_val_samples)

   
    num_batches_per_epoch = int(math.ceil(num_train_samples / float(FLAGS.batch_size)))

    

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            start_time = time.time()
            batch_time = time.time()

            # the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                batch_inputs, batch_labels, _ = next(train_feeder)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels}

                # if summary is needed
                # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step,
                              model.train_op], feed)
                # calculate the cost

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    
                    val_inputs, val_labels, ori_labels = next(val_feeder)    
                    val_feed = {model.inputs: val_inputs,
                                model.labels: val_labels}

                    dense_decoded, lr = \
                        sess.run([model.dense_decoded, model.lrn_rate],
                                 val_feed)

                    # print the decode result
                    accuracy = utils.accuracy_calculation(ori_labels, dense_decoded,
                                                     ignore_value=-1, isPrint=True)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.5f},train_cost = {:.5f}, " \
                          ", time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, batch_cost,
                                     time.time() - start_time, lr))
Example #14
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        num_batches_per_epoch = int(TRAIN_SET_NUM /
                                    FLAGS.batch_size)  # example: 100000/100
        num_batches_per_epoch_val = int(TRAIN_SET_NUM /
                                        FLAGS.batch_size)  # example: 10000/100

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)

        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        for epoch in range(1000):
            train_feeder = sess.run(train_batch)
            batch_inputs,  batch_labels = \
                train_feeder[0],read_labels(train_feeder[1])
            feed = {model.inputs: batch_inputs, model.labels: batch_labels}
            summary_str, batch_cost, step, _ = \
                sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
            train_writer.add_summary(summary_str, step)
            if step % FLAGS.save_steps == 1:
                if not os.path.isdir(FLAGS.checkpoint_dir):
                    os.mkdir(FLAGS.checkpoint_dir)
                logger.info('save checkpoint at step {0}', format(step))
                saver.save(sess,
                           os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                           global_step=step)
            # do validation
            if step % FLAGS.validation_steps == 0:
                acc_batch_total = 0
                lastbatch_err = 0
                lr = 0
                for j in range(num_batches_per_epoch_val):
                    val_inputs, val_labels = \
                        train_feeder[0], read_labels(train_feeder[1])

                    val_feed = {
                        model.inputs: val_inputs,
                        model.labels: val_labels
                    }

                    dense_decoded, lastbatch_err, lr = \
                        sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                 val_feed)
                    # print the decode result
                    print(dense_decoded)
                    print(val_labels)

                    acc = utils.accuracy_calculation(val_labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    acc_batch_total += acc

                accuracy = (acc_batch_total * FLAGS.batch_size) / 2

                # train_err /= num_train_samples
                now = datetime.datetime.now()
                log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                      "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                      "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                print(
                    log.format(now.month, now.day, now.hour, now.minute,
                               now.second, epoch + 1, FLAGS.num_epochs,
                               accuracy, epoch, lastbatch_err,
                               time.time() - epoch, lr))
Example #15
0
def infer(root, mode='infer'):

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        for img_file in os.listdir(root):
            start_time = time.time()
            img_path = os.path.join(root, img_file)
            print(img_path)
            # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/')
            file_name = img_path.split('/')[-1].split('_')[0]
            imgList = helper.load_img_path(img_path)
            #print(imgList[:5])

            total_steps = len(imgList) / FLAGS.batch_size
            sample_num = len(imgList) * 3

            total_acc = 0
            for curr_step in xrange(total_steps):
                decoded_expression = []
                imgs_input = []
                seq_len_input = []
                imgs_label = []
                for img in imgList[curr_step *
                                   FLAGS.batch_size:(curr_step + 1) *
                                   FLAGS.batch_size]:

                    label = img.split('_')[-1].split('.')[0]
                    imgs_label.append(label.upper())

                    #print (img)
                    im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype(
                        np.float32) / 255.
                    im = cv2.resize(im,
                                    (FLAGS.image_width, FLAGS.image_height))
                    im = np.reshape(im, [
                        FLAGS.image_height, FLAGS.image_width,
                        FLAGS.image_channel
                    ])

                    def get_input_lens(seqs):
                        length = np.array([FLAGS.max_stepsize for _ in seqs],
                                          dtype=np.int64)

                        return seqs, length

                    inp, seq_len = get_input_lens(np.array([im]))
                    imgs_input.append(im)
                    seq_len_input.append(seq_len)

                imgs_input = np.asarray(imgs_input)
                seq_len_input = np.asarray(seq_len_input)
                seq_len_input = np.reshape(seq_len_input, [-1])

                feed = {model.inputs: imgs_input, model.seq_len: seq_len_input}
                dense_decoded_code = sess.run(model.dense_decoded, feed)

                for item in dense_decoded_code:
                    expression = ''

                    for i in item:
                        if i == -1:
                            expression += ''
                        else:
                            expression += utils.decode_maps[i]

                    decoded_expression.append(expression)

                acc = utils.test_accuracy_calculation(imgs_label,
                                                      decoded_expression, True)
                total_acc += acc
            print(total_acc / total_steps)
            print(file_name)
            print(sample_num)
            with open('./result.txt', 'a') as f:
                f.write(file_name + ',' +
                        str(round(total_acc / total_steps, 2)) + ',' +
                        str(sample_num) + ',' +
                        str(round((time.time() - start_time) /
                                  sample_num, 2)) + '\n')
Example #16
0
def main(_):
    model = cnn_lstm_otc_ocr.LSTMOCR('train')
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir='train')
    print('get image: ', type(train_feeder.image[0].shape),
          train_feeder.labels)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir='val')
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/gpu:0'):
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.per_process_gpu_memory_fraction = 0.6
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

            print(
                '=============================begin training============================='
            )
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(num_train_samples)
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # the tracing part
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]
                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)
                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size
                    print('batch_cost is: ', batch_cost)
                    #print 'train_cost is: ', train_cost
                    train_writer.add_summary(summary_str, step)

                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in xrange(num_batches_per_epoch_val):
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            dense_decoded, lastbatch_err, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, accuracy,
                                       avg_train_cost, lastbatch_err,
                                       time.time() - start_time, lr))
Example #17
0
def infer(img_path, mode='infer'):
    # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/')
    imgList = helper.load_img_path(img_path)
    actual = []
    # for name in imgList:
	   #  # code = name.split('/')[-1].split('_')[1].split('.')[0]
	   #  code = '-'.join(name.split('/')[-1].split('-')[:-1])
	   #  actual.append(code)
    # actual = np.asarray(actual)
    # MAX = 120
    # imgList = imgList[:MAX]
    print(imgList[:5])
    with open('./actual.txt', 'w') as f:
        for name in imgList:
            code = name.split('/')[-1].split('-')[:-1]
            # code = name.split('/')[-1].split('_')[-1].split('.')[0]
            ## convert year field from 2019 -> 19
            # code = code.split('-')
            # code[2] = code[2][2:]
            code = '-'.join(code)
            actual.append(code)
            f.write(code + '\n')
    actual = np.asarray(actual)

    # exit(1)
    # im = cv2.imread(imgList[0], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
    # cv2.imshow('image',im)

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    total_steps = len(imgList) // FLAGS.batch_size

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        # print(ckpt)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        decoded_expression = []
        for curr_step in range(total_steps):

            imgs_input = []
            seq_len_input = []
            for img in imgList[curr_step * FLAGS.batch_size: (curr_step + 1) * FLAGS.batch_size]:
                # im = cv2.imread(img, 0).astype(np.float32) / 255.
                im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
                im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height))
                # im = im[10:45,8:160]
                # im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height))
                im = np.reshape(im, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel])
                # cv2.imshow('image',im)
                # cv2.waitKey(0)
                def get_input_lens(seqs):
                    length = np.array([FLAGS.max_stepsize for _ in seqs], dtype=np.int64)

                    return seqs, length

                inp, seq_len = get_input_lens(np.array([im]))
                imgs_input.append(im)
                seq_len_input.append(seq_len)

            imgs_input = np.asarray(imgs_input)
            seq_len_input = np.asarray(seq_len_input)
            seq_len_input = np.reshape(seq_len_input, [-1])

            feed = {model.inputs: imgs_input}
            dense_decoded_code = sess.run(model.dense_decoded, feed)

            for item in dense_decoded_code:
                expression = ''

                for i in item:
                    if i == -1:
                        expression += ''
                    else:
                        expression += utils.decode_maps[i]

                decoded_expression.append(expression)


            # visualize the layers
            # conv_out = sess.run(model.conv_out,feed)
            # img_name = imgList[curr_step].split('/')[-1].split('.')[0]
            # # layer0 = conv_out[0]
            # for i in range(len(conv_out)):
            #     layer = conv_out[i]

            #     print(layer.shape)
            #     plotNNFilter(layer)
            #     plt.show()
                # plt.savefig("./imgs/filters/conv-{}_{}".format(i+1,img_name))

        # print(decoded_expression)
        # layer0 = model.conv_out[0]
        # print(layer0.shape)
        # print(layer0)
        # print(type(layer0.eval()))
        # plotNNFilter(layer0)
        ## visualize the layers
        
        # test image
        # SIZE = 167,55
        # imageToUse = imgList[0]
        # im = cv2.imread(imageToUse, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # im = cv2.resize(im, SIZE)
        # im = im[8:48,5:155]
        # im = cv2.resize(im, SIZE)
        # im = np.reshape(im, [SIZE[1],SIZE[0],1])
        # cv2.imshow('image',im)
        # cv2.waitKey(0)

        # op = sess.graph.get_operations()
        # for i in op:
        #     print(i.name)
            # exit(1)
        # print layers


        # plt.imshow(np.reshape(imageToUse,[28,28]), interpolation="nearest", cmap="gray")

        with open('./result.txt', 'w') as f:
            for code in decoded_expression:
                f.write(code + '\n')
                # print(code)
        # exit()

        decoded_expression = np.asarray(decoded_expression)
        imgList = np.asarray(imgList)

        # print 6 corect and 6 incorrect predictions
        c = decoded_expression == actual
        w = decoded_expression != actual
        correct = imgList[c]
        wrong = imgList[w]
        print("correct predictions:")
        print(correct[:6])
        print("********")
        print("wrong predictions:")
        print(wrong[:6])
        print("********")
        for i in range(6):
        	print("prediction = {}".format(decoded_expression[w][i]))
        	print("actual = {}".format(actual[w][i]))
        	print("********")

        acc = float(c.sum()) / (c.sum()+w.sum())
        print("accuracy = {}".format(acc))
 def infer(self):
     FLAGS.num_threads = 1
     gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
     with tf.Graph().as_default(), tf.device('/cpu:0'):
         train_feeder = utils.DataIterator(is_val=True, random_shuff=False)
         X, Y = train_feeder.distored_inputs()
         model = cnn_lstm_otc_ocr.LSTMOCR('infer', gpus)
         train_op, decodes = model.build_graph(X, Y)
         total_steps = (len(train_feeder.image) + FLAGS.batch_size -
                        1) / FLAGS.batch_size
         config = tf.ConfigProto(allow_soft_placement=True)
         result_dir = os.path.dirname(FLAGS.infer_file)
         with tf.Session(config=config) as sess, open(
                 os.path.join(FLAGS.output_dir, 'result'), 'w') as f:
             sess.run(tf.global_variables_initializer())
             coord = tf.train.Coordinator()
             threads = tf.train.start_queue_runners(sess=sess, coord=coord)
             saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
             ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
             print(FLAGS.checkpoint_dir)
             if ckpt:
                 saver.restore(sess, ckpt)
                 print('restore from ckpt{}'.format(ckpt))
             else:
                 print('cannot restore')
             count = 0
             for curr_step in range(total_steps):
                 decoded_expression = []
                 start = time.time()
                 dense_decoded_code = sess.run(decodes)
                 print('time cost:', (time.time() - start))
                 print("dense_decoded_code:", dense_decoded_code)
                 for d in dense_decoded_code:
                     for item in d:
                         expression = ''
                         for i in item:
                             if i not in utils.decode_maps:
                                 expression += ''
                             else:
                                 expression += utils.decode_maps[i]
                         decoded_expression.append(expression)
                 for code in decoded_expression:
                     if count >= len(train_feeder.image): break
                     f.write("%s,%s,%s\n" %
                             (train_feeder.image[count],
                              train_feeder.anno[count].encode('utf-8'),
                              code.encode('utf-8')))
                     filename = os.path.splitext(
                         os.path.basename(
                             train_feeder.image[count]))[0] + ".txt"
                     output_file = os.path.join(FLAGS.output_dir, filename)
                     cur = open(output_file, "w")
                     cur.write(code.encode('utf-8'))
                     cur.close()
                     print(code.encode('utf-8'))
                     try:
                         image_debug = cv2.imread(train_feeder.image[count])
                         image_debug = self.draw_debug(
                             image_debug, code.encode('utf-8'),
                             code == train_feeder.anno[count])
                         image_path = os.path.join(
                             FLAGS.output_dir,
                             os.path.basename(train_feeder.image[count]))
                         cv2.imwrite(image_path, image_debug)
                     except Exception as e:
                         print(e)
                     count += 1
             coord.request_stop()
             coord.join(threads)
Example #19
0
def train(train_dir = None, val_dir = None, mode = 'train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    # 创建图
    model.build_graph()

    print('loading train data, please wait---------------------')
    # 训练数据构造器
    train_feeder = utils.DataIterator(data_dir = train_dir)
    print('get image:', train_feeder.size)
    print('loading validation data, please wait---------------------')
    # 验证数据构造器
    val_feeder = utils.DataIterator(data_dir = val_dir)
    print('get image:', val_feeder.size)

    num_train_samples = train_feeder.size # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000 / 100

    num_val_samples = val_feeder.size # 100000
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 100000 / 100
    # 随机打乱验证集样本
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/gpu:0'):
        # tf.ConfigProto一般用在创建session的时候。用来对session进行参数配置
        # allow_soft_placement = True
        # 如果你指定的设备不存在,允许TF自动分配设备
        config = tf.ConfigProto(allow_soft_placement = True)
        with tf.Session(config = config) as sess:
            sess.run(tf.global_variables_initializer())

            # 创建saver对象,用来保存和恢复模型的参数
            saver = tf.train.Saver(tf.global_variables(), max_to_keep = 100)
            # 将sess里的graph放到日志文件中
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            # 如果之前有保存的模型参数,将之恢复到现在的sess中
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

             print('=============================begin training=============================')
             # 开始训练
             for cur_epoch in range(FLAGS.num_epochs):
                 shuffle_idx = np.random.permutation(num_train_samples)
                 train_cost = 0
                 start_time = time.time()
                 batch_time = time.time()

                 for cur_batch in range(num_batches_per_epoch):
                     if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ':time', time.time() - batch_time)
                    batch_time = time.time()
                    # 构造当前batch的样本indexs
                    # 在训练样本空间中随机选取batch_size数量的的样本
                    indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    # 构造模型feed参数
                    feed = {model.inputs:batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len}

                    # 执行图
                    # fetch操作取回tensors
                    summar_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                    # 计算损失值
                    # 这里的batch_cost是一个batch里的均值
                    train_cost += batch_cost * FLAGS.batch_size
                    # 可视化
                    train_writer.add_summary(summar_str, step)

                    # 保存模型文件checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step = step)

                    # 每个batch验证集上得到解码结果
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        # 得到验证集的输入
                        # 每个batch做迭代验证
                        for j in xrange(num_batches_per_epoch_val):
                            indexs_val = [shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                            val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {model.inputs: val_inputs, modell.labels: val_labels, model.seq_len: val_seq_len}

                            dense_decoded, lastbatch_err, lr = sess.run([model.dense_decoded, model.lrn_rate], val_feed)

                            # 打印在验证集上返回的结果
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value = -1, isPrint = True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                        avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)
                        # train_err /= num_train_smaples
                        
                        now = datetime.datetime.time()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                        print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
Example #20
0
def infer(img_path, mode='infer'):
    # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/')
    imgList = helper.load_img_path(img_path)
    print(imgList[:5])

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    total_steps = len(imgList) / FLAGS.batch_size
    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        decoded_expression = []
        for curr_step in range(int(total_steps)):
            imgs_input = []
            seq_len_input = []
            for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) *
                               FLAGS.batch_size]:
                im = cv2.imread(img, cv2.IMREAD_COLOR).astype(
                    np.float32) / 255.
                im = np.reshape(im, [
                    FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel
                ])

                def get_input_lens(seqs):
                    length = np.array([FLAGS.max_stepsize for _ in seqs],
                                      dtype=np.int64)
                    return seqs, length

                inp, seq_len = get_input_lens(np.array([im]))
                imgs_input.append(im)
                seq_len_input.append(seq_len)

            imgs_input = np.asarray(imgs_input)
            seq_len_input = np.asarray(seq_len_input)
            seq_len_input = np.reshape(seq_len_input, [-1])

            feed = {model.inputs: imgs_input}
            dense_decoded_code = sess.run(model.dense_decoded, feed)

            for item in dense_decoded_code:
                expression = ''
                for i in item:
                    if i == -1:
                        expression += ''
                    else:
                        expression += utils.decode_maps[i]
                decoded_expression.append(expression)

        with open('./result.txt', 'a') as f:
            for code in decoded_expression:
                print(code)
                f.write(code + '\n')
Example #21
0
def train(train_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    label_files = [os.pat.join(train_dir, e) for e in os.listdir(train_dir) if e.endswith('.txt') and os.path.exists(os.path.join(train_dir, e.replace('.txt', '.jpg')))]
    train_num = int(len(label_files) * 0.8)
    test_num = len(label_files) - train_num

    print('total num', len(label_files), 'train num', train_num, 'test num', test_num)
    train_imgs = label_files[0:train_num]
    test_imgs = label_files[train_num:]


    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_imgs)
    

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=test_imgs)
   

    num_batches_per_epoch_train = int(train_num / FLAGS.batch_size)  # example: 100000/100

    num_batches_per_epoch_val = int(test_num / FLAGS.batch_size)  # example: 10000/100
   

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            if cur_epoch == 0:
                random.shuffle(train_feeder.train_data)

            # the training part
            for cur_batch in range(num_batches_per_epoch_train):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                
                batch_inputs, result_img_length, batch_labels = \
                    train_feeder.get_batchsize_data(cur_batch)
               
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: result_img_length}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for val_j in range(num_batches_per_epoch_val):
                        result_img_val, seq_len_input_val, batch_label_val = \
                            val_feeder.get_batchsize_data(val_j)
                        val_feed = {model.inputs: result_img_val,
                                    model.labels: batch_label_val,
                                    model.seq_len: seq_len_input_val}

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        val_pre_list = []
                        for decode_code in dense_decoded:
                            pred_strings = utils.label2text(decode_code)
                            val_pre_list.append(pred_strings)
                        ori_labels = val_feeder.get_val_label(val_j)

                        acc = utils.accuracy_calculation(ori_labels, val_pre_list,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = acc_batch_total / num_batches_per_epoch_val

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     lastbatch_err, time.time() - start_time, lr))
Example #22
0
def train(train_dir, batch_size=64, image_height=60, image_width=180, image_channel=1,
          checkpoint_dir="../checkpoint/", num_epochs=100):

    # 加载数据
    train_data = DataIterator(data_dir=train_dir, batch_size=batch_size, begin=0, end=800)
    valid_data = DataIterator(data_dir=train_dir,  batch_size=batch_size, begin=800, end=1000)
    print('train data batch number: {}'.format(train_data.number_batch))
    print('valid data batch number: {}'.format(valid_data.number_batch))

    # 模型
    model = cnn_lstm_otc_ocr.LSTMOCR(NumClasses, batch_size, image_height=image_height,
                                     image_width=image_width, image_channel=image_channel, is_train=True)
    model.build_graph()

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True), allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        # 初始化
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(checkpoint_dir + 'train', sess.graph)

        # 加载模型
        ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from checkpoint{0}'.format(ckpt))
        else:
            print('no checkpoint to restore')
            pass

        print('=======begin training=======')
        for cur_epoch in range(num_epochs):
            start_time = time.time()
            batch_time = time.time()

            # 训练
            train_cost = 0
            for cur_batch in range(train_data.number_batch):
                if cur_batch % 100 == 0:
                    print('batch {}/{} time: {}'.format(cur_batch, train_data.number_batch, time.time() - batch_time))
                batch_time = time.time()

                batch_inputs, _, sparse_labels = train_data.next_train_batch()

                summary, cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op],
                                                  {model.inputs: batch_inputs, model.labels: sparse_labels})
                train_cost += cost
                train_writer.add_summary(summary, step)
                pass
            print("loss is {}".format(train_cost / train_data.number_batch))

            # 保存模型
            if cur_epoch % 1 == 0:
                if not os.path.isdir(checkpoint_dir):
                    os.mkdir(checkpoint_dir)
                saver.save(sess, os.path.join(checkpoint_dir, 'ocr-model'), global_step=cur_epoch)
                pass

            # 测试
            if cur_epoch % 1 == 0:
                lr = 0
                acc_batch_total = 0
                for j in range(valid_data.number_batch):
                    val_inputs, _, sparse_labels, ori_labels = valid_data.next_test_batch(j)
                    dense_decoded, lr = sess.run([model.dense_decoded, model.lrn_rate],
                                                 {model.inputs: val_inputs, model.labels: sparse_labels})
                    acc_batch_total += accuracy_calculation(ori_labels, dense_decoded, -1)
                    pass

                accuracy = acc_batch_total / valid_data.number_batch

                now = datetime.datetime.now()
                log = "{}/{} {}:{}:{} Epoch {}/{}, accuracy = {:.3f}, time = {:.3f},lr={:.8f}"
                print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                 cur_epoch + 1, num_epochs, accuracy, time.time() - start_time, lr))

            pass
        pass

    pass
def infer(mode='infer'):
    FLAGS.num_threads = 1
    gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        train_feeder = utils.DataIterator(is_val=True, random_shuff=False)
        X, Y_in, Y_out, length = train_feeder.distored_inputs()
        model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus)
        train_op, decodes = model.build_graph(X, Y_in, Y_out, length)
        total_steps = int((len(train_feeder.image) + FLAGS.batch_size - 1) /
                          FLAGS.batch_size)
        config = tf.ConfigProto(allow_soft_placement=True)
        result_dir = os.path.dirname(FLAGS.infer_file)
        with tf.Session(config=config) as sess, open(
                os.path.join(result_dir, 'result_digit_v1.txt'), 'w') as f:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            #saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
            #saver.restore(sess, './checkpoint_zhuyiwei/ocr-model-55001')
            variables_to_restore = model.variable_averages.variables_to_restore(
            )
            saver = tf.train.Saver(variables_to_restore)

            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            print("search from ", FLAGS.checkpoint_dir)
            print(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from ckpt{}'.format(ckpt))
            else:
                print('cannot restore')
            if not os.path.exists(FLAGS.output_dir):
                os.makedirs(FLAGS.output_dir)
            count = 0
            for curr_step in range(total_steps):
                decoded_expression = []

                dense_decoded_code = sess.run(decodes)
                #print('dense_decode', dense_decoded_code)
                for batch in dense_decoded_code:
                    for sequence in batch:
                        expression = ''
                        for code in sequence:
                            if code == utils.TOKEN["<EOS>"]:
                                break
                            if code not in utils.decode_maps:
                                expression += ''
                            else:
                                expression += utils.decode_maps[code]
                        decoded_expression.append(expression)
                for expression in decoded_expression:
                    if count >= len(train_feeder.image): break
                    #    f.write("%s,%s,%s\n"%(train_feeder.image[count], train_feeder.anno[count].encode('utf-8'), code.encode('utf-8')))
                    print(train_feeder.image[count])
                    #print(train_feeder.anno[count].encode('utf-8'))
                    #print(expression.encode('utf-8'))
                    print(train_feeder.anno[count])
                    print(expression)
                    print('')
                    filename = os.path.splitext(
                        os.path.basename(
                            train_feeder.image[count]))[0] + ".txt"
                    output_file = os.path.join(FLAGS.output_dir, filename)
                    cur = open(output_file, "w")
                    #cur.write(expression.encode('utf-8'))
                    cur.write(expression)
                    cur.close()
                    count += 1

            coord.request_stop()
            coord.join(threads)
Example #24
0
def train(train_dir=None, val_dir=None, mode='train'):
    #load dataset
    tfrecords_filename = '/home/youth/DL/CNN_LSTM_CTC_Tensorflow/tfrecords/train.tfrecords'
    filename_queue = tf.train.string_input_producer([tfrecords_filename],
                                                    num_epochs=EPOCHS,
                                                    shuffle=True)
    images, names, labels = gen_tfrecord.read_and_decode(filename_queue)
    print images, names, labels
    #    b, h, w, c = tf.shape(images)
    shape = np.shape(images)
    #    print shape
    seq_len = np.array([12 for _ in range(shape[0])], dtype=np.int64)
    labels = utils.sparse_tuple_from_label(labels)

    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph(images, labels, seq_len)

    num_train_samples = gen_tfrecord.get_size(tfrecords_filename)  # 100000
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 100000/100

    with tf.device('/cpu:0'):
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            print(
                '=============================begin training============================='
            )

            for cur_epoch in range(FLAGS.num_epochs):
                #                shuffle_idx = np.random.permutation(num_train_samples)
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # the tracing part
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()
                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op])
                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(summary_str, step)

                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)


#
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    #
                    #                        # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{},avg_train_cost = {:.3f}, time = {:.3f}"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   avg_train_cost,
                                   time.time() - start_time))
            coord.request_stop()
            coord.join(threads)
Example #25
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    print(FLAGS.image_channel)
    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('size: ', train_feeder.size)

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        sess.graph.finalize()
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, _, batch_labels = \
                    train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * FLAGS.batch_size, (j + 1) *
                                           FLAGS.batch_size)
                        ]
                        val_inputs, _, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {
                            model.inputs: val_inputs,
                            model.labels: val_labels
                        }

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)

                        acc = utils.accuracy_calculation(ori_labels,
                                                         dense_decoded,
                                                         ignore_value=-1,
                                                         isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total *
                                FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    with open('test_acc.txt', 'a') as f:
                        f.write(
                            str(
                                log.format(now.month, now.day, now.hour,
                                           now.minute, now.second, cur_epoch +
                                           1, FLAGS.num_epochs, accuracy,
                                           avg_train_cost, lastbatch_err,
                                           time.time() - start_time, lr)) +
                            "\n")
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   accuracy, avg_train_cost, lastbatch_err,
                                   time.time() - start_time, lr))
Example #26
0
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir, num=4000000)
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir, num=40000)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数

    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/cpu:0'):
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            #全局变量初始化
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            accuracy_per_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(num_train_samples)
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # the tracing part
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()

                    #获得这一轮batch数据的标号
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]

                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)
                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(summary_str, step)

                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        acc_per_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)

                            acc_per = utils.accuracy_calculation_single(
                                ori_labels,
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)

                            acc_per_batch_total += acc_per
                            acc_batch_total += acc

                        accuracy_per = (acc_per_batch_total *
                                        FLAGS.batch_size) / num_val_samples
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples

                        accuracy_per_res.append(accuracy_per)
                        accuracy_res.append(accuracy)

                        epoch_res.append(cur_epoch)
                        if accuracy_per > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy_per, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - start_time, lr))