コード例 #1
0
ファイル: test.py プロジェクト: kapitsa2811/ICDAR_task2
def acc():
    acc_all = []
    g = model_transfer.Graph(pb_file_path = pb_file_path)
    with tf.Session(graph = g.graph) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(),max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess,ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore') 
        test_feeder=utils.DataIterator2(data_dir=data_dir,text_dir=text_dir)
        print("total data:",test_feeder.size)
        print("total image in folder", test_feeder.total_pic_read)
        total_epoch = int(test_feeder.size / FLAGS.batch_size) + 1
        for cur_batch in range(total_epoch):
            print("cur_epoch/total_epoch",cur_batch,"/",total_epoch)
            indexs=[]
            cur_batch_num = FLAGS.batch_size
            if cur_batch == int(test_feeder.size / FLAGS.batch_size):
                cur_batch_num = test_feeder.size - cur_batch * FLAGS.batch_size 
            for i in range(cur_batch_num):
                indexs.append(cur_batch * FLAGS.batch_size + i) 
            test_inputs,test_seq_len,test_labels=test_feeder.input_index_generate_batch(indexs)
            cur_labels = [test_feeder.labels[i] for i in indexs]
            test_feed={g.original_pic: test_inputs,
                      g.labels: test_labels,
                      g.seq_len: np.array([Flage_width]*test_inputs.shape[0])}
            dense_decoded= sess.run(g.dense_decoded, test_feed)
            acc = utils.accuracy_calculation(cur_labels,dense_decoded,ignore_value=-1,isPrint=False)
            acc_all.append(acc)
        print("$$$$$$$$$$$$$$$$$ ACC is :",acc_all,"$$$$$$$$$$$$$$$$$")
        print("avg_acc:",np.array(acc_all).mean()) 
コード例 #2
0
def acc():
    g = model.Graph()
    with tf.Session(graph=g.graph) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print('restore from ckpt{}'.format(ckpt))
        else:
            print('cannot restore')
        val_feeder = utils.DataIterator(data_dir=inferFolder)
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([27] * val_inputs.shape[0])
        }
        dense_decoded = sess.run(g.dense_decoded, val_feed)

        # print the decode result
        acc = utils.accuracy_calculation(val_feeder.labels,
                                         dense_decoded,
                                         ignore_value=-1,
                                         isPrint=True)
        print(acc)
コード例 #3
0
def train(train_dir=None, val_dir=None):
	#载入训练、测试数据
	print('Loading training data...')
	train_feeder = utils.DataIterator(data_dir=train_dir)
	print('Get images: ', train_feeder.size)

	print('Loading validate data...')
	val_feeder=utils.DataIterator(data_dir=val_dir)
	print('Get images: ', val_feeder.size)

	#定义网络结构
	g = Graph()
	
	#训练样本总数
	num_train_samples = train_feeder.size
	#每一轮(epoch)样本可以跑多少个batch
	num_batches_per_epoch = int(num_train_samples / batch_size)
	
	with tf.Session(graph = g.graph) as sess:
		sess.run(tf.global_variables_initializer())
		
		saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
		
		# restore = True 加载模型
		if restore:
			ckpt = tf.train.latest_checkpoint(checkpoint_dir)
			if ckpt:
				# global_step也会被加载
				saver.restore(sess, ckpt);
				print('restore from the checkpoint{0}'.format(ckpt))

		print('============begin training============')
		# 获取一个batch的验证数据,制作成placeholder的输入格式
		val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch()
		val_feed = {g.inputs: val_inputs, g.labels: val_labels, g.seq_len: val_seq_len}
		
		start_time = time.time();
		for cur_epoch in range(num_epochs):	#按照epoch进行循环
			shuffle_idx = np.random.permutation(num_train_samples)	#将训练样本的index打乱
			train_cost = 0;

			for cur_batch in range(num_batches_per_epoch):	#对于当前epoch中的每个bacth进行训练
				# 获取一个batch的训练样本,制作成placeholder的输入格式
				indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * batch_size, (cur_batch+1) * batch_size)];
				batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs);
				feed = {g.inputs: batch_inputs, g.labels:batch_labels, g.seq_len:batch_seq_len};

				# 训练run
				summary_str, batch_cost, step, _ = sess.run([g.merged_summay, g.cost, g.global_step, g.optimizer], feed)
				# 计算损失
				train_cost += batch_cost;

				# 打印
				if step % 50 == 1:
					end_time = time.time();
					print('No. %5d batches, loss: %5.2f, time: %3.1fs' % (step, batch_cost, end_time-start_time));
					start_time = time.time();
				
				#验证集验证、保存checkpoint:
				if step % validation_steps == 1:
					if not os.path.isdir(checkpoint_dir):	os.mkdir(checkpoint_dir);
					saver.save(sess,os.path.join(checkpoint_dir, 'ocr-model'), global_step=step)
					
					#解码的结果:
					dense_decoded, lastbatch_err, lr = sess.run([g.dense_decoded, g.lerr, g.learning_rate], val_feed)
					acc = utils.accuracy_calculation(val_feeder.labels, dense_decoded, ignore_value=-1, isPrint=False)
					print('-After %5d steps, Val accu: %4.2f%%' % (step, acc));
コード例 #4
0
ファイル: lstm_ocr.py プロジェクト: linVdcd/lstm_ctc_ocr
def train():
    g = Graph()
    with g.graph.as_default():
        print('loading train data, please wait---------------------', end=' ')
        train_feeder = utils.DataIterator(data_dir='../train/')
        print('get image: ', train_feeder.size)

        print('loading validation data, please wait---------------------',
              end=' ')
        val_feeder = utils.DataIterator(data_dir='../test/')
        print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 12800
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 12800/64

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        # the cuda trace
        #run_metadata = tf.RunMetadata()
        #trace_file = open('timeline.ctf.json','w')
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: val_seq_len
        }
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = train_err = 0
            start = time.time()
            batch_time = time.time()
            #the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.seq_len: batch_seq_len
                }
                #_,batch_cost, the_err,d,lr,train_summary,step = sess.run([optimizer,cost,lerr,decoded[0],learning_rate,merged_summay,global_step],feed)
                #_,batch_cost, the_err,d,lr,step = sess.run([optimizer,cost,lerr,decoded[0],learning_rate,global_step],feed)
                #the_err,d,lr = sess.run([lerr,decoded[0],learning_rate])

                # if summary is needed
                #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
                batch_cost, step, _ = sess.run(
                    [g.cost, g.global_step, g.optimizer], feed)
                #calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                ## the tracing part
                #_,batch_cost,the_err,step,lr,d = sess.run([optimizer,cost,lerr,
                #    global_step,learning_rate,decoded[0]],feed)
                #options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                #run_metadata=run_metadata)
                #trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                #race_file.write(trace.generate_chrome_trace_format())
                #trace_file.close()

                #train_writer.add_summary(train_summary,step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
                #train_err+=the_err*FLAGS.batch_size
            d, lastbatch_err = sess.run([g.decoded[0], g.lerr], val_feed)
            dense_decoded = tf.sparse_tensor_to_dense(
                d, default_value=-1).eval(session=sess)
            # print the decode result
            acc = utils.accuracy_calculation(val_feeder.labels,
                                             dense_decoded,
                                             ignore_value=-1,
                                             isPrint=True)
            train_cost /= num_train_samples
            #train_err/=num_train_samples
            now = datetime.datetime.now()
            log = "{}-{} {}:{}:{} Epoch {}/{}, accuracy = {:.3f},train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f}"
            print(
                log.format(now.month, now.day, now.hour, now.minute,
                           now.second, cur_epoch + 1, FLAGS.num_epochs, acc,
                           train_cost, lastbatch_err,
                           time.time() - start))
コード例 #5
0
def train(train_dir=None, val_dir=None, mode='train'):
    #加载模型类
    model = orcmodel.LSTMOCR(mode)
    #创建模型,这一步,运算图和训练操作等都已经具备
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir)  # 准备训练数据
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir)  # 准备验证数据
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 训练样本个数
    num_batches_per_epoch = int(
        num_train_samples / FLAGS.batch_size)  # 一轮有多少批次 example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(
        num_val_samples / FLAGS.batch_size)  # 一轮有多少批次 example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/cpu:0'):
        # ConfigProto 用于配置Session
        # 如果你指定的设备不存在,允许TF自动分配设备
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            # max_to_keep 用于指定保存最近的N个checkpoint
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=50)
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)  # summary 可视化
            # 根据配置是否恢复权重
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # 恢复的时候global_step也会被恢复
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

            print(
                '=============================begin training============================='
            )
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(
                    num_train_samples)  # 乱序训练样本的index,达到SGD
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # 开始一轮的N批训练
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()

                    # 生成训练批次的index
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]
                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)

                    # 填充placeholder
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # 开始run,记录可视化数据,计算成本值,获取global_step,并训练
                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)

                    # batch_cost是一个均值,这里计算一个batch的cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(
                        summary_str,
                        step)  # run merge_all的到一个summery信息,然后写入,用于可视化

                    # 保存checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # 验证
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            # 按SGD验证一个最小批
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            # 解码,并获得当前学习率
                            dense_decoded, lr = sess.run(
                                [model.dense_decoded, model.lrn_rate],
                                val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)

                            # 计算一个批次正确率
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total * FLAGS.batch_size
                                    ) / num_val_samples  # 求一轮的平均正确率

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)  # 整一轮的当前平均损失值

                        # 输出训练最新信息
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_cost = {:f}, time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, accuracy,
                                       avg_train_cost, batch_cost,
                                       time.time() - start_time, lr))
コード例 #6
0
def train(train_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    label_files = [os.pat.join(train_dir, e) for e in os.listdir(train_dir) if e.endswith('.txt') and os.path.exists(os.path.join(train_dir, e.replace('.txt', '.jpg')))]
    train_num = int(len(label_files) * 0.8)
    test_num = len(label_files) - train_num

    print('total num', len(label_files), 'train num', train_num, 'test num', test_num)
    train_imgs = label_files[0:train_num]
    test_imgs = label_files[train_num:]


    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_imgs)
    

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=test_imgs)
   

    num_batches_per_epoch_train = int(train_num / FLAGS.batch_size)  # example: 100000/100

    num_batches_per_epoch_val = int(test_num / FLAGS.batch_size)  # example: 10000/100
   

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            if cur_epoch == 0:
                random.shuffle(train_feeder.train_data)

            # the training part
            for cur_batch in range(num_batches_per_epoch_train):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                
                batch_inputs, result_img_length, batch_labels = \
                    train_feeder.get_batchsize_data(cur_batch)
               
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: result_img_length}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for val_j in range(num_batches_per_epoch_val):
                        result_img_val, seq_len_input_val, batch_label_val = \
                            val_feeder.get_batchsize_data(val_j)
                        val_feed = {model.inputs: result_img_val,
                                    model.labels: batch_label_val,
                                    model.seq_len: seq_len_input_val}

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        val_pre_list = []
                        for decode_code in dense_decoded:
                            pred_strings = utils.label2text(decode_code)
                            val_pre_list.append(pred_strings)
                        ori_labels = val_feeder.get_val_label(val_j)

                        acc = utils.accuracy_calculation(ori_labels, val_pre_list,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = acc_batch_total / num_batches_per_epoch_val

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     lastbatch_err, time.time() - start_time, lr))
コード例 #7
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        num_batches_per_epoch = int(TRAIN_SET_NUM /
                                    FLAGS.batch_size)  # example: 100000/100
        num_batches_per_epoch_val = int(TRAIN_SET_NUM /
                                        FLAGS.batch_size)  # example: 10000/100

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)

        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        for epoch in range(1000):
            train_feeder = sess.run(train_batch)
            batch_inputs,  batch_labels = \
                train_feeder[0],read_labels(train_feeder[1])
            feed = {model.inputs: batch_inputs, model.labels: batch_labels}
            summary_str, batch_cost, step, _ = \
                sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
            train_writer.add_summary(summary_str, step)
            if step % FLAGS.save_steps == 1:
                if not os.path.isdir(FLAGS.checkpoint_dir):
                    os.mkdir(FLAGS.checkpoint_dir)
                logger.info('save checkpoint at step {0}', format(step))
                saver.save(sess,
                           os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                           global_step=step)
            # do validation
            if step % FLAGS.validation_steps == 0:
                acc_batch_total = 0
                lastbatch_err = 0
                lr = 0
                for j in range(num_batches_per_epoch_val):
                    val_inputs, val_labels = \
                        train_feeder[0], read_labels(train_feeder[1])

                    val_feed = {
                        model.inputs: val_inputs,
                        model.labels: val_labels
                    }

                    dense_decoded, lastbatch_err, lr = \
                        sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                 val_feed)
                    # print the decode result
                    print(dense_decoded)
                    print(val_labels)

                    acc = utils.accuracy_calculation(val_labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    acc_batch_total += acc

                accuracy = (acc_batch_total * FLAGS.batch_size) / 2

                # train_err /= num_train_samples
                now = datetime.datetime.now()
                log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                      "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                      "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                print(
                    log.format(now.month, now.day, now.hour, now.minute,
                               now.second, epoch + 1, FLAGS.num_epochs,
                               accuracy, epoch, lastbatch_err,
                               time.time() - epoch, lr))
コード例 #8
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None,
          pb_file_path=None):
    g = model_transfer.Graph(is_training=True, pb_file_path=pb_file_path)

    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('***************get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('***************get image: ', val_feeder.size)
    '''
    print('loading train data, please wait---------------------','end= ')
    train_feeder=utils.DataIterator(data_dir=train_dir)
    print('get image: ',train_feeder.size)
    print('loading validation data, please wait---------------------','end= ')
    val_feeder=utils.DataIterator(data_dir=val_dir)
    print('get image: ',val_feeder.size)
    '''
    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        g.graph.finalize()
        print(
            '=============================begin training============================='
        )
        cur_training_step = 0
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        val_feed = {
            g.original_pic: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([Flage_width] * val_inputs.shape[0])
        }

        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                cur_training_step += 1
                if cur_training_step % Flage_print_frequency == 0:
                    print("epochs", cur_epoch, cur_batch)

                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                transfer_train_batch_feed = {
                    g.original_pic: batch_inputs,
                    g.seq_len: np.array([Flage_width] * batch_inputs.shape[0]),
                    g.labels: batch_labels
                }
                summary_str, batch_cost, all_step, _ = sess.run(
                    [g.merged_summay, g.cost, g.global_step, g.optimizer],
                    transfer_train_batch_feed)
                train_cost += batch_cost * FLAGS.batch_size

                if all_step % FLAGS.save_steps == 0:
                    print("**********save checkpoint********** all_step:",
                          all_step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=all_step)

                if all_step % FLAGS.validation_steps == 0:
                    print("**********CrossValidation********** all_step:",
                          all_step)
                    dense_decoded, lastbatch_err, lr = sess.run(
                        [g.dense_decoded, g.lerr, g.learning_rate], val_feed)
                    acc = utils.accuracy_calculation(val_feeder.labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{}  all_step==={}, Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f},lr={:.8f}\n"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, all_step, cur_epoch + 1,
                                   FLAGS.num_epochs, acc, avg_train_cost,
                                   lastbatch_err, lr))
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, all_step,
                                       cur_epoch + 1, FLAGS.num_epochs, acc,
                                       avg_train_cost, lastbatch_err, lr))
                        f.close()
コード例 #9
0
def train(train_dir=None,
          val_dir=None,
          train_text_dir=None,
          val_text_dir=None):
    g = model.Graph(is_training=True)
    print('loading train data, please wait---------------------', 'end= ')
    train_feeder = utils.DataIterator2(data_dir=train_dir,
                                       text_dir=train_text_dir)
    print('get image: ', train_feeder.size)
    print('loading validation data, please wait---------------------', 'end= ')
    val_feeder = utils.DataIterator2(data_dir=val_dir, text_dir=val_text_dir)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)
    num_val_samples = val_feeder.size
    num_val_per_epoch = int(num_val_samples / FLAGS.batch_size)

    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=False)
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.6
    with tf.Session(graph=g.graph, config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=3)
        g.graph.finalize()
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            print("restore is true")
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(
        )
        #print(len(val_inputs))
        val_feed = {
            g.inputs: val_inputs,
            g.labels: val_labels,
            g.seq_len: np.array([g.cnn_time] * val_inputs.shape[0])
        }
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()
            #the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(
                    indexs)
                #batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {
                    g.inputs: batch_inputs,
                    g.labels: batch_labels,
                    g.seq_len: np.array([g.cnn_time] * batch_inputs.shape[0])
                }

                # if summary is needed
                #batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)
                summary_str, batch_cost, step, _ = sess.run(
                    [g.merged_summay, g.cost, g.global_step, g.optimizer],
                    feed)
                #calculate the cost
                train_cost += batch_cost * FLAGS.batch_size
                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 0:
                    print("save checkpoint", step)
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)
                #train_err+=the_err*FLAGS.batch_size
                #do validation
                if step % FLAGS.validation_steps == 0:
                    dense_decoded, lastbatch_err, lr = sess.run(
                        [g.dense_decoded, g.lerr, g.learning_rate], val_feed)
                    # print the decode result
                    acc = utils.accuracy_calculation(val_feeder.labels,
                                                     dense_decoded,
                                                     ignore_value=-1,
                                                     isPrint=True)
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)
                    #train_err/=num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{}  step==={}, Epoch {}/{}, accuracy = {:.3f},avg_train_cost = {:.3f}, lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}\n"
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, step, cur_epoch + 1,
                                   FLAGS.num_epochs, acc, avg_train_cost,
                                   lastbatch_err,
                                   time.time() - start_time, lr))
                    if Flag_Isserver:
                        f = open('../log/acc/acc.txt', mode="a")
                        f.write(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, step,
                                       cur_epoch + 1, FLAGS.num_epochs, acc,
                                       avg_train_cost, lastbatch_err,
                                       time.time() - start_time, lr))
                        f.close()
コード例 #10
0
ファイル: main.py プロジェクト: luoqingyu/new-vgg
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    #########################read  train   data###############################
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=FLAGS.train_dir, istrain=True)
    print('get image  data size: ', train_feeder.size)
    filename = train_feeder.image
    label = train_feeder.labels
    print(len(filename))
    train_data = ReadData.ReadData(filename, label)
    ##################################read  test   data######################################
    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=FLAGS.val_dir, istrain=False)
    filename1 = val_feeder.image
    label1 = val_feeder.labels
    test_data = ReadData.ReadData(filename1, label1)
    print('val get image: ', val_feeder.size)
    ##################计算batch 数
    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数
    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数
    ###########################data################################################

    with tf.device('/cpu:0'):

        config = tf.ConfigProto(allow_soft_placement=True)

        #######################read  data###################################

        with tf.Session(config=config) as sess:
            #初始化data迭代器
            train_data.init_itetator(sess)
            test_data.init_itetator(sess)
            train_data = train_data.get_nex_batch()
            test_data = test_data.get_nex_batch()
            #全局变量初始化
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0

            for cur_epoch in range(FLAGS.num_epochs):

                train_cost = 0
                batch_time = time.time()
                for cur_batch in range(num_batches_per_epoch):
                    #获得这一轮batch数据的标号##############################
                    #read_data_start = time.time()
                    batch_inputs, batch_labels = sess.run(train_data)
                    #print('read data timr',time.time()-read_data_start)
                    process_data_start = time.time()
                    #print('233333333333333',type(batch_labels))

                    new_batch_labels = utils.sparse_tuple_from_label(
                        batch_labels.tolist())  # 对了
                    batch_seq_len = np.asarray(
                        [FLAGS.max_stepsize for _ in batch_inputs],
                        dtype=np.int64)
                    #print('process data timr', time.time() - process_data_start)

                    #train_data_start = time.time()
                    #print('2444444',batch_inputs.shape())
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: new_batch_labels,
                        model.seq_len: batch_seq_len
                    }
                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)

                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size
                    #print  train_cost
                    #train_writer.add_summary(summary_str, step)
                    #print('train data timr', time.time() - train_data_start)
                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)
                    if (cur_batch) % 100 == 1:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time, 'loss', batch_cost)
                        batch_time = time.time()
                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        validation_start_time = time.time()
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            batch_inputs, batch_labels = sess.run(test_data)
                            new_batch_labels = utils.sparse_tuple_from_label(
                                batch_labels.tolist())  # 对了
                            batch_seq_len = np.asarray(
                                [FLAGS.max_stepsize for _ in batch_inputs],
                                dtype=np.int64)
                            val_feed = {
                                model.inputs: batch_inputs,
                                model.labels: new_batch_labels,
                                model.seq_len: batch_seq_len
                            }


                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            acc = utils.accuracy_calculation(
                                batch_labels.tolist(),
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)
                            acc_batch_total += acc
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples
                        accuracy_res.append(accuracy)
                        epoch_res.append(cur_epoch)
                        if accuracy > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch
                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"

                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - validation_start_time,
                                       lr))
コード例 #11
0
ファイル: main.py プロジェクト: 980044579/chinese_ocr
def train(train_dir=None, val_dir=None, mode='train'):
    if FLAGS.model == 'lstm':
        model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    else:
        print("no such model")
        sys.exit()

    #开始构建图
    model.build_graph()
    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir, num=4000000)
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir, num=40000)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # 训练集一次epoch需要的batch数

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # 验证集一次epoch需要的batch数

    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/cpu:0'):
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config) as sess:
            #全局变量初始化
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=100)  #存储模型
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                                 sess.graph)

            #导入预训练模型
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    # the global_step will restore sa well
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))
                else:
                    print("No checkpoint")

            print(
                '=============================begin training============================='
            )
            accuracy_res = []
            accuracy_per_res = []
            epoch_res = []
            tmp_max = 0
            tmp_epoch = 0
            for cur_epoch in range(FLAGS.num_epochs):
                shuffle_idx = np.random.permutation(num_train_samples)
                train_cost = 0
                start_time = time.time()
                batch_time = time.time()

                # the tracing part
                for cur_batch in range(num_batches_per_epoch):
                    if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ': time',
                              time.time() - batch_time)
                    batch_time = time.time()

                    #获得这一轮batch数据的标号
                    indexs = [
                        shuffle_idx[i % num_train_samples]
                        for i in range(cur_batch *
                                       FLAGS.batch_size, (cur_batch + 1) *
                                       FLAGS.batch_size)
                    ]

                    batch_inputs, batch_seq_len, batch_labels = \
                        train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    feed = {
                        model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len
                    }

                    # if summary is needed
                    # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                    summary_str, batch_cost, step, _ = \
                        sess.run([model.merged_summay, model.cost, model.global_step,
                                  model.train_op], feed)
                    # calculate the cost
                    train_cost += batch_cost * FLAGS.batch_size

                    train_writer.add_summary(summary_str, step)

                    # save the checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess,
                                   os.path.join(FLAGS.checkpoint_dir,
                                                'ocr-model'),
                                   global_step=step)

                    # train_err += the_err * FLAGS.batch_size
                    # do validation
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        acc_per_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        for j in range(num_batches_per_epoch_val):
                            indexs_val = [
                                shuffle_idx_val[i % num_val_samples]
                                for i in range(j * FLAGS.batch_size, (j + 1) *
                                               FLAGS.batch_size)
                            ]
                            val_inputs, val_seq_len, val_labels = \
                                val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {
                                model.inputs: val_inputs,
                                model.labels: val_labels,
                                model.seq_len: val_seq_len
                            }

                            dense_decoded, lr = \
                                sess.run([model.dense_decoded, model.lrn_rate],
                                         val_feed)

                            # print the decode result
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels,
                                                             dense_decoded,
                                                             ignore_value=-1,
                                                             isPrint=True)

                            acc_per = utils.accuracy_calculation_single(
                                ori_labels,
                                dense_decoded,
                                ignore_value=-1,
                                isPrint=True)

                            acc_per_batch_total += acc_per
                            acc_batch_total += acc

                        accuracy_per = (acc_per_batch_total *
                                        FLAGS.batch_size) / num_val_samples
                        accuracy = (acc_batch_total *
                                    FLAGS.batch_size) / num_val_samples

                        accuracy_per_res.append(accuracy_per)
                        accuracy_res.append(accuracy)

                        epoch_res.append(cur_epoch)
                        if accuracy_per > tmp_max:
                            tmp_max = accuracy
                            tmp_epoch = cur_epoch

                        avg_train_cost = train_cost / (
                            (cur_batch + 1) * FLAGS.batch_size)

                        # train_err /= num_train_samples
                        now = datetime.datetime.now()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \
                              " time = {:.3f},lr={:.8f}"
                        print(
                            log.format(now.month, now.day, now.hour,
                                       now.minute, now.second, cur_epoch + 1,
                                       FLAGS.num_epochs, tmp_max, tmp_epoch,
                                       accuracy_per, acc_batch_total,
                                       avg_train_cost,
                                       time.time() - start_time, lr))
コード例 #12
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = mdlstm_ctc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('get image: ', train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('get image: ', val_feeder.size)

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    # define batch_size    
    batch_size = FLAGS.batch_size


    config = tf.ConfigProto( allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')            
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels,
                        model.seq_len: batch_seq_len}
                model.is_training = True
                # if summary is needed
                # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                summary_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step,model.train_op], feed)
                # calculate the cost
                delta_batch_cost = batch_cost * FLAGS.batch_size
                train_cost += delta_batch_cost
                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [shuffle_idx_val[i % num_val_samples] for i in
                                      range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                        val_inputs, val_seq_len, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {model.inputs: val_inputs,
                                    model.labels: val_labels,
                                    model.seq_len: val_seq_len}
                        model.is_training = False
                        dense_decoded, lr = \
                            sess.run([model.dense_decoded, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)
                        acc = utils.accuracy_calculation(ori_labels, dense_decoded,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     time.time() - start_time, lr))
コード例 #13
0
def train(train_dir=None,val_dir=None,initial_learning_rate=None,num_hidden=None,num_classes=None,hparam=None):
    g = Graph()
    print('loading train data, please wait---------------------')
    train_feeder=utils.DataIterator(data_dir=train_dir)
    print('get image: ',train_feeder.size)

    print('loading validation data, please wait---------------------')
    val_feeder=utils.DataIterator(data_dir=val_dir)
    print('get image: ',val_feeder.size)

    num_train_samples = train_feeder.size 
    num_batches_per_epoch = int(num_train_samples/FLAGS.batch_size) 

    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=False)
    with tf.Session(graph=g.graph,config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(),max_to_keep=100) #持久化
        g.graph.finalize() #???
        train_writer=tf.summary.FileWriter(FLAGS.log_dir+hparam,sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess,ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        val_inputs,val_seq_len,val_labels,_=val_feeder.input_index_generate_batch()  # seq_len 时序长度120
        
        val_feed={g.inputs: val_inputs,
                  g.labels: val_labels,
                 g.seq_len: val_seq_len}
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx=np.random.permutation(num_train_samples)
            train_cost = train_err=0
            start_time = time.time()
            batch_time = time.time()
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                indexs = [shuffle_idx[i%num_train_samples] for i in range(cur_batch*FLAGS.batch_size,(cur_batch+1)*FLAGS.batch_size)]
                batch_inputs,batch_seq_len,batch_labels,label_batch=train_feeder.input_index_generate_batch(indexs)
                feed={g.inputs: batch_inputs,
                        g.labels:batch_labels,
                        g.seq_len:batch_seq_len}

                train_dense_decoded,summary_str, batch_cost,step,_ = sess.run([g.dense_decoded,g.merged_summay,g.CTCloss,g.global_step,g.optimizer],feed)
            #check computing cell propeties
#                print "||||||||||||||||||||"
         #       print(sess.run(g.dense_decoded,feed))
    #            print(np.shape(sess.run(g.pool1,feed)))
                train_cost+=batch_cost*FLAGS.batch_size
                train_writer.add_summary(summary_str,step)
                # save the checkpoint
                if step%FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    saver.save(sess,os.path.join(FLAGS.checkpoint_dir,'ocr-model'),global_step=step)
                if step%FLAGS.validation_steps == 0:
                    dense_decoded,validation_length_err,learningRate = sess.run([g.dense_decoded,g.lengthOferr,
                        g.learning_rate],val_feed)
                    valid_acc = utils.accuracy_calculation(val_feeder.labels,dense_decoded,ignore_value=-1,isPrint=True)
                    train_acc = utils.accuracy_calculation(label_batch,train_dense_decoded,ignore_value=-1,isPrint=True)
                    avg_train_cost=train_cost/((cur_batch+1)*FLAGS.batch_size)
                    now = datetime.datetime.now()
                    log = "*{}/{} {}:{}:{} Epoch {}/{}, accOfvalidation = {:.3f},train_accuracy={:.3f}, time = {:.3f},learningRate={:.8f}"
                    print(log.format(now.month,now.day,now.hour,now.minute,now.second,cur_epoch+1,FLAGS.num_epochs,valid_acc,train_acc,time.time()-start_time,learningRate)) 
コード例 #14
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data, please wait---------------------')
    train_feeder, num_train_samples = data_prep.input_batch_generator('train', is_training=True, batch_size = FLAGS.batch_size)
    print('get image: ', num_train_samples)

    print('loading validation data, please wait---------------------')
    val_feeder, num_val_samples = data_prep.input_batch_generator('val', is_training=False, batch_size = FLAGS.batch_size * 2)
    print('get image: ', num_val_samples)

   
    num_batches_per_epoch = int(math.ceil(num_train_samples / float(FLAGS.batch_size)))

    

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from the checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            start_time = time.time()
            batch_time = time.time()

            # the tracing part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                batch_inputs, batch_labels, _ = next(train_feeder)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels}

                # if summary is needed
                # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed)

                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step,
                              model.train_op], feed)
                # calculate the cost

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save the checkpoint of{0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    
                    val_inputs, val_labels, ori_labels = next(val_feeder)    
                    val_feed = {model.inputs: val_inputs,
                                model.labels: val_labels}

                    dense_decoded, lr = \
                        sess.run([model.dense_decoded, model.lrn_rate],
                                 val_feed)

                    # print the decode result
                    accuracy = utils.accuracy_calculation(ori_labels, dense_decoded,
                                                     ignore_value=-1, isPrint=True)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.5f},train_cost = {:.5f}, " \
                          ", time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, batch_cost,
                                     time.time() - start_time, lr))
コード例 #15
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()

    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('size: ', train_feeder.size)

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print('=============================begin training=============================')
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time', time.time() - batch_time)
                batch_time = time.time()
                indexs = [shuffle_idx[i % num_train_samples] for i in
                          range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                batch_inputs, _, batch_labels = \
                    train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs,
                        model.labels: batch_labels}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [shuffle_idx_val[i % num_val_samples] for i in
                                      range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                        val_inputs, _, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {model.inputs: val_inputs,
                                    model.labels: val_labels}

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)
                        acc = utils.accuracy_calculation(ori_labels, dense_decoded,
                                                         ignore_value=-1, isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    print(log.format(now.month, now.day, now.hour, now.minute, now.second,
                                     cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost,
                                     lastbatch_err, time.time() - start_time, lr))
コード例 #16
0
def train_process(mode=RunMode.Trains):
    model = framework.GraphOCR(mode, NETWORK_MAP[NEU_CNN],
                               NETWORK_MAP[NEU_RECURRENT])
    model.build_graph()

    print('Loading Trains DataSet...')
    train_feeder = utils.DataIterator(mode=RunMode.Trains)
    if TRAINS_USE_TFRECORDS:
        train_feeder.read_sample_from_tfrecords(TRAINS_PATH)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_tfrecords(TEST_PATH)
    else:
        if isinstance(TRAINS_PATH, list):
            origin_list = []
            for trains_path in TRAINS_PATH:
                origin_list += [
                    os.path.join(trains_path, trains)
                    for trains in os.listdir(trains_path)
                ]
        else:
            origin_list = [
                os.path.join(TRAINS_PATH, trains)
                for trains in os.listdir(TRAINS_PATH)
            ]
        random.shuffle(origin_list)
        if not HAS_TEST_SET:
            test_list = origin_list[:TEST_SET_NUM]
            trains_list = origin_list[TEST_SET_NUM:]
        else:
            if isinstance(TEST_PATH, list):
                test_list = []
                for test_path in TEST_PATH:
                    test_list += [
                        os.path.join(test_path, test)
                        for test in os.listdir(test_path)
                    ]
            else:
                test_list = [
                    os.path.join(TEST_PATH, test)
                    for test in os.listdir(TEST_PATH)
                ]
            random.shuffle(test_list)
            trains_list = origin_list
        train_feeder.read_sample_from_files(trains_list)
        print('Loading Test DataSet...')
        test_feeder = utils.DataIterator(mode=RunMode.Test)
        test_feeder.read_sample_from_files(test_list)

    print('Total {} Trains DataSets'.format(train_feeder.size))
    print('Total {} Test DataSets'.format(test_feeder.size))
    if test_feeder.size >= train_feeder.size:
        exception(
            "The number of training sets cannot be less than the test set.", )

    num_train_samples = train_feeder.size
    num_test_samples = test_feeder.size
    if num_test_samples < TEST_BATCH_SIZE:
        exception(
            "The number of test sets cannot be less than the test batch size.",
            ConfigException.INSUFFICIENT_SAMPLE)
    num_batches_per_epoch = int(num_train_samples / BATCH_SIZE)

    config = tf.ConfigProto(
        # allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allocator_type='BFC',
            allow_growth=True,  # it will cause fragmentation.
            per_process_gpu_memory_fraction=GPU_USAGE))
    accuracy = 0
    epoch_count = 1

    with tf.Session(config=config) as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        train_writer = tf.summary.FileWriter('logs', sess.graph)
        try:
            saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
        except ValueError:
            pass
        print('Start training...')

        while 1:
            shuffle_trains_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            _avg_train_cost = 0
            for cur_batch in range(num_batches_per_epoch):
                batch_time = time.time()
                index_list = [
                    shuffle_trains_idx[i % num_train_samples]
                    for i in range(cur_batch * BATCH_SIZE, (cur_batch + 1) *
                                   BATCH_SIZE)
                ]
                if TRAINS_USE_TFRECORDS:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_tfrecords(
                        sess)
                else:
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.generate_batch_by_files(
                        index_list)

                feed = {
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                }

                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summary, model.cost, model.global_step,
                    model.train_op
                ],
                                                            feed_dict=feed)
                train_cost += batch_cost * BATCH_SIZE
                avg_train_cost = train_cost / ((cur_batch + 1) * BATCH_SIZE)
                _avg_train_cost = avg_train_cost
                train_writer.add_summary(summary_str, step)

                if step % 100 == 0 and step != 0:
                    print('Step: {} Time: {:.3f}, Cost = {:.5f}'.format(
                        step,
                        time.time() - batch_time, avg_train_cost))

                if step % TRAINS_SAVE_STEPS == 0 and step != 0:
                    saver.save(sess, SAVE_MODEL, global_step=step)
                    logger.info('save checkpoint at step {0}', format(step))

                if step % TRAINS_VALIDATION_STEPS == 0 and step != 0:
                    shuffle_test_idx = np.random.permutation(num_test_samples)
                    batch_time = time.time()
                    index_test = [
                        shuffle_test_idx[i % num_test_samples]
                        for i in range(cur_batch *
                                       TEST_BATCH_SIZE, (cur_batch + 1) *
                                       TEST_BATCH_SIZE)
                    ]
                    if TRAINS_USE_TFRECORDS:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_tfrecords(
                            sess)
                    else:
                        test_inputs, batch_seq_len, test_labels = test_feeder.generate_batch_by_files(
                            index_test)

                    val_feed = {
                        model.inputs: test_inputs,
                        model.labels: test_labels
                    }
                    dense_decoded, lr = sess.run(
                        [model.dense_decoded, model.lrn_rate],
                        feed_dict=val_feed)
                    accuracy = utils.accuracy_calculation(
                        test_feeder.labels(
                            None if TRAINS_USE_TFRECORDS else index_test),
                        dense_decoded,
                        ignore_value=[0, -1],
                    )
                    log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \
                          "Time = {:.3f}, LearningRate: {}"
                    print(
                        log.format(epoch_count, step, accuracy, avg_train_cost,
                                   time.time() - batch_time, lr))

                    if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and avg_train_cost <= TRAINS_END_COST:
                        break
            if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS and _avg_train_cost <= TRAINS_END_COST:
                compile_graph(accuracy)
                print('Total Time: {}'.format(time.time() - start_time))
                break
            epoch_count += 1
コード例 #17
0
def train(train_dir = None, val_dir = None, mode = 'train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    # 创建图
    model.build_graph()

    print('loading train data, please wait---------------------')
    # 训练数据构造器
    train_feeder = utils.DataIterator(data_dir = train_dir)
    print('get image:', train_feeder.size)
    print('loading validation data, please wait---------------------')
    # 验证数据构造器
    val_feeder = utils.DataIterator(data_dir = val_dir)
    print('get image:', val_feeder.size)

    num_train_samples = train_feeder.size # 100000
    num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000 / 100

    num_val_samples = val_feeder.size # 100000
    num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 100000 / 100
    # 随机打乱验证集样本
    shuffle_idx_val = np.random.permutation(num_val_samples)

    with tf.device('/gpu:0'):
        # tf.ConfigProto一般用在创建session的时候。用来对session进行参数配置
        # allow_soft_placement = True
        # 如果你指定的设备不存在,允许TF自动分配设备
        config = tf.ConfigProto(allow_soft_placement = True)
        with tf.Session(config = config) as sess:
            sess.run(tf.global_variables_initializer())

            # 创建saver对象,用来保存和恢复模型的参数
            saver = tf.train.Saver(tf.global_variables(), max_to_keep = 100)
            # 将sess里的graph放到日志文件中
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            # 如果之前有保存的模型参数,将之恢复到现在的sess中
            if FLAGS.restore:
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print('restore from the checkpoint{0}'.format(ckpt))

             print('=============================begin training=============================')
             # 开始训练
             for cur_epoch in range(FLAGS.num_epochs):
                 shuffle_idx = np.random.permutation(num_train_samples)
                 train_cost = 0
                 start_time = time.time()
                 batch_time = time.time()

                 for cur_batch in range(num_batches_per_epoch):
                     if (cur_batch + 1) % 100 == 0:
                        print('batch', cur_batch, ':time', time.time() - batch_time)
                    batch_time = time.time()
                    # 构造当前batch的样本indexs
                    # 在训练样本空间中随机选取batch_size数量的的样本
                    indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)]
                    batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs)
                    # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                    # 构造模型feed参数
                    feed = {model.inputs:batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len}

                    # 执行图
                    # fetch操作取回tensors
                    summar_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                    # 计算损失值
                    # 这里的batch_cost是一个batch里的均值
                    train_cost += batch_cost * FLAGS.batch_size
                    # 可视化
                    train_writer.add_summary(summar_str, step)

                    # 保存模型文件checkpoint
                    if step % FLAGS.save_steps == 1:
                        if not os.path.isdir(FLAGS.checkpoint_dir):
                            os.mkdir(FLAGS.checkpoint_dir)
                        logger.info('save the checkpoint of{0}', format(step))
                        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step = step)

                    # 每个batch验证集上得到解码结果
                    if step % FLAGS.validation_steps == 0:
                        acc_batch_total = 0
                        lastbatch_err = 0
                        lr = 0
                        # 得到验证集的输入
                        # 每个batch做迭代验证
                        for j in xrange(num_batches_per_epoch_val):
                            indexs_val = [shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)]
                            val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(indexs_val)
                            val_feed = {model.inputs: val_inputs, modell.labels: val_labels, model.seq_len: val_seq_len}

                            dense_decoded, lastbatch_err, lr = sess.run([model.dense_decoded, model.lrn_rate], val_feed)

                            # 打印在验证集上返回的结果
                            ori_labels = val_feeder.the_label(indexs_val)
                            acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value = -1, isPrint = True)
                            acc_batch_total += acc

                        accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples

                        avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size)
                        # train_err /= num_train_smaples
                        
                        now = datetime.datetime.time()
                        log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                              "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                              "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                        print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
コード例 #18
0
def train(train_dir=None, val_dir=None, mode='train'):
    model = cnn_lstm_otc_ocr.LSTMOCR(mode)
    model.build_graph()
    print(FLAGS.image_channel)
    print('loading train data')
    train_feeder = utils.DataIterator(data_dir=train_dir)
    print('size: ', train_feeder.size)

    print('loading validation data')
    val_feeder = utils.DataIterator(data_dir=val_dir)
    print('size: {}\n'.format(val_feeder.size))

    num_train_samples = train_feeder.size  # 100000
    num_batches_per_epoch = int(num_train_samples /
                                FLAGS.batch_size)  # example: 100000/100

    num_val_samples = val_feeder.size
    num_batches_per_epoch_val = int(num_val_samples /
                                    FLAGS.batch_size)  # example: 10000/100
    shuffle_idx_val = np.random.permutation(num_val_samples)

    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                # the global_step will restore sa well
                saver.restore(sess, ckpt)
                print('restore from checkpoint{0}'.format(ckpt))

        print(
            '=============================begin training============================='
        )
        sess.graph.finalize()
        for cur_epoch in range(FLAGS.num_epochs):
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()
            batch_time = time.time()

            # the training part
            for cur_batch in range(num_batches_per_epoch):
                if (cur_batch + 1) % 100 == 0:
                    print('batch', cur_batch, ': time',
                          time.time() - batch_time)
                batch_time = time.time()
                indexs = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch *
                                   FLAGS.batch_size, (cur_batch + 1) *
                                   FLAGS.batch_size)
                ]
                batch_inputs, _, batch_labels = \
                    train_feeder.input_index_generate_batch(indexs)
                # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size)
                feed = {model.inputs: batch_inputs, model.labels: batch_labels}

                # if summary is needed
                summary_str, batch_cost, step, _ = \
                    sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed)
                # calculate the cost
                train_cost += batch_cost * FLAGS.batch_size

                train_writer.add_summary(summary_str, step)

                # save the checkpoint
                if step % FLAGS.save_steps == 1:
                    if not os.path.isdir(FLAGS.checkpoint_dir):
                        os.mkdir(FLAGS.checkpoint_dir)
                    logger.info('save checkpoint at step {0}', format(step))
                    saver.save(sess,
                               os.path.join(FLAGS.checkpoint_dir, 'ocr-model'),
                               global_step=step)

                # train_err += the_err * FLAGS.batch_size
                # do validation
                if step % FLAGS.validation_steps == 0:
                    acc_batch_total = 0
                    lastbatch_err = 0
                    lr = 0
                    for j in range(num_batches_per_epoch_val):
                        indexs_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * FLAGS.batch_size, (j + 1) *
                                           FLAGS.batch_size)
                        ]
                        val_inputs, _, val_labels = \
                            val_feeder.input_index_generate_batch(indexs_val)
                        val_feed = {
                            model.inputs: val_inputs,
                            model.labels: val_labels
                        }

                        dense_decoded, lastbatch_err, lr = \
                            sess.run([model.dense_decoded, model.cost, model.lrn_rate],
                                     val_feed)

                        # print the decode result
                        ori_labels = val_feeder.the_label(indexs_val)

                        acc = utils.accuracy_calculation(ori_labels,
                                                         dense_decoded,
                                                         ignore_value=-1,
                                                         isPrint=True)
                        acc_batch_total += acc

                    accuracy = (acc_batch_total *
                                FLAGS.batch_size) / num_val_samples

                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * FLAGS.batch_size)

                    # train_err /= num_train_samples
                    now = datetime.datetime.now()
                    log = "{}/{} {}:{}:{} Epoch {}/{}, " \
                          "accuracy = {:.3f},avg_train_cost = {:.3f}, " \
                          "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}"
                    with open('test_acc.txt', 'a') as f:
                        f.write(
                            str(
                                log.format(now.month, now.day, now.hour,
                                           now.minute, now.second, cur_epoch +
                                           1, FLAGS.num_epochs, accuracy,
                                           avg_train_cost, lastbatch_err,
                                           time.time() - start_time, lr)) +
                            "\n")
                    print(
                        log.format(now.month, now.day, now.hour, now.minute,
                                   now.second, cur_epoch + 1, FLAGS.num_epochs,
                                   accuracy, avg_train_cost, lastbatch_err,
                                   time.time() - start_time, lr))
コード例 #19
0
ファイル: trains.py プロジェクト: GoKey/captcha_trainer
def train_process(mode=RunMode.Trains):
    model = framework_lstm.LSTM(mode)
    model.build_graph()
    test_list, trains_list = None, None
    if not HAS_TEST_SET:
        trains_list = os.listdir(TRAINS_PATH)
        random.shuffle(trains_list)
        origin_list = [
            os.path.join(TRAINS_PATH, trains)
            for i, trains in enumerate(trains_list)
        ]
        test_list = origin_list[:TEST_SET_NUM]
        trains_list = origin_list[TEST_SET_NUM:]

    print('Loading Trains DataSet...')
    train_feeder = utils.DataIterator(mode=RunMode.Trains)
    if TRAINS_USE_TFRECORDS:
        train_feeder.read_sample_from_tfrecords()
    else:
        train_feeder.read_sample_from_files(trains_list)
    print('Total {} Trains DataSets'.format(train_feeder.size))

    print('Loading Test DataSet...')
    test_feeder = utils.DataIterator(mode=RunMode.Test)
    if TEST_USE_TFRECORDS:
        test_feeder.read_sample_from_tfrecords()
    else:
        test_feeder.read_sample_from_files(test_list)
    print('Total {} Test DataSets'.format(test_feeder.size))

    num_train_samples = train_feeder.size
    num_batches_per_epoch = int(num_train_samples / BATCH_SIZE)

    num_val_samples = test_feeder.size
    num_batches_per_epoch_val = int(num_val_samples / BATCH_SIZE)
    shuffle_idx_val = np.random.permutation(num_val_samples)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        gpu_options=tf.GPUOptions(
            allow_growth=True,  # it will cause fragmentation.
            per_process_gpu_memory_fraction=GPU_USAGE))
    accuracy = 0
    epoch_count = 1

    with tf.Session(config=config) as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        train_writer = tf.summary.FileWriter('logs', sess.graph)
        try:
            saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
        except ValueError:
            pass

        print('Start training...')

        while 1:
            shuffle_idx = np.random.permutation(num_train_samples)
            train_cost = 0
            start_time = time.time()

            for cur_batch in range(num_batches_per_epoch):
                index_list = [
                    shuffle_idx[i % num_train_samples]
                    for i in range(cur_batch * BATCH_SIZE, (cur_batch + 1) *
                                   BATCH_SIZE)
                ]
                if TRAINS_USE_TFRECORDS:
                    batch_inputs, _, batch_labels = train_feeder.generate_batch_by_tfrecords(
                        sess)
                else:
                    batch_inputs, _, batch_labels = train_feeder.generate_batch_by_index(
                        index_list)

                feed = {
                    model.batch_size: BATCH_SIZE,
                    model.inputs: batch_inputs,
                    model.labels: batch_labels,
                }

                summary_str, batch_cost, step, _ = sess.run([
                    model.merged_summary, model.cost, model.global_step,
                    model.train_op
                ], feed)
                train_cost += batch_cost * BATCH_SIZE
                train_writer.add_summary(summary_str, step)

                if step % TRAINS_SAVE_STEPS == 0:
                    saver.save(sess, SAVE_MODEL, global_step=step)
                    logger.info('save checkpoint at step {0}', format(step))

                if step % TRAINS_VALIDATION_STEPS == 0:
                    acc_batch_total = 0
                    lr = 0
                    batch_time = time.time()

                    for j in range(num_batches_per_epoch_val):
                        index_val = [
                            shuffle_idx_val[i % num_val_samples]
                            for i in range(j * BATCH_SIZE, (j + 1) *
                                           BATCH_SIZE)
                        ]
                        if TRAINS_USE_TFRECORDS:
                            val_inputs, _, val_labels = test_feeder.generate_batch_by_tfrecords(
                                sess)
                        else:
                            val_inputs, _, val_labels = test_feeder.generate_batch_by_index(
                                index_val)

                        val_feed = {
                            model.batch_size: BATCH_SIZE,
                            model.inputs: val_inputs,
                            model.labels: val_labels,
                        }

                        dense_decoded, last_batch_err, lr = sess.run(
                            [model.dense_decoded, model.cost, model.lrn_rate],
                            val_feed)
                        if TRAINS_USE_TFRECORDS:
                            ori_labels = test_feeder.label_by_tfrecords()
                        else:
                            ori_labels = test_feeder.label_by_index(index_val)
                        acc = utils.accuracy_calculation(
                            ori_labels,
                            dense_decoded,
                            ignore_value=-1,
                        )
                        acc_batch_total += acc

                    accuracy = (acc_batch_total * BATCH_SIZE) / num_val_samples
                    avg_train_cost = train_cost / (
                        (cur_batch + 1) * BATCH_SIZE)

                    log = "Epoch: {}, Step: {} Accuracy = {:.3f}, Cost = {:.3f}, Time = {:.3f}, LearningRate: {}"
                    print(
                        log.format(epoch_count, step, accuracy, avg_train_cost,
                                   time.time() - batch_time, lr))
                    if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS:
                        break
            if accuracy >= TRAINS_END_ACC and epoch_count >= TRAINS_END_EPOCHS:
                compile_graph(sess, accuracy)
                print('Total Time: {}'.format(time.time() - start_time))
                break
            epoch_count += 1

        coord.request_stop()
        coord.join(threads)