Пример #1
0
def acc():
    g = model.Graph()
    with tf.Session(graph=g.graph) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')
        val_feeder = utils.DataIterator(data_dir=inferFolder)
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([27] * val_inputs.shape[0])
        }
        dense_decoded = sess.run(g.dense_decoded, val_feed)

        # print the decode result
        acc = utils.accuracy_calculation(val_feeder.labels,
                                         dense_decoded,
                                         ignore_value=-1,
                                         isPrint=True)
        print(acc)
Пример #2
0
def doWordListClassify(keyfile, outfilename, sue_weight=0.0, k=0):
    '''
    count pos/neg words to make classification
    '''

    dictionary = utils.getSentVoc()
    pos = dictionary['pos']
    neg = dictionary['neg']

    with open(outfilename, 'w') as f:
        for item in utils.DataIterator(keyfile):
            BOW = item['BOW']
            SUE = item['SUE']
            poscount = 0
            negcount = 0
            for word, count in BOW.items():
                if word in pos:
                    poscount += count
                elif word in neg:
                    negcount += count
            score = float((poscount - negcount)) + sue_weight * SUE
            #score = float(poscount - negcount)
            if score > k:
                f.write('1\n')
            elif score < -k:
                f.write('3\n')
            else:
                f.write('2\n')
Пример #3
0
def doMostCommonClassify(keyfile, outfilename):
    '''
    guess all label to be 1, 2,3
    Most Common Baseline
    '''
    with open(outfilename, 'w') as f:
        for inst in utils.DataIterator(keyfile):
            f.write('3\n')
Пример #4
0
def train(checkpoint, runtime_generate=False):
    lprnet = LPRnet(is_train=True)
    train_gen = utils.DataIterator(img_dir=TRAIN_DIR,
                                   runtime_generate=runtime_generate)
    val_gen = utils.DataIterator(img_dir=VAL_DIR)

    def train_batch(train_gen):
        if runtime_generate:
            train_inputs, train_targets, _ = train_gen.next_gen_batch()
        else:
            train_inputs, train_targets, _ = train_gen.next_batch()

        feed = {lprnet.inputs: train_inputs, lprnet.targets: train_targets}

        loss, steps, _, lr = sess.run( \
            [lprnet.loss, lprnet.global_step, lprnet.optimizer, lprnet.learning_rate], feed)

        if steps > 0 and steps % SAVE_STEPS == 0:
            ckpt_dir = CHECKPOINT_DIR
            ckpt_file = os.path.join(ckpt_dir, \
                        'LPRnet_steps{}_loss_{:.3f}.ckpt'.format(steps, loss))
            if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir)
            saver.save(sess, ckpt_file)
            print('checkpoint ', ckpt_file)
        return loss, steps, lr

    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=30)
        restore_checkpoint(sess, saver, checkpoint)

        print('training...')
        for curr_epoch in range(TRAIN_EPOCHS):
            print('Epoch {}/{}'.format(curr_epoch + 1, TRAIN_EPOCHS))
            train_loss = lr = 0
            st = time.time()
            for batch in range(BATCH_PER_EPOCH):
                b_loss, steps, lr = train_batch(train_gen)
                train_loss += b_loss
            tim = time.time() - st
            train_loss /= BATCH_PER_EPOCH
            log = "train loss: {:.3f}, steps: {}, time: {:.1f}s, learning rate: {:.5f}"
            print(log.format(train_loss, steps, tim, lr))

            if curr_epoch > 0 and curr_epoch % VALIDATE_EPOCHS == 0:
                inference(sess, lprnet, val_gen)
Пример #5
0
def eval(data_path):
    # val_feeder=utils.DataIterator(data_dirs=["/data/linkface/OcrData/OcrFakeTestData/"])
    val_feeder=utils.DataIterator(data_dirs=[data_path])
    print('get image: ',val_feeder.size)

    num_eval_samples = val_feeder.size 
    num_batches_per_epoch = int(num_eval_samples/batch_size)

    g = model.Graph()
    with tf.Session(graph = g.graph) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(),max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess,ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')

        print('=============================begin evaluation=============================')

        start = time.time()
        batch_time = time.time()
        total_lines = 0
        total_hit_lines = 0
        total_chars = 0
        total_hit_chars = 0

        #the tracing part
        for cur_batch in range(num_batches_per_epoch):
            batch_time = time.time()
            indexs = range(cur_batch*batch_size, (cur_batch+1)*batch_size)
            batch_inputs, batch_seq_len, batch_labels, batch_lab_len, labels = val_feeder.input_index_generate_batch_warp(indexs)
            val_feed={g.inputs: batch_inputs,
                    g.seq_len: batch_seq_len,
                    g.output_keep_prob: 1.0,
                    g.input_keep_prob: 1.0}
            # print(batch_labels)
            d = sess.run(g.decoded[0], val_feed)
            dense_decoded = tf.sparse_tensor_to_dense(d, default_value=0).eval(session=sess)
            reco_time = time.time()

            hit_lines, lines, hit_chars, chars = utils.accuracy_calculation2(labels, dense_decoded,ignore_value=0,isPrint=False)
            
            total_lines += lines
            total_hit_lines += hit_lines
            total_chars += chars
            total_hit_chars += hit_chars
            log = "Batch {}, Lines {}, Chars {},  line accuracy = {:.5f}, char accuracy = {:.5f}, recognize time = {:.3f}, total time={:3f}"
            print(log.format(cur_batch + 1, lines, chars, hit_lines * 1.0 / lines, hit_chars * 1.0 / chars,  reco_time - batch_time, time.time()-batch_time))

        print('=============================Total Result=============================')
        print(data_path)
        log = "time = {:.3f}, line accuracy = {:.5f}, char accuracy = {:.5f}"
        print(log.format(time.time()-start, total_hit_lines * 1.0 / total_lines, total_hit_chars * 1.0 / total_chars))
        print("Lines {}, Chars {}".format(total_lines, total_chars))
def train(train_dir=None, mode='train'):
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
        model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus)
        train_feeder = utils.DataIterator()
        X, Y = train_feeder.distored_inputs()
        train_op, _ = model.build_graph(X, Y)
        print('len(labels):%d, batch_size:%d' %
              (len(train_feeder.labels), FLAGS.batch_size))
        num_batches_per_epoch = int(
            len(train_feeder.labels) / FLAGS.batch_size / len(gpus))
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print('restore from checkpoint{0}'.format(ckpt))
                    print('global_step:', model.global_step.eval())
                    print('assign value %d' %
                          (FLAGS.num_epochs * num_batches_per_epoch / 3))
                    #sess.run(tf.assign(model.global_step, FLAGS.num_epochs*num_batches_per_epoch/3))
                    print('global_step:', model.global_step.eval())
            print(
                '=============================begin training============================='
            )
            for cur_epoch in range(FLAGS.num_epochs):
                start_time = time.time()
                batch_time = time.time()
                # the training part
                for cur_batch in range(num_batches_per_epoch):
                    res, step = sess.run([train_op, model.global_step])
                    #print("step ", step)
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)
                    if (step + 1) % 100 == 1:
                        print(
                            'step: %d, batch: %d time: %d, learning rate: %.8f, loss:%.4f'
                            % (step, cur_batch, time.time() - batch_time,
                               model.lrn_rate.eval(), model.loss.eval()))
            coord.request_stop()
            coord.join(threads)
Пример #7
0
def test(checkpoint):
    lprnet = LPRnet(is_train=False)
    test_gen = utils.DataIterator(img_dir=TEST_DIR)
    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables())

        if not restore_checkpoint(sess, saver, checkpoint, is_train=False):
            return

        inference(sess, lprnet, test_gen)
Пример #8
0
    def readDevFile(self, devfilename):
        devfile = pd.read_csv(devfilename, header=None)
        self.devNum = devfile.shape[0]
        self.devFeature = np.zeros((self.devNum, self.featNum))
        self.devLabel = np.zeros(self.devNum)

        index = 0
        for item in utils.DataIterator(devfilename):
            label = int(item['label'])
            BOW = item['BOW']
            sue = float(item['SUE'] + 200)
            for word, count in BOW.items():
                try:
                    position = self.Vocabulary[word]
                    self.devFeature[index][position] = count
                except KeyError:
                    pass
            self.devLabel[index] = label
            self.devFeature[index][0] = sue
            index += 1
Пример #9
0
 def GenerateDictionary(self, keyfile, outdictname):
     '''
     Args:
         keyfile(str): path to look for the keyfile
         outdictname: name of the output dictionary
     '''
     for point in utils.DataIterator(keyfile):
         for word, count in point['BOW'].items():
             self.allcounts[word] += count
     totalWordNum = sum(self.allcounts.values())
     wordfilter = float(totalWordNum) * 0.000001
     filterVocabulary = {
         key: value
         for key, value in self.allcounts.items()
         if (value > wordfilter and len(key) > 3)
     }
     sortedVocabulary = sorted(filterVocabulary.iteritems(),
                               key=operator.itemgetter(1))
     with open(outdictname, 'w') as f:
         for item in sortedVocabulary:
             word, count = item
             f.write(word + ':' + str(count) + '\n')
Пример #10
0
    def readTrainFile(self, dictionarypath, trainfilename):

        with open(dictionarypath, 'r') as f:
            wordlist = f.read()
            wordlist = wordlist.split('\n')
            wordlist.remove('')
            wordlist = [item.split(':')[0] for item in wordlist]

        self.Vocabulary = {}
        keyIndex = 1
        for key in set(wordlist):
            self.Vocabulary[key] = keyIndex
            keyIndex += 1
        self.featNum = len(self.Vocabulary) + 1
        print 'length of voc = ' + str(len(self.Vocabulary))

        trainfile = pd.read_csv(trainfilename, header=None)
        self.trainNum = trainfile.shape[0]
        self.trainFeature = np.zeros((self.trainNum, self.featNum))
        self.trainLabel = np.zeros(self.trainNum)

        index = 0
        for item in utils.DataIterator(trainfilename):
            label = int(item['label'])
            BOW = item['BOW']
            sue = float(item['SUE'] + 200)

            for word, count in BOW.items():
                try:
                    position = self.Vocabulary[word]
                    self.trainFeature[index][position] = count
                except KeyError:
                    pass
            self.trainLabel[index] = label
            self.trainFeature[index][0] = sue
            index += 1
Пример #11
0
def train(train_dir=None,val_dir=None,initial_learning_rate=None,num_hidden=None,num_classes=None,hparam=None):
    g = Graph()
    print('loading train data, please wait---------------------')
    train_feeder=utils.DataIterator(data_dir=train_dir)
    print('get image: ',train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder=utils.DataIterator(data_dir=val_dir)
    print('get image: ',val_feeder.size)

    num_train_samples = train_feeder.size 
    num_batches_per_epoch = int(num_train_samples/FLAGS.batch_size) 

    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=False)
    with tf.Session(graph=g.graph,config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(),max_to_keep=100) #持久化
        g.graph.finalize() #???
        train_writer=tf.summary.FileWriter(FLAGS.log_dir+hparam,sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess,ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        val_inputs,val_seq_len,val_labels,_=val_feeder.input_index_generate_batch()  # seq_len 时序长度120
        
        val_feed={g.inputs: val_inputs,
                  g.labels: val_labels,
                 g.seq_len: val_seq_len}
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx=np.random.permutation(num_train_samples)
            train_cost = train_err=0
            start_time = time.time()
            batch_time = time.time()
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                indexs = [shuffle_idx[i%num_train_samples] for i in range(cur_batch*FLAGS.batch_size,(cur_batch+1)*FLAGS.batch_size)]
                batch_inputs,batch_seq_len,batch_labels,label_batch=train_feeder.input_index_generate_batch(indexs)
                feed={g.inputs: batch_inputs,
                        g.labels:batch_labels,
                        g.seq_len:batch_seq_len}

                train_dense_decoded,summary_str, batch_cost,step,_ = sess.run([g.dense_decoded,g.merged_summay,g.CTCloss,g.global_step,g.optimizer],feed)
            #check computing cell propeties
#                print "||||||||||||||||||||"
         #       print(sess.run(g.dense_decoded,feed))
    #            print(np.shape(sess.run(g.pool1,feed)))
                train_cost+=batch_cost*FLAGS.batch_size
                train_writer.add_summary(summary_str,step)
                # save the checkpoint
                if step%FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,os.path.join(FLAGS.checkpoint_dir,'ocr-model'),global_step=step)
                if step%FLAGS.validation_steps == 0:
                    dense_decoded,validation_length_err,learningRate = sess.run([g.dense_decoded,g.lengthOferr,
                        g.learning_rate],val_feed)
                    valid_acc = utils.accuracy_calculation(val_feeder.labels,dense_decoded,ignore_value=-1,isPrint=True)
                    train_acc = utils.accuracy_calculation(label_batch,train_dense_decoded,ignore_value=-1,isPrint=True)
                    avg_train_cost=train_cost/((cur_batch+1)*FLAGS.batch_size)
                    now = datetime.datetime.now()
                    log = "*{}/{} {}:{}:{} Epoch {}/{}, accOfvalidation = {:.3f},train_accuracy={:.3f}, time = {:.3f},learningRate={:.8f}"
                    print(log.format(now.month,now.day,now.hour,now.minute,now.second,cur_epoch+1,FLAGS.num_epochs,valid_acc,train_acc,time.time()-start_time,learningRate)) 
Пример #12
0
def train(train_dir=None, val_dir=None):
	#载入训练、测试数据
	print('Loading training data...')
	train_feeder = utils.DataIterator(data_dir=train_dir)
	print('Get images: ', train_feeder.size)

	print('Loading validate data...')
	val_feeder=utils.DataIterator(data_dir=val_dir)
	print('Get images: ', val_feeder.size)

	#定义网络结构
	g = Graph()
	
	#训练样本总数
	num_train_samples = train_feeder.size
	#每一轮(epoch)样本可以跑多少个batch
	num_batches_per_epoch = int(num_train_samples / batch_size)
	
	with tf.Session(graph = g.graph) as sess:
		sess.run(tf.global_variables_initializer())
		
		saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
		
		# restore = True 加载模型
		if restore:
			ckpt = tf.train.latest_checkpoint(checkpoint_dir)
			if ckpt:
				# global_step也会被加载
				saver.restore(sess, ckpt);
				print('restore from the checkpoint{0}'.format(ckpt))

		print('============begin training============')
		# 获取一个batch的验证数据,制作成placeholder的输入格式
		val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch()
		val_feed = {g.inputs: val_inputs, g.labels: val_labels, g.seq_len: val_seq_len}
		
		start_time = time.time();
		for cur_epoch in range(num_epochs):	#按照epoch进行循环
			shuffle_idx = np.random.permutation(num_train_samples)	#将训练样本的index打乱
			train_cost = 0;

			for cur_batch in range(num_batches_per_epoch):	#对于当前epoch中的每个bacth进行训练
				# 获取一个batch的训练样本,制作成placeholder的输入格式
				indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * batch_size, (cur_batch+1) * batch_size)];
				batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs);
				feed = {g.inputs: batch_inputs, g.labels:batch_labels, g.seq_len:batch_seq_len};

				# 训练run
				summary_str, batch_cost, step, _ = sess.run([g.merged_summay, g.cost, g.global_step, g.optimizer], feed)
				# 计算损失
				train_cost += batch_cost;

				# 打印
				if step % 50 == 1:
					end_time = time.time();
					print('No. %5d batches, loss: %5.2f, time: %3.1fs' % (step, batch_cost, end_time-start_time));
					start_time = time.time();
				
				#验证集验证、保存checkpoint:
				if step % validation_steps == 1:
					if not os.path.isdir(checkpoint_dir):	os.mkdir(checkpoint_dir);
					saver.save(sess,os.path.join(checkpoint_dir, 'ocr-model'), global_step=step)
					
					#解码的结果:
					dense_decoded, lastbatch_err, lr = sess.run([g.dense_decoded, g.lerr, g.learning_rate], val_feed)
					acc = utils.accuracy_calculation(val_feeder.labels, dense_decoded, ignore_value=-1, isPrint=False)
					print('-After %5d steps, Val accu: %4.2f%%' % (step, acc));
Пример #13
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = mdlstm_ctc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    # define batch_size    
    batch_size = FLAGS.batch_size


    config = tf.ConfigProto( allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')            
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len}
                model.is_training = True
                # if summary is needed
                # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                summary_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step,model.train_op], feed)
                # calculate the cost
                delta_batch_cost = batch_cost * FLAGS.batch_size
                train_cost += delta_batch_cost
                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [shuffle_idx_val[i % num_val_samples] for i in
                                      range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                        val_inputs, val_seq_len, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {model.inputs: val_inputs,
                                    model.labels: val_labels,
                                    model.seq_len: val_seq_len}
                        model.is_training = False
                        dense_decoded, lr = \
                            sess.run([model.dense_decoded, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)
                        acc = utils.accuracy_calculation(ori_labels, dense_decoded,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     time.time() - start_time, lr))
Пример #14
0
def train(train_dir=None, val_dir=None, mode='train'):
    #加载模型类
    model = orcmodel.LSTMOCR(mode)
    #创建模型,这一步,运算图和训练操作等都已经具备
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir)  # 准备训练数据
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir)  # 准备验证数据
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 训练样本个数
    num_batches_per_epoch = int(
        num_train_samples / FLAGS.batch_size)  # 一轮有多少批次 example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(
        num_val_samples / FLAGS.batch_size)  # 一轮有多少批次 example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/cpu:0'):
        # ConfigProto 用于配置Session
        # 如果你指定的设备不存在,允许TF自动分配设备
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # max_to_keep 用于指定保存最近的N个checkpoint
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=50)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)  # summary 可视化
            # 根据配置是否恢复权重
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # 恢复的时候global_step也会被恢复
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

            print(
                '=============================begin training============================='
            )
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(
                    num_train_samples)  # 乱序训练样本的index,达到SGD
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # 开始一轮的N批训练
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()

                    # 生成训练批次的index
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]
                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)

                    # 填充placeholder
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # 开始run,记录可视化数据,计算成本值,获取global_step,并训练
                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)

                    # batch_cost是一个均值,这里计算一个batch的cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(
                        summary_str,
                        step)  # run merge_all的到一个summery信息,然后写入,用于可视化

                    # 保存checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # 验证
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            # 按SGD验证一个最小批
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            # 解码,并获得当前学习率
                            dense_decoded, lr = sess.run(
                                [model.dense_decoded, model.lrn_rate],
                                val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)

                            # 计算一个批次正确率
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total * FLAGS.batch_size
                                    ) / num_val_samples  # 求一轮的平均正确率

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)  # 整一轮的当前平均损失值

                        # 输出训练最新信息
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_cost = {:f}, time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, accuracy,
                                       avg_train_cost, batch_cost,
                                       time.time() - start_time, lr))
Пример #15
0
def train():
    g = Graph()
    with g.graph.as_default():
        print('loading train data, please wait---------------------', end=' ')
        train_feeder = utils.DataIterator(data_dir='../train/')
        print('get image: ', train_feeder.size)

        print('loading validation data, please wait---------------------',
              end=' ')
        val_feeder = utils.DataIterator(data_dir='../test/')
        print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 12800
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 12800/64

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        # the cuda trace
        #run_metadata = tf.RunMetadata()
        #trace_file = open('timeline.ctf.json','w')
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: val_seq_len
        }
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = train_err = 0
            start = time.time()
            batch_time = time.time()
            #the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.seq_len: batch_seq_len
                }
                #_,batch_cost, the_err,d,lr,train_summary,step = sess.run([optimizer,cost,lerr,decoded[0],learning_rate,merged_summay,global_step],feed)
                #_,batch_cost, the_err,d,lr,step = sess.run([optimizer,cost,lerr,decoded[0],learning_rate,global_step],feed)
                #the_err,d,lr = sess.run([lerr,decoded[0],learning_rate])

                # if summary is needed
                #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
                batch_cost, step, _ = sess.run(
                    [g.cost, g.global_step, g.optimizer], feed)
                #calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                ## the tracing part
                #_,batch_cost,the_err,step,lr,d = sess.run([optimizer,cost,lerr,
                #    global_step,learning_rate,decoded[0]],feed)
                #options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                #run_metadata=run_metadata)
                #trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                #race_file.write(trace.generate_chrome_trace_format())
                #trace_file.close()

                #train_writer.add_summary(train_summary,step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
                #train_err+=the_err*FLAGS.batch_size
            d, lastbatch_err = sess.run([g.decoded[0], g.lerr], val_feed)
            dense_decoded = tf.sparse_tensor_to_dense(
                d, default_value=-1).eval(session=sess)
            # print the decode result
            acc = utils.accuracy_calculation(val_feeder.labels,
                                             dense_decoded,
                                             ignore_value=-1,
                                             isPrint=True)
            train_cost /= num_train_samples
            #train_err/=num_train_samples
            now = datetime.datetime.now()
            log = "{}-{} {}:{}:{} Epoch {}/{}, accuracy = {:.3f},train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f}"
            print(
                log.format(now.month, now.day, now.hour, now.minute,
                           now.second, cur_epoch + 1, FLAGS.num_epochs, acc,
                           train_cost, lastbatch_err,
                           time.time() - start))
Пример #16
0
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir, num=4000000)
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir, num=40000)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数

    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/cpu:0'):
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            #全局变量初始化
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            accuracy_per_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(num_train_samples)
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # the tracing part
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()

                    #获得这一轮batch数据的标号
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]

                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)
                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(summary_str, step)

                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        acc_per_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)

                            acc_per = utils.accuracy_calculation_single(
                                ori_labels,
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)

                            acc_per_batch_total += acc_per
                            acc_batch_total += acc

                        accuracy_per = (acc_per_batch_total *
                                        FLAGS.batch_size) / num_val_samples
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples

                        accuracy_per_res.append(accuracy_per)
                        accuracy_res.append(accuracy)

                        epoch_res.append(cur_epoch)
                        if accuracy_per > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy_per, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - start_time, lr))
Пример #17
0
def train(train_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    label_files = [os.pat.join(train_dir, e) for e in os.listdir(train_dir) if e.endswith('.txt') and os.path.exists(os.path.join(train_dir, e.replace('.txt', '.jpg')))]
    train_num = int(len(label_files) * 0.8)
    test_num = len(label_files) - train_num

    print('total num', len(label_files), 'train num', train_num, 'test num', test_num)
    train_imgs = label_files[0:train_num]
    test_imgs = label_files[train_num:]


    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_imgs)
    

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=test_imgs)
   

    num_batches_per_epoch_train = int(train_num / FLAGS.batch_size)  # example: 100000/100

    num_batches_per_epoch_val = int(test_num / FLAGS.batch_size)  # example: 10000/100
   

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            if cur_epoch == 0:
                random.shuffle(train_feeder.train_data)

            # the training part
            for cur_batch in range(num_batches_per_epoch_train):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                
                batch_inputs, result_img_length, batch_labels = \
                    train_feeder.get_batchsize_data(cur_batch)
               
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: result_img_length}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for val_j in range(num_batches_per_epoch_val):
                        result_img_val, seq_len_input_val, batch_label_val = \
                            val_feeder.get_batchsize_data(val_j)
                        val_feed = {model.inputs: result_img_val,
                                    model.labels: batch_label_val,
                                    model.seq_len: seq_len_input_val}

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        val_pre_list = []
                        for decode_code in dense_decoded:
                            pred_strings = utils.label2text(decode_code)
                            val_pre_list.append(pred_strings)
                        ori_labels = val_feeder.get_val_label(val_j)

                        acc = utils.accuracy_calculation(ori_labels, val_pre_list,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = acc_batch_total / num_batches_per_epoch_val

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     lastbatch_err, time.time() - start_time, lr))
 def infer(self):
     FLAGS.num_threads = 1
     gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
     with tf.Graph().as_default(), tf.device('/cpu:0'):
         train_feeder = utils.DataIterator(is_val=True, random_shuff=False)
         X, Y = train_feeder.distored_inputs()
         model = cnn_lstm_otc_ocr.LSTMOCR('infer', gpus)
         train_op, decodes = model.build_graph(X, Y)
         total_steps = (len(train_feeder.image) + FLAGS.batch_size -
                        1) / FLAGS.batch_size
         config = tf.ConfigProto(allow_soft_placement=True)
         result_dir = os.path.dirname(FLAGS.infer_file)
         with tf.Session(config=config) as sess, open(
                 os.path.join(FLAGS.output_dir, 'result'), 'w') as f:
             sess.run(tf.global_variables_initializer())
             coord = tf.train.Coordinator()
             threads = tf.train.start_queue_runners(sess=sess, coord=coord)
             saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
             ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
             print(FLAGS.checkpoint_dir)
             if ckpt:
                 saver.restore(sess, ckpt)
                 print('restore from ckpt{}'.format(ckpt))
             else:
                 print('cannot restore')
             count = 0
             for curr_step in range(total_steps):
                 decoded_expression = []
                 start = time.time()
                 dense_decoded_code = sess.run(decodes)
                 print('time cost:', (time.time() - start))
                 print("dense_decoded_code:", dense_decoded_code)
                 for d in dense_decoded_code:
                     for item in d:
                         expression = ''
                         for i in item:
                             if i not in utils.decode_maps:
                                 expression += ''
                             else:
                                 expression += utils.decode_maps[i]
                         decoded_expression.append(expression)
                 for code in decoded_expression:
                     if count >= len(train_feeder.image): break
                     f.write("%s,%s,%s\n" %
                             (train_feeder.image[count],
                              train_feeder.anno[count].encode('utf-8'),
                              code.encode('utf-8')))
                     filename = os.path.splitext(
                         os.path.basename(
                             train_feeder.image[count]))[0] + ".txt"
                     output_file = os.path.join(FLAGS.output_dir, filename)
                     cur = open(output_file, "w")
                     cur.write(code.encode('utf-8'))
                     cur.close()
                     print(code.encode('utf-8'))
                     try:
                         image_debug = cv2.imread(train_feeder.image[count])
                         image_debug = self.draw_debug(
                             image_debug, code.encode('utf-8'),
                             code == train_feeder.anno[count])
                         image_path = os.path.join(
                             FLAGS.output_dir,
                             os.path.basename(train_feeder.image[count]))
                         cv2.imwrite(image_path, image_debug)
                     except Exception as e:
                         print(e)
                     count += 1
             coord.request_stop()
             coord.join(threads)
Пример #19
0
def train(train_dir = None, val_dir = None, mode = 'train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    # 创建图
    model.build_graph()

    print('loading train data, please wait---------------------')
    # 训练数据构造器
    train_feeder = utils.DataIterator(data_dir = train_dir)
    print('get image:', train_feeder.size)
    print('loading validation data, please wait---------------------')
    # 验证数据构造器
    val_feeder = utils.DataIterator(data_dir = val_dir)
    print('get image:', val_feeder.size)

    num_train_samples = train_feeder.size # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000 / 100

    num_val_samples = val_feeder.size # 100000
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 100000 / 100
    # 随机打乱验证集样本
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/gpu:0'):
        # tf.ConfigProto一般用在创建session的时候。用来对session进行参数配置
        # allow_soft_placement = True
        # 如果你指定的设备不存在,允许TF自动分配设备
        config = tf.ConfigProto(allow_soft_placement = True)
        with tf.Session(config = config) as sess:
            sess.run(tf.global_variables_initializer())

            # 创建saver对象,用来保存和恢复模型的参数
            saver = tf.train.Saver(tf.global_variables(), max_to_keep = 100)
            # 将sess里的graph放到日志文件中
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            # 如果之前有保存的模型参数,将之恢复到现在的sess中
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

             print('=============================begin training=============================')
             # 开始训练
             for cur_epoch in range(FLAGS.num_epochs):
                 shuffle_idx = np.random.permutation(num_train_samples)
                 train_cost = 0
                 start_time = time.time()
                 batch_time = time.time()

                 for cur_batch in range(num_batches_per_epoch):
                     if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ':time', time.time() - batch_time)
                    batch_time = time.time()
                    # 构造当前batch的样本indexs
                    # 在训练样本空间中随机选取batch_size数量的的样本
                    indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    # 构造模型feed参数
                    feed = {model.inputs:batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len}

                    # 执行图
                    # fetch操作取回tensors
                    summar_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                    # 计算损失值
                    # 这里的batch_cost是一个batch里的均值
                    train_cost += batch_cost * FLAGS.batch_size
                    # 可视化
                    train_writer.add_summary(summar_str, step)

                    # 保存模型文件checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step = step)

                    # 每个batch验证集上得到解码结果
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        # 得到验证集的输入
                        # 每个batch做迭代验证
                        for j in xrange(num_batches_per_epoch_val):
                            indexs_val = [shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                            val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {model.inputs: val_inputs, modell.labels: val_labels, model.seq_len: val_seq_len}

                            dense_decoded, lastbatch_err, lr = sess.run([model.dense_decoded, model.lrn_rate], val_feed)

                            # 打印在验证集上返回的结果
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value = -1, isPrint = True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                        avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)
                        # train_err /= num_train_smaples
                        
                        now = datetime.datetime.time()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                        print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
Пример #20
0
def train(train_dir=None, val_dir=None, mode='train', config=None):
    model = CNN_LSTM_CTC_CAPTCHA(mode)
    model.build_graph()
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('train size: ', train_feeder.size)
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('validation size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(
        math.ceil(float(num_train_samples) /
                  FLAGS.batch_size))  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(
        math.ceil(float(num_val_samples) /
                  FLAGS.batch_size))  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        saver = utils.resore_or_load_mode(sess, FLAGS)

        print('++++++++++++start training+++++++++++++')
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                if (cur_batch + 1) * FLAGS.batch_size <= num_train_samples:
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]
                else:
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, num_train_samples)
                    ]
                batch_inputs, _, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                # if summary is needed
                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summay, model.cost, model.global_step,
                    model.train_op
                ],
                                                            feed_dict=feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                train_writer.add_summary(summary_str, step)

                print('current batch loss:%f' % (batch_cost))

                # save the checkpoint
                if step % FLAGS.save_steps == 0:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'captcha'),
                               global_step=step)
                # do validation
                if step % FLAGS.validation_steps == 999:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    print('++++++++++++start validation+++++++++++++')
                    for j in range(num_batches_per_epoch_val):
                        if (j + 1) * FLAGS.batch_size <= num_val_samples:
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                        else:
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples] for i in
                                range(j * FLAGS.batch_size, num_val_samples)
                            ]
                        val_inputs, _, val_labels = val_feeder.input_index_generate_batch(
                            indexs_val)

                        val_feed = {
                            model.inputs: val_inputs,
                            model.labels: val_labels
                        }
                        lastbatch_err, lr = sess.run(
                            [model.cost, model.lrn_rate], feed_dict=val_feed)
                        # print the decode result
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   avg_train_cost, lastbatch_err,
                                   time.time() - start_time, lr))
Пример #21
0
def train():

    #Initiate the Neural Network
    net = LPRNet(NUM_CLASS)

    #get the trainn and validation batch size from argument parser
    batch_size = args["batch_size"]
    val_batch_size = args["val_batch_size"]

    #initialize the custom data generator
    #Defined in utils.py
    train_gen = utils.DataIterator(img_dir=args["train_dir"],
                                   batch_size=batch_size)
    val_gen = utils.DataIterator(img_dir=args["val_dir"],
                                 batch_size=val_batch_size)

    #variable intialization used for custom training loop
    train_len = len(next(os.walk(args["train_dir"]))[2])
    val_len = len(next(os.walk(args["val_dir"]))[2])
    print("Train Len is", train_len)
    # BATCH_PER_EPOCH = None
    if batch_size == 1:
        BATCH_PER_EPOCH = train_len
    else:
        BATCH_PER_EPOCH = int(math.ceil(train_len / batch_size))

    #initialize tensorboard
    tensorboard = keras.callbacks.TensorBoard(log_dir='tmp/my_tf_logs',
                                              histogram_freq=0,
                                              batch_size=batch_size,
                                              write_graph=True)

    val_batch_len = int(math.floor(val_len / val_batch_size))
    evaluator = evaluate.Evaluator(val_gen, net, CHARS, val_batch_len,
                                   val_batch_size)
    best_val_loss = float("inf")

    #if a pretrained model is available, load weights from it
    if args["pretrained"]:
        net.load_weights(args["pretrained"])

    model = net.model
    tensorboard.set_model(model)

    #initialize the learning rate
    learning_rate = keras.optimizers.schedules.ExponentialDecay(
        args["lr"],
        decay_steps=args["decay_steps"],
        decay_rate=args["decay_rate"],
        staircase=args["staircase"])

    #define training optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    print('Training ...')
    train_loss = 0

    #starting the training loop
    for epoch in range(args["train_epochs"]):

        print("Start of epoch {} / {}".format(epoch, args["train_epochs"]))

        #zero out the train_loss and val_loss at the beginning of every loop
        #This helps us track the loss value for every epoch
        train_loss = 0
        val_loss = 0
        start_time = time.time()

        for batch in range(BATCH_PER_EPOCH):
            # print("batch {}/{}".format(batch, BATCH_PER_EPOCH))
            #get a batch of images/labels
            #the labels have to be put into sparse tensor to feed into tf.nn.ctc_loss
            train_inputs, train_targets, train_labels = train_gen.next_batch()
            train_inputs = train_inputs.astype('float32')

            train_targets = tf.SparseTensor(train_targets[0], train_targets[1],
                                            train_targets[2])

            # Open a GradientTape to record the operations run
            # during the forward pass, which enables auto-differentiation.
            with tf.GradientTape() as tape:

                #get model outputs
                logits = model(train_inputs, training=True)

                #next we pass the model outputs into the ctc loss function
                logits = tf.reduce_mean(logits, axis=1)
                logits_shape = tf.shape(logits)
                cur_batch_size = logits_shape[0]
                timesteps = logits_shape[1]
                seq_len = tf.fill([cur_batch_size], timesteps)
                logits = tf.transpose(logits, (1, 0, 2))
                ctc_loss = tf.nn.ctc_loss(labels=train_targets,
                                          inputs=logits,
                                          sequence_length=seq_len)
                loss_value = tf.reduce_mean(ctc_loss)

            #Calculate Gradients and Update it
            grads = tape.gradient(
                ctc_loss,
                model.trainable_weights,
                unconnected_gradients=tf.UnconnectedGradients.NONE)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            train_loss += float(loss_value)

        tim = time.time() - start_time

        print("Train loss {}, time {} \n".format(
            float(train_loss / BATCH_PER_EPOCH), tim))
        #run a validation loop in every 25 epoch
        if epoch != 0 and epoch % 25 == 0:
            val_loss = evaluator.evaluate()
            #if the validation loss is less the previous best validation loss, update the saved model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                net.save_weights(
                    os.path.join(args["saved_dir"], "new_out_model_best.pb"))
                print("Weights updated in {}/{}".format(
                    args["saved_dir"], "new_out_model_best.pb"))

            else:
                print("Validation loss is greater than best_val_loss ")

            # if epoch %500 == 0:
            # 	net.save(os.path.join(args["saved_dir"], f"new_out_model_last_{epoch}.pb"))

    net.save(os.path.join(args["saved_dir"], "new_out_model_last.pb"))
    print("Final Weights saved in {}/{}".format(args["saved_dir"],
                                                "new_out_model_last.pb"))
    tensorboard.on_train_end(None)
Пример #22
0
def train_process(mode=RunMode.Trains):
    model = framework.GraphOCR(mode, NETWORK_MAP[NEU_CNN],
                               NETWORK_MAP[NEU_RECURRENT])
    model.build_graph()

    print('Loading Trains DataSet...')
    train_feeder = utils.DataIterator(mode=RunMode.Trains)
    if TRAINS_USE_TFRECORDS:
        train_feeder.read_sample_from_tfrecords(TRAINS_PATH)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_tfrecords(TEST_PATH)
    else:
        if isinstance(TRAINS_PATH, list):
            origin_list = []
            for trains_path in TRAINS_PATH:
                origin_list += [
                    os.path.join(trains_path, trains)
                    for trains in os.listdir(trains_path)
                ]
        else:
            origin_list = [
                os.path.join(TRAINS_PATH, trains)
                for trains in os.listdir(TRAINS_PATH)
            ]
        random.shuffle(origin_list)
        if not HAS_TEST_SET:
            test_list = origin_list[:TEST_SET_NUM]
            trains_list = origin_list[TEST_SET_NUM:]
        else:
            if isinstance(TEST_PATH, list):
                test_list = []
                for test_path in TEST_PATH:
                    test_list += [
                        os.path.join(test_path, test)
                        for test in os.listdir(test_path)
                    ]
            else:
                test_list = [
                    os.path.join(TEST_PATH, test)
                    for test in os.listdir(TEST_PATH)
                ]
            random.shuffle(test_list)
            trains_list = origin_list
        train_feeder.read_sample_from_files(trains_list)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_files(test_list)

    print('Total {} Trains DataSets'.format(train_feeder.size))
    print('Total {} Test DataSets'.format(test_feeder.size))
    if test_feeder.size >= train_feeder.size:
        exception(
            "The number of training sets cannot be less than the test set.", )

    num_train_samples = train_feeder.size
    num_test_samples = test_feeder.size
    if num_test_samples < TEST_BATCH_SIZE:
        exception(
            "The number of test sets cannot be less than the test batch size.",
            ConfigException.INSUFFICIENT_SAMPLE)
    num_batches_per_epoch = int(num_train_samples / BATCH_SIZE)

    config = tf.ConfigProto(
        # allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allocator_type='BFC',
            allow_growth=True,  # it will cause fragmentation.
            per_process_gpu_memory_fraction=GPU_USAGE))
    accuracy = 0
    epoch_count = 1

    with tf.Session(config=config) as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        train_writer = tf.summary.FileWriter('logs', sess.graph)
        try:
            saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
        except ValueError:
            pass
        print('Start training...')

        while 1:
            shuffle_trains_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            _avg_train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                index_list = [
                    shuffle_trains_idx[i % num_train_samples]
                    for i in range(cur_batch * BATCH_SIZE, (cur_batch + 1) *
                                   BATCH_SIZE)
                ]
                if TRAINS_USE_TFRECORDS:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_tfrecords(
                        sess)
                else:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_files(
                        index_list)

                feed = {
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                }

                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summary, model.cost, model.global_step,
                    model.train_op
                ],
                                                            feed_dict=feed)
                train_cost += batch_cost * BATCH_SIZE
                avg_train_cost = train_cost / ((cur_batch + 1) * BATCH_SIZE)
                _avg_train_cost = avg_train_cost
                train_writer.add_summary(summary_str, step)

                if step % 100 == 0 and step != 0:
                    print('Step: {} Time: {:.3f}, Cost = {:.5f}'.format(
                        step,
                        time.time() - batch_time, avg_train_cost))

                if step % TRAINS_SAVE_STEPS == 0 and step != 0:
                    saver.save(sess, SAVE_MODEL, global_step=step)
                    logger.info('save checkpoint at step {0}', format(step))

                if step % TRAINS_VALIDATION_STEPS == 0 and step != 0:
                    shuffle_test_idx = np.random.permutation(num_test_samples)
                    batch_time = time.time()
                    index_test = [
                        shuffle_test_idx[i % num_test_samples]
                        for i in range(cur_batch *
                                       TEST_BATCH_SIZE, (cur_batch + 1) *
                                       TEST_BATCH_SIZE)
                    ]
                    if TRAINS_USE_TFRECORDS:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_tfrecords(
                            sess)
                    else:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_files(
                            index_test)

                    val_feed = {
                        model.inputs: test_inputs,
                        model.labels: test_labels
                    }
                    dense_decoded, lr = sess.run(
                        [model.dense_decoded, model.lrn_rate],
                        feed_dict=val_feed)
                    accuracy = utils.accuracy_calculation(
                        test_feeder.labels(
                            None if TRAINS_USE_TFRECORDS else index_test),
                        dense_decoded,
                        ignore_value=[0, -1],
                    )
                    log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \
                          "Time = {:.3f}, LearningRate: {}"
                    print(
                        log.format(epoch_count, step, accuracy, avg_train_cost,
                                   time.time() - batch_time, lr))

                    if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and avg_train_cost <= TRAINS_END_COST:
                        break
            if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and _avg_train_cost <= TRAINS_END_COST:
                compile_graph(accuracy)
                print('Total Time: {}'.format(time.time() - start_time))
                break
            epoch_count += 1
Пример #23
0
        return  self.iterator.get_next()



def parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_png(image_string, channels=1)
    image_decoded = image_decoded / 255
    image_resize = tf.image.resize_images(image_decoded,[32,tf.shape(image_decoded)[1]])
    add = tf.zeros((32, 256-tf.shape(image_resize)[1],1))+image_decoded[-1][-1]
    im =tf.concat( [image_resize,add],1)
        #print(im.shape)
    return im, label
if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    val_feeder = utils.DataIterator(data_dir='../data/test/', istrain=False)
    filename = val_feeder.image
    print(len(filename))
    label = val_feeder.labels


    dataset = tf.data.Dataset.from_tensor_slices((filename, label))
    dataset = dataset.map(parse_function)
    dataset = dataset.repeat()  # 不带参数为无限个epoch
    dataset = dataset.shuffle(buffer_size=20000)  # 缓冲区,随机缓存区
    batched_dataset = dataset.batch(128)
    iterator = batched_dataset.make_initializable_iterator()
    with tf.Session() as sess:
        sess.run(iterator.initializer)

        tf_train_data  = iterator.get_next()
Пример #24
0
def train(train_dir=None, val_dir=None, mode='train'):
    network = model.Model(mode)
    network.build_graph()

    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('size: ', train_feeder.size)

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size)
    shuffle_idx_val = np.random.permutation(num_val_samples)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        if not os.path.isdir(FLAGS.log_dir):
            os.mkdir(FLAGS.log_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore as well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {
                    network.inputs: batch_inputs,
                    network.labels: batch_labels
                }

                # summary and calculate the loss
                summary_str, batch_loss, step, _ = \
                    sess.run([network.merged_summay, network.loss, network.global_step, network.train_op], feed_dict=feed)
                train_cost += batch_loss * FLAGS.batch_size
                train_writer.add_summary(summary_str, step)

                if (cur_batch + 1) % 2 == 0:
                    print('batch', cur_batch, '/', num_batches_per_epoch,
                          ' loss =', batch_loss, ' time',
                          time.time() - batch_time)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0.
                    lastbatch_err = 0.
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * FLAGS.batch_size, (j + 1) *
                                           FLAGS.batch_size)
                        ]
                        val_inputs, val_labels = val_feeder.input_index_generate_batch(
                            indexs_val)
                        val_feed = {
                            network.inputs: val_inputs,
                            network.labels: val_labels
                        }

                        lastbatch_err, acc, lr = \
                            sess.run([network.loss, network.accuracy, network.lrn_rate], feed_dict=val_feed)

                        acc_batch_total += acc

                    accuracy = (acc_batch_total *
                                FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   accuracy, avg_train_cost, lastbatch_err,
                                   time.time() - start_time, lr))

            # train with val dataset to reduce overfitting
            if (cur_epoch + 1) % FLAGS.train_with_val_steps == 0:
                shuffle_idx_val = np.random.permutation(num_val_samples)
                for cur_batch in range(num_batches_per_epoch_val):
                    batch_time = time.time()
                    indexs = [
                        shuffle_idx_val[i % num_val_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]
                    batch_inputs, batch_labels = val_feeder.input_index_generate_batch(
                        indexs)
                    feed = {
                        network.inputs: batch_inputs,
                        network.labels: batch_labels
                    }

                    batch_loss, step, _ = \
                        sess.run([network.loss, network.global_step, network.train_op], feed_dict=feed)

                    if (cur_batch + 1) % 2 == 0:
                        print('train with val dataset: batch', cur_batch, '/',
                              num_batches_per_epoch_val, ' loss =', batch_loss,
                              ' time',
                              time.time() - batch_time)
Пример #25
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None):
    g = model.Graph(is_training=True)
    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.6
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
        g.graph.finalize()
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        #print(len(val_inputs))
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([g.cnn_time] * val_inputs.shape[0])
        }
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            #the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                #batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.seq_len: np.array([g.cnn_time] * batch_inputs.shape[0])
                }

                # if summary is needed
                #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
                summary_str, batch_cost, step, _ = sess.run(
                    [g.merged_summay, g.cost, g.global_step, g.optimizer],
                    feed)
                #calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                train_writer.add_summary(summary_str, step)
                print("cur_epoch====", cur_epoch, "cur_batch----", cur_batch,
                      "g_step****", step, "cost", batch_cost)
                # save the checkpoint
                if step % FLAGS.save_steps == 0:
                    print("save checkpoint", step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
                #train_err+=the_err*FLAGS.batch_size
                #do validation
                if step % FLAGS.validation_steps == 0:
                    dense_decoded, lastbatch_err, lr = sess.run(
                        [g.dense_decoded, g.lerr, g.learning_rate], val_feed)
                    # print the decode result
                    acc = utils.accuracy_calculation(val_feeder.labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    #train_err/=num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{}  step==={}, Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}\n"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, step, cur_epoch + 1,
                                   FLAGS.num_epochs, acc, avg_train_cost,
                                   lastbatch_err,
                                   time.time() - start_time, lr))
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, step,
                                       cur_epoch + 1, FLAGS.num_epochs, acc,
                                       avg_train_cost, lastbatch_err,
                                       time.time() - start_time, lr))
                        f.close()
Пример #26
0
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string, channels=1)
        image_decoded = image_decoded / 255
        #image_resize = tf.image.resize_images(image_decoded,[32,tf.shape(image_decoded)[1]])
        add = tf.zeros((tf.shape(image_decoded)[0],256-tf.shape(image_decoded)[1],1))+image_decoded[-1][-1]
        im =tf.concat( [image_decoded,add],1)
        image_resize = tf.image.resize_images(im,[32,256])
        #print(im.shape)
        return image_resize, label
    def init_itetator(self,sess):
        sess.run(self.iterator.initializer)
    def get_nex_batch(self):
        return  self.iterator.get_next()
if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    val_feeder = utils.DataIterator(data_dir='/home/work/data/', istrain=True)
    filename1 = val_feeder.image
    print(len(filename1))
    label1 = val_feeder.labels
    #print('234333333',label1.shapes())
    train_data = ReadData(filename1,label1)
    config = tf.ConfigProto(allow_soft_placement=False)
    with tf.Session(config=config) as sess:
        train_data.init_itetator(sess)
        #这行必须放在loop之外
        tf_train_data  =  train_data.get_nex_batch()
        start_time = time.time()
        for i in range(100):
            imgbatch, label_batch = sess.run(tf_train_data)
            #print(imgbatch)
            #print(label_batch)
def infer(mode='infer'):
    FLAGS.num_threads = 1
    gpus = list(filter(lambda x: x, FLAGS.gpus.split(',')))
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        train_feeder = utils.DataIterator(is_val=True, random_shuff=False)
        X, Y_in, Y_out, length = train_feeder.distored_inputs()
        model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus)
        train_op, decodes = model.build_graph(X, Y_in, Y_out, length)
        total_steps = int((len(train_feeder.image) + FLAGS.batch_size - 1) /
                          FLAGS.batch_size)
        config = tf.ConfigProto(allow_soft_placement=True)
        result_dir = os.path.dirname(FLAGS.infer_file)
        with tf.Session(config=config) as sess, open(
                os.path.join(result_dir, 'result_digit_v1.txt'), 'w') as f:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            #saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
            #saver.restore(sess, './checkpoint_zhuyiwei/ocr-model-55001')
            variables_to_restore = model.variable_averages.variables_to_restore(
            )
            saver = tf.train.Saver(variables_to_restore)

            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            print("search from ", FLAGS.checkpoint_dir)
            print(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from ckpt{}'.format(ckpt))
            else:
                print('cannot restore')
            if not os.path.exists(FLAGS.output_dir):
                os.makedirs(FLAGS.output_dir)
            count = 0
            for curr_step in range(total_steps):
                decoded_expression = []

                dense_decoded_code = sess.run(decodes)
                #print('dense_decode', dense_decoded_code)
                for batch in dense_decoded_code:
                    for sequence in batch:
                        expression = ''
                        for code in sequence:
                            if code == utils.TOKEN["<EOS>"]:
                                break
                            if code not in utils.decode_maps:
                                expression += ''
                            else:
                                expression += utils.decode_maps[code]
                        decoded_expression.append(expression)
                for expression in decoded_expression:
                    if count >= len(train_feeder.image): break
                    #    f.write("%s,%s,%s\n"%(train_feeder.image[count], train_feeder.anno[count].encode('utf-8'), code.encode('utf-8')))
                    print(train_feeder.image[count])
                    #print(train_feeder.anno[count].encode('utf-8'))
                    #print(expression.encode('utf-8'))
                    print(train_feeder.anno[count])
                    print(expression)
                    print('')
                    filename = os.path.splitext(
                        os.path.basename(
                            train_feeder.image[count]))[0] + ".txt"
                    output_file = os.path.join(FLAGS.output_dir, filename)
                    cur = open(output_file, "w")
                    #cur.write(expression.encode('utf-8'))
                    cur.write(expression)
                    cur.close()
                    count += 1

            coord.request_stop()
            coord.join(threads)
Пример #28
0
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    #########################read  train   data###############################
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=FLAGS.train_dir, istrain=True)
    print('get image  data size: ', train_feeder.size)
    filename = train_feeder.image
    label = train_feeder.labels
    print(len(filename))
    train_data = ReadData.ReadData(filename, label)
    ##################################read  test   data######################################
    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=FLAGS.val_dir, istrain=False)
    filename1 = val_feeder.image
    label1 = val_feeder.labels
    test_data = ReadData.ReadData(filename1, label1)
    print('val get image: ', val_feeder.size)
    ##################计算batch 数
    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数
    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数
    ###########################data################################################

    with tf.device('/cpu:0'):

        config = tf.ConfigProto(allow_soft_placement=True)

        #######################read  data###################################

        with tf.Session(config=config) as sess:
            #初始化data迭代器
            train_data.init_itetator(sess)
            test_data.init_itetator(sess)
            train_data = train_data.get_nex_batch()
            test_data = test_data.get_nex_batch()
            #全局变量初始化
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0

            for cur_epoch in range(FLAGS.num_epochs):

                train_cost = 0
                batch_time = time.time()
                for cur_batch in range(num_batches_per_epoch):
                    #获得这一轮batch数据的标号##############################
                    #read_data_start = time.time()
                    batch_inputs, batch_labels = sess.run(train_data)
                    #print('read data timr',time.time()-read_data_start)
                    process_data_start = time.time()
                    #print('233333333333333',type(batch_labels))

                    new_batch_labels = utils.sparse_tuple_from_label(
                        batch_labels.tolist())  # 对了
                    batch_seq_len = np.asarray(
                        [FLAGS.max_stepsize for _ in batch_inputs],
                        dtype=np.int64)
                    #print('process data timr', time.time() - process_data_start)

                    #train_data_start = time.time()
                    #print('2444444',batch_inputs.shape())
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: new_batch_labels,
                        model.seq_len: batch_seq_len
                    }
                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)

                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size
                    #print  train_cost
                    #train_writer.add_summary(summary_str, step)
                    #print('train data timr', time.time() - train_data_start)
                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)
                    if (cur_batch) % 100 == 1:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time, 'loss', batch_cost)
                        batch_time = time.time()
                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        validation_start_time = time.time()
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            batch_inputs, batch_labels = sess.run(test_data)
                            new_batch_labels = utils.sparse_tuple_from_label(
                                batch_labels.tolist())  # 对了
                            batch_seq_len = np.asarray(
                                [FLAGS.max_stepsize for _ in batch_inputs],
                                dtype=np.int64)
                            val_feed = {
                                model.inputs: batch_inputs,
                                model.labels: new_batch_labels,
                                model.seq_len: batch_seq_len
                            }


                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            acc = utils.accuracy_calculation(
                                batch_labels.tolist(),
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)
                            acc_batch_total += acc
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples
                        accuracy_res.append(accuracy)
                        epoch_res.append(cur_epoch)
                        if accuracy > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch
                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"

                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - validation_start_time,
                                       lr))
Пример #29
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    print(FLAGS.image_channel)
    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('size: ', train_feeder.size)

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        sess.graph.finalize()
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, _, batch_labels = \
                    train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * FLAGS.batch_size, (j + 1) *
                                           FLAGS.batch_size)
                        ]
                        val_inputs, _, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {
                            model.inputs: val_inputs,
                            model.labels: val_labels
                        }

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)

                        acc = utils.accuracy_calculation(ori_labels,
                                                         dense_decoded,
                                                         ignore_value=-1,
                                                         isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total *
                                FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    with open('test_acc.txt', 'a') as f:
                        f.write(
                            str(
                                log.format(now.month, now.day, now.hour,
                                           now.minute, now.second, cur_epoch +
                                           1, FLAGS.num_epochs, accuracy,
                                           avg_train_cost, lastbatch_err,
                                           time.time() - start_time, lr)) +
                            "\n")
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   accuracy, avg_train_cost, lastbatch_err,
                                   time.time() - start_time, lr))
Пример #30
0
def train_process(mode=RunMode.Trains):
    model = framework_lstm.LSTM(mode)
    model.build_graph()
    test_list, trains_list = None, None
    if not HAS_TEST_SET:
        trains_list = os.listdir(TRAINS_PATH)
        random.shuffle(trains_list)
        origin_list = [
            os.path.join(TRAINS_PATH, trains)
            for i, trains in enumerate(trains_list)
        ]
        test_list = origin_list[:TEST_SET_NUM]
        trains_list = origin_list[TEST_SET_NUM:]

    print('Loading Trains DataSet...')
    train_feeder = utils.DataIterator(mode=RunMode.Trains)
    if TRAINS_USE_TFRECORDS:
        train_feeder.read_sample_from_tfrecords()
    else:
        train_feeder.read_sample_from_files(trains_list)
    print('Total {} Trains DataSets'.format(train_feeder.size))

    print('Loading Test DataSet...')
    test_feeder = utils.DataIterator(mode=RunMode.Test)
    if TEST_USE_TFRECORDS:
        test_feeder.read_sample_from_tfrecords()
    else:
        test_feeder.read_sample_from_files(test_list)
    print('Total {} Test DataSets'.format(test_feeder.size))

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / BATCH_SIZE)

    num_val_samples = test_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / BATCH_SIZE)
    shuffle_idx_val = np.random.permutation(num_val_samples)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allow_growth=True,  # it will cause fragmentation.
            per_process_gpu_memory_fraction=GPU_USAGE))
    accuracy = 0
    epoch_count = 1

    with tf.Session(config=config) as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        train_writer = tf.summary.FileWriter('logs', sess.graph)
        try:
            saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
        except ValueError:
            pass

        print('Start training...')

        while 1:
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()

            for cur_batch in range(num_batches_per_epoch):
                index_list = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch * BATCH_SIZE, (cur_batch + 1) *
                                   BATCH_SIZE)
                ]
                if TRAINS_USE_TFRECORDS:
                    batch_inputs, _, batch_labels = train_feeder.generate_batch_by_tfrecords(
                        sess)
                else:
                    batch_inputs, _, batch_labels = train_feeder.generate_batch_by_index(
                        index_list)

                feed = {
                    model.batch_size: BATCH_SIZE,
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                }

                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summary, model.cost, model.global_step,
                    model.train_op
                ], feed)
                train_cost += batch_cost * BATCH_SIZE
                train_writer.add_summary(summary_str, step)

                if step % TRAINS_SAVE_STEPS == 0:
                    saver.save(sess, SAVE_MODEL, global_step=step)
                    logger.info('save checkpoint at step {0}', format(step))

                if step % TRAINS_VALIDATION_STEPS == 0:
                    acc_batch_total = 0
                    lr = 0
                    batch_time = time.time()

                    for j in range(num_batches_per_epoch_val):
                        index_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * BATCH_SIZE, (j + 1) *
                                           BATCH_SIZE)
                        ]
                        if TRAINS_USE_TFRECORDS:
                            val_inputs, _, val_labels = test_feeder.generate_batch_by_tfrecords(
                                sess)
                        else:
                            val_inputs, _, val_labels = test_feeder.generate_batch_by_index(
                                index_val)

                        val_feed = {
                            model.batch_size: BATCH_SIZE,
                            model.inputs: val_inputs,
                            model.labels: val_labels,
                        }

                        dense_decoded, last_batch_err, lr = sess.run(
                            [model.dense_decoded, model.cost, model.lrn_rate],
                            val_feed)
                        if TRAINS_USE_TFRECORDS:
                            ori_labels = test_feeder.label_by_tfrecords()
                        else:
                            ori_labels = test_feeder.label_by_index(index_val)
                        acc = utils.accuracy_calculation(
                            ori_labels,
                            dense_decoded,
                            ignore_value=-1,
                        )
                        acc_batch_total += acc

                    accuracy = (acc_batch_total * BATCH_SIZE) / num_val_samples
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * BATCH_SIZE)

                    log = "Epoch: {}, Step: {} Accuracy = {:.3f}, Cost = {:.3f}, Time = {:.3f}, LearningRate: {}"
                    print(
                        log.format(epoch_count, step, accuracy, avg_train_cost,
                                   time.time() - batch_time, lr))
                    if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS:
                        break
            if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS:
                compile_graph(sess, accuracy)
                print('Total Time: {}'.format(time.time() - start_time))
                break
            epoch_count += 1

        coord.request_stop()
        coord.join(threads)