def train(): model = cnn_lstm_otc_ocr.LSTMOCR('train') model.build_graph() train_feeder, num_train_samples = data_prep.input_batch_generator( 'train', is_training=True, batch_size=batch_size) print('get image: ', num_train_samples) num_batches_per_epoch = int( math.ceil(num_train_samples / float(batch_size))) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph) if restore: ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) for cur_epoch in range(num_epochs): # the tracing part for cur_batch in range(num_batches_per_epoch): batch_time = time.time() batch_inputs, batch_labels, _ = next(train_feeder) feed = {model.inputs: batch_inputs, model.labels: batch_labels} loss, step, _ = sess.run( [model.cost, model.global_step, model.train_op], feed) if step % 100 == 0: print('{}/{}:{},loss={}, time={}'.format( step, cur_epoch, num_epochs, loss, time.time() - batch_time)) # monitor trainig process if step % validation_steps == 0 or ( cur_epoch == num_epochs - 1 and cur_batch == num_batches_per_epoch - 1): batch_inputs, batch_labels, _ = next(train_feeder) feed = { model.inputs: batch_inputs, model.labels: batch_labels } summary_str = sess.run(model.merged_summay, feed) train_writer.add_summary(summary_str, step) # save the checkpoint once very few epoochs if (cur_epoch % save_epochs == 0) or (cur_epoch == num_epochs - 1): if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) print('save the checkpoint of step {}'.format(step)) saver.save(sess, os.path.join(checkpoint_dir, 'ocr-model'), global_step=step)
def infer(path, mode='infer'): imgList = [os.path.join(path, e) for e in os.listdir(path) if e.endswith('.jpg')] print(len(imgList)) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() total_steps = len(imgList) / FLAGS.batch_size config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) # ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) ckpt = FLAGS.checkpoint_dir if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') decoded_expression = [] for curr_step in range(total_steps): show_img_name = [] batch_imgs = [] seq_len_input = [] for img in imgList[curr_step * FLAGS.batch_size: (curr_step + 1) * FLAGS.batch_size]: show_img_name.append(img) im = cv2.imread(img, 0).astype(np.float32) / 255. # im = np.reshape(im, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel]) scale = FLAGS.image_height / im.shape[0] im = cv2.resize(im, None, fx=scale, fy=scale) im = np.expand_dims(im, 2) batch_imgs.append(im) max_width = max([e.shape[1] for e in batch_imgs]) inputs_imgs = np.zeros((len(batch_imgs), batch_imgs[0].shape[0], max_width, 1)) for idx, item in enumerate(batch_imgs): inputs_imgs[idx, 0:item.shape[1], :] = item seq_len_input = [e.shape[1], for e in batch_imgs] imgs_input = inputs_imgs seq_len_input = np.asarray(seq_len_input) feed = {model.inputs: imgs_input, model.seq_len: seq_len_input} dense_decoded_code = sess.run(model.dense_decoded, feed) batch_result = [] for decode_code in dense_decoded_code: pred_strings = utils.label2text(decode_code) batch_result.append(pred_strings) for i in range(len(show_img_name)): print(show_img_name[i], batch_result[i])
def infer(img_path, batch_size=64, image_height=60, image_width=180, image_channel=1, checkpoint_dir="../checkpoint/"): # 读取图片的名称 file_names = os.listdir(img_path) file_names = [t for t in file_names if t.find("label") < 0] file_names.sort(key=lambda x: int(x.split('.')[0])) file_names = np.asarray([os.path.join(img_path, file_name) for file_name in file_names]) # 模型 model = cnn_lstm_otc_ocr.LSTMOCR(num_classes=NumClasses, batch_size=batch_size, is_train=False) model.build_graph() with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: # 初始化模型 sess.run(tf.global_variables_initializer()) # 加载模型 ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt: print('restore from ckpt{}'.format(ckpt)) tf.train.Saver(tf.global_variables(), max_to_keep=100).restore(sess, ckpt) else: print('cannot restore') raise Exception("cannot restore") results = [] for curr_step in range(len(file_names) // batch_size): # 读取图片数据 images_input = [] for img in file_names[curr_step * batch_size: (curr_step + 1) * batch_size]: image_data = np.asarray(Image.open(img).convert("L"), dtype=np.float32) / 255. image_data = np.reshape(image_data, [image_height, image_width, image_channel]) images_input.append(image_data) images_input = np.asarray(images_input) # 运行得到结果 # net_results = sess.run(model.dense_decoded, {model.inputs: images_input}) net_results = sess.run([model.logits, model.seq_len, model.decoded, model.log_prob, model.dense_decoded], {model.inputs: images_input}) # 对网络输出进行解码得到结果 for item in net_results: result = DataIterator.get_result(item) results.append(result) print(result) pass pass # 保存结果 with open('./result.txt', 'a') as f: for code in results: f.write(code + '\n') pass pass pass
def __init__(self): self.model = cnn_lstm_otc_ocr.LSTMOCR('infer') self.graph = self.model.build_graph() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.per_process_gpu_memory_fraction = 0.6 self.sess = tf.Session(config = config) self.saver = tf.train.Saver() self._load_weights('/media/zdyd/dujing/yjx/textrecognition/checkpoint', self.sess, self.saver)
def __init__(self, model_dir = model_dir): self.X = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel], name='input') model = cnn_lstm_otc_ocr.LSTMOCR('infer', '0') self.decodes, self.prob = model.build_graph_for_export(self.X) config=tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth=True self.sess = tf.Session(config = config) self.sess.run(tf.global_variables_initializer()) ckpt = tf.train.latest_checkpoint(model_dir) saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) saver.restore(self.sess, ckpt)
def train(train_dir=None, mode='train'): with tf.Graph().as_default(), tf.device('/cpu:0'): gpus = list(filter(lambda x: x, FLAGS.gpus.split(','))) model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus) train_feeder = utils.DataIterator() X, Y = train_feeder.distored_inputs() train_op, _ = model.build_graph(X, Y) print('len(labels):%d, batch_size:%d' % (len(train_feeder.labels), FLAGS.batch_size)) num_batches_per_epoch = int( len(train_feeder.labels) / FLAGS.batch_size / len(gpus)) config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from checkpoint{0}'.format(ckpt)) print('global_step:', model.global_step.eval()) print('assign value %d' % (FLAGS.num_epochs * num_batches_per_epoch / 3)) #sess.run(tf.assign(model.global_step, FLAGS.num_epochs*num_batches_per_epoch/3)) print('global_step:', model.global_step.eval()) print( '=============================begin training=============================' ) for cur_epoch in range(FLAGS.num_epochs): start_time = time.time() batch_time = time.time() # the training part for cur_batch in range(num_batches_per_epoch): res, step = sess.run([train_op, model.global_step]) #print("step ", step) if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) if (step + 1) % 100 == 1: print( 'step: %d, batch: %d time: %d, learning rate: %.8f, loss:%.4f' % (step, cur_batch, time.time() - batch_time, model.lrn_rate.eval(), model.loss.eval())) coord.request_stop() coord.join(threads)
def eval_model(self): model = cnn_lstm_otc_ocr.LSTMOCR('eval') model.build_graph() val_feeder, num_samples = self.input_batch_generator( self.split_name, is_training=False, batch_size=FLAGS.batch_size) num_batches_per_epoch = int( math.ceil(num_samples / float(FLAGS.batch_size))) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) eval_writer = tf.summary.FileWriter( "{}/{}".format(log_dir, self.split_name), sess.graph) if tf.gfile.IsDirectory(self.checkpoint_path): checkpoint_file = tf.train.latest_checkpoint( self.checkpoint_path) else: checkpoint_file = self.checkpoint_path print('Evaluating checkpoint_path={}, split={}, num_samples={}'. format(checkpoint_file, self.split_name, num_samples)) saver.restore(sess, checkpoint_file) for i in range(num_batches_per_epoch): inputs, labels, _ = next(val_feeder) feed = {model.inputs: inputs} start = time.time() predictions = sess.run(model.dense_decoded, feed) pred = list() for j in range(len(predictions)): code = [ utils.decode_maps[c] if c != -1 else '' for c in predictions[j] ] code = ''.join(code) pred.append(code) print("%s" % pred[-1]) elapsed = time.time() elapsed = elapsed - start print('{}/{}, {:.5f} seconds.'.format(i, num_batches_per_epoch, elapsed)) # print the decode result summary_str, step = sess.run( [model.merged_summay, model.global_step]) eval_writer.add_summary(summary_str, step) return
def infer(img_path, mode='infer'): # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/') imgList = helper.load_img_path(img_path) # actual = [] # for name in imgList: # # code = name.split('/')[-1].split('_')[1].split('.')[0] # code = '-'.join(name.split('/')[-1].split('-')[:-1]) # actual.append(code) # actual = np.asarray(actual) # MAX = 120 # imgList = imgList[:MAX] print(imgList[:5]) with open('./actual.txt', 'w') as f: for name in imgList: code = name.split('/')[-1].split('_')[1].split('.')[0] # code = '-'.join(name.split('/')[-1].split('-')[:-1]) f.write(code + '\n') # exit(1) # im = cv2.imread(imgList[0], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # cv2.imshow('image',im) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() total_steps = len(imgList) // FLAGS.batch_size config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') decoded_expression = [] for curr_step in range(total_steps): imgs_input = [] seq_len_input = [] for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) * FLAGS.batch_size]: # im = cv2.imread(img, 0).astype(np.float32) / 255. im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype( np.float32) / 255. im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height)) im = np.reshape(im, [ FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel ]) def get_input_lens(seqs): length = np.array([FLAGS.max_stepsize for _ in seqs], dtype=np.int64) return seqs, length inp, seq_len = get_input_lens(np.array([im])) imgs_input.append(im) seq_len_input.append(seq_len) imgs_input = np.asarray(imgs_input) seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) feed = {model.inputs: imgs_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: expression = '' for i in item: if i == -1: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) with open('./result.txt', 'w') as f: for code in decoded_expression: f.write(code + '\n')
read_img_begin = time.time() for image in os.listdir(root): image_name = os.path.join(root, image) img = cv2.imdecode(np.fromfile(image_name, dtype=np.uint8), -1) img_list.append(img) read_img_end = time.time() #print len(img_list) #with tf.device('/gpus:0'): #build_model_begin = time.time() #model = cnn_lstm_otc_ocr.LSTMOCR('infer') #model.build_graph() #build_model_end = time.time() with tf.device('/cpu:0'): build_model_begin = time.time() model = cnn_lstm_otc_ocr.LSTMOCR('infer') model.build_graph() build_model_end = time.time() config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(utils.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') fit_image_begin = time.time() result = fit(model, sess, img_list[:1]) fit_image_end = time.time()
def export(): with tf.device('/cpu:0'): with tf.Graph().as_default(): serialized_tf_recognition = tf.placeholder(tf.string, name='tf_recognition') feature_configs = { 'image/encoded': tf.FixedLenFeature( shape=[], dtype=tf.string), } tf_recognition = tf.parse_example(serialized_tf_recognition, feature_configs) jpegs = tf_recognition['image/encoded'] images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32) model = cnn_lstm_otc_ocr.LSTMOCR('infer', '0') decodes = model.build_graph_for_export(images) with tf.device('/cpu:0'), tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: sess.run(tf.global_variables_initializer()) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) saver.restore(sess, ckpt) output_path = os.path.join( tf.compat.as_bytes(FLAGS.output_dir), tf.compat.as_bytes(str(FLAGS.model_version))) print('Exporting trained model to', output_path) builder = tf.saved_model.builder.SavedModelBuilder(output_path) # Build the signature_def_map. classify_inputs_tensor_info = tf.saved_model.utils.build_tensor_info( serialized_tf_recognition) classes_output_tensor_info = tf.saved_model.utils.build_tensor_info( decodes) classification_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={ tf.saved_model.signature_constants.CLASSIFY_INPUTS: classify_inputs_tensor_info }, outputs={ tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES: classes_output_tensor_info }, method_name=tf.saved_model.signature_constants. CLASSIFY_METHOD_NAME)) predict_inputs_tensor_info = tf.saved_model.utils.build_tensor_info(jpegs) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'images': predict_inputs_tensor_info}, outputs={ 'classes': classes_output_tensor_info, }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME )) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_images': prediction_signature, tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY: classification_signature, }, clear_devices=True) builder.save() print('Successfully exported model to %s' % FLAGS.output_dir)
def train(train_dir=None, val_dir=None, mode='train'): if FLAGS.model == 'lstm': model = cnn_lstm_otc_ocr.LSTMOCR(mode) else: print("no such model") sys.exit() #开始构建图 model.build_graph() #########################read train data############################### print('loading train data, please wait---------------------') train_feeder = utils.DataIterator(data_dir=FLAGS.train_dir, istrain=True) print('get image data size: ', train_feeder.size) filename = train_feeder.image label = train_feeder.labels print(len(filename)) train_data = ReadData.ReadData(filename, label) ##################################read test data###################################### print('loading validation data, please wait---------------------') val_feeder = utils.DataIterator(data_dir=FLAGS.val_dir, istrain=False) filename1 = val_feeder.image label1 = val_feeder.labels test_data = ReadData.ReadData(filename1, label1) print('val get image: ', val_feeder.size) ##################计算batch 数 num_train_samples = train_feeder.size num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # 训练集一次epoch需要的batch数 num_val_samples = val_feeder.size num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # 验证集一次epoch需要的batch数 ###########################data################################################ with tf.device('/cpu:0'): config = tf.ConfigProto(allow_soft_placement=True) #######################read data################################### with tf.Session(config=config) as sess: #初始化data迭代器 train_data.init_itetator(sess) test_data.init_itetator(sess) train_data = train_data.get_nex_batch() test_data = test_data.get_nex_batch() #全局变量初始化 sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) #存储模型 train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) #导入预训练模型 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) else: print("No checkpoint") print( '=============================begin training=============================' ) accuracy_res = [] epoch_res = [] tmp_max = 0 tmp_epoch = 0 for cur_epoch in range(FLAGS.num_epochs): train_cost = 0 batch_time = time.time() for cur_batch in range(num_batches_per_epoch): #获得这一轮batch数据的标号############################## #read_data_start = time.time() batch_inputs, batch_labels = sess.run(train_data) #print('read data timr',time.time()-read_data_start) process_data_start = time.time() #print('233333333333333',type(batch_labels)) new_batch_labels = utils.sparse_tuple_from_label( batch_labels.tolist()) # 对了 batch_seq_len = np.asarray( [FLAGS.max_stepsize for _ in batch_inputs], dtype=np.int64) #print('process data timr', time.time() - process_data_start) #train_data_start = time.time() #print('2444444',batch_inputs.shape()) feed = { model.inputs: batch_inputs, model.labels: new_batch_labels, model.seq_len: batch_seq_len } # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_cost += batch_cost * FLAGS.batch_size #print train_cost #train_writer.add_summary(summary_str, step) #print('train data timr', time.time() - train_data_start) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) if (cur_batch) % 100 == 1: print('batch', cur_batch, ': time', time.time() - batch_time, 'loss', batch_cost) batch_time = time.time() # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: validation_start_time = time.time() acc_batch_total = 0 lastbatch_err = 0 lr = 0 for j in range(num_batches_per_epoch_val): batch_inputs, batch_labels = sess.run(test_data) new_batch_labels = utils.sparse_tuple_from_label( batch_labels.tolist()) # 对了 batch_seq_len = np.asarray( [FLAGS.max_stepsize for _ in batch_inputs], dtype=np.int64) val_feed = { model.inputs: batch_inputs, model.labels: new_batch_labels, model.seq_len: batch_seq_len } dense_decoded, lr = \ sess.run([model.dense_decoded, model.lrn_rate], val_feed) acc = utils.accuracy_calculation( batch_labels.tolist(), dense_decoded, ignore_value=-1, isPrint=True) acc_batch_total += acc accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples accuracy_res.append(accuracy) epoch_res.append(cur_epoch) if accuracy > tmp_max: tmp_max = accuracy tmp_epoch = cur_epoch avg_train_cost = train_cost / ( (cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \ " time = {:.3f},lr={:.8f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, tmp_max, tmp_epoch, accuracy, acc_batch_total, avg_train_cost, time.time() - validation_start_time, lr))
def infer(img_path, mode='infer'): # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/') imgList = helper.load_img_path(img_path) print(imgList[:5]) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() total_steps = len(imgList) / FLAGS.batch_size config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') decoded_expression = [] for curr_step in range(int(total_steps)): imgs_input = [] seq_len_input = [] for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) * FLAGS.batch_size]: im = cv2.imread(img, 0).astype(np.float32) / 255. im = np.reshape(im, [ FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel ]) def get_input_lens(seqs): length = np.array([FLAGS.out_channels for _ in seqs], dtype=np.int64) return seqs, length inp, seq_len = get_input_lens(np.array([im])) imgs_input.append(im) seq_len_input.append(seq_len) imgs_input = np.asarray(imgs_input) seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) feed = {model.inputs: imgs_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: expression = '' for i in item: if i == -1: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) print(decoded_expression) with open('./result.txt', 'w') as f: true_count = 0 for ind, code in enumerate(decoded_expression[0:len(imgList)]): img_name = imgList[ind] img_label = img_name.split('_')[-1].replace('.jpg', '') if code == img_label: true_count = true_count + 1 f.write('{} {} {}\n'.format(img_name, img_label, code)) print('{}/{} = {}'.format(true_count, len(imgList), float(true_count) / len(imgList)))
def train(train_dir=None, val_dir=None, mode='train'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() print('loading train data, please wait---------------------') train_feeder, num_train_samples = data_prep.input_batch_generator('train', is_training=True, batch_size = FLAGS.batch_size) print('get image: ', num_train_samples) print('loading validation data, please wait---------------------') val_feeder, num_val_samples = data_prep.input_batch_generator('val', is_training=False, batch_size = FLAGS.batch_size * 2) print('get image: ', num_val_samples) num_batches_per_epoch = int(math.ceil(num_train_samples / float(FLAGS.batch_size))) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) print('=============================begin training=============================') for cur_epoch in range(FLAGS.num_epochs): start_time = time.time() batch_time = time.time() # the tracing part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() batch_inputs, batch_labels, _ = next(train_feeder) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = {model.inputs: batch_inputs, model.labels: batch_labels} # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: val_inputs, val_labels, ori_labels = next(val_feeder) val_feed = {model.inputs: val_inputs, model.labels: val_labels} dense_decoded, lr = \ sess.run([model.dense_decoded, model.lrn_rate], val_feed) # print the decode result accuracy = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value=-1, isPrint=True) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.5f},train_cost = {:.5f}, " \ ", time = {:.3f},lr={:.8f}" print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, batch_cost, time.time() - start_time, lr))
def train(train_dir=None, val_dir=None, mode='train'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) num_batches_per_epoch = int(TRAIN_SET_NUM / FLAGS.batch_size) # example: 100000/100 num_batches_per_epoch_val = int(TRAIN_SET_NUM / FLAGS.batch_size) # example: 10000/100 saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from checkpoint{0}'.format(ckpt)) print( '=============================begin training=============================' ) for epoch in range(1000): train_feeder = sess.run(train_batch) batch_inputs, batch_labels = \ train_feeder[0],read_labels(train_feeder[1]) feed = {model.inputs: batch_inputs, model.labels: batch_labels} summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) train_writer.add_summary(summary_str, step) if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save checkpoint at step {0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # do validation if step % FLAGS.validation_steps == 0: acc_batch_total = 0 lastbatch_err = 0 lr = 0 for j in range(num_batches_per_epoch_val): val_inputs, val_labels = \ train_feeder[0], read_labels(train_feeder[1]) val_feed = { model.inputs: val_inputs, model.labels: val_labels } dense_decoded, lastbatch_err, lr = \ sess.run([model.dense_decoded, model.cost, model.lrn_rate], val_feed) # print the decode result print(dense_decoded) print(val_labels) acc = utils.accuracy_calculation(val_labels, dense_decoded, ignore_value=-1, isPrint=True) acc_batch_total += acc accuracy = (acc_batch_total * FLAGS.batch_size) / 2 # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.3f},avg_train_cost = {:.3f}, " \ "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, epoch + 1, FLAGS.num_epochs, accuracy, epoch, lastbatch_err, time.time() - epoch, lr))
def infer(root, mode='infer'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') for img_file in os.listdir(root): start_time = time.time() img_path = os.path.join(root, img_file) print(img_path) # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/') file_name = img_path.split('/')[-1].split('_')[0] imgList = helper.load_img_path(img_path) #print(imgList[:5]) total_steps = len(imgList) / FLAGS.batch_size sample_num = len(imgList) * 3 total_acc = 0 for curr_step in xrange(total_steps): decoded_expression = [] imgs_input = [] seq_len_input = [] imgs_label = [] for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) * FLAGS.batch_size]: label = img.split('_')[-1].split('.')[0] imgs_label.append(label.upper()) #print (img) im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype( np.float32) / 255. im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height)) im = np.reshape(im, [ FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel ]) def get_input_lens(seqs): length = np.array([FLAGS.max_stepsize for _ in seqs], dtype=np.int64) return seqs, length inp, seq_len = get_input_lens(np.array([im])) imgs_input.append(im) seq_len_input.append(seq_len) imgs_input = np.asarray(imgs_input) seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) feed = {model.inputs: imgs_input, model.seq_len: seq_len_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: expression = '' for i in item: if i == -1: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) acc = utils.test_accuracy_calculation(imgs_label, decoded_expression, True) total_acc += acc print(total_acc / total_steps) print(file_name) print(sample_num) with open('./result.txt', 'a') as f: f.write(file_name + ',' + str(round(total_acc / total_steps, 2)) + ',' + str(sample_num) + ',' + str(round((time.time() - start_time) / sample_num, 2)) + '\n')
def main(_): model = cnn_lstm_otc_ocr.LSTMOCR('train') model.build_graph() print('loading train data, please wait---------------------') train_feeder = utils.DataIterator(data_dir='train') print('get image: ', type(train_feeder.image[0].shape), train_feeder.labels) print('loading validation data, please wait---------------------') val_feeder = utils.DataIterator(data_dir='val') print('get image: ', val_feeder.size) num_train_samples = train_feeder.size # 100000 num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000/100 num_val_samples = val_feeder.size num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 10000/100 shuffle_idx_val = np.random.permutation(num_val_samples) with tf.device('/gpu:0'): config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.per_process_gpu_memory_fraction = 0.6 with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) print( '=============================begin training=============================' ) for cur_epoch in range(FLAGS.num_epochs): shuffle_idx = np.random.permutation(num_train_samples) train_cost = 0 start_time = time.time() batch_time = time.time() # the tracing part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() indexs = [ shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size) ] batch_inputs, batch_seq_len, batch_labels = \ train_feeder.input_index_generate_batch(indexs) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = { model.inputs: batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len } # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_cost += batch_cost * FLAGS.batch_size print('batch_cost is: ', batch_cost) #print 'train_cost is: ', train_cost train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: acc_batch_total = 0 lastbatch_err = 0 lr = 0 for j in xrange(num_batches_per_epoch_val): indexs_val = [ shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size) ] val_inputs, val_seq_len, val_labels = \ val_feeder.input_index_generate_batch(indexs_val) val_feed = { model.inputs: val_inputs, model.labels: val_labels, model.seq_len: val_seq_len } dense_decoded, lastbatch_err, lr = \ sess.run([model.dense_decoded, model.lrn_rate], val_feed) # print the decode result ori_labels = val_feeder.the_label(indexs_val) acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value=-1, isPrint=True) acc_batch_total += acc accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples avg_train_cost = train_cost / ( (cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.3f},avg_train_cost = {:.3f}, " \ "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
def infer(img_path, mode='infer'): # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/') imgList = helper.load_img_path(img_path) actual = [] # for name in imgList: # # code = name.split('/')[-1].split('_')[1].split('.')[0] # code = '-'.join(name.split('/')[-1].split('-')[:-1]) # actual.append(code) # actual = np.asarray(actual) # MAX = 120 # imgList = imgList[:MAX] print(imgList[:5]) with open('./actual.txt', 'w') as f: for name in imgList: code = name.split('/')[-1].split('-')[:-1] # code = name.split('/')[-1].split('_')[-1].split('.')[0] ## convert year field from 2019 -> 19 # code = code.split('-') # code[2] = code[2][2:] code = '-'.join(code) actual.append(code) f.write(code + '\n') actual = np.asarray(actual) # exit(1) # im = cv2.imread(imgList[0], cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # cv2.imshow('image',im) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() total_steps = len(imgList) // FLAGS.batch_size config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) # print(ckpt) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') decoded_expression = [] for curr_step in range(total_steps): imgs_input = [] seq_len_input = [] for img in imgList[curr_step * FLAGS.batch_size: (curr_step + 1) * FLAGS.batch_size]: # im = cv2.imread(img, 0).astype(np.float32) / 255. im = cv2.imread(img, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height)) # im = im[10:45,8:160] # im = cv2.resize(im, (FLAGS.image_width, FLAGS.image_height)) im = np.reshape(im, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel]) # cv2.imshow('image',im) # cv2.waitKey(0) def get_input_lens(seqs): length = np.array([FLAGS.max_stepsize for _ in seqs], dtype=np.int64) return seqs, length inp, seq_len = get_input_lens(np.array([im])) imgs_input.append(im) seq_len_input.append(seq_len) imgs_input = np.asarray(imgs_input) seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) feed = {model.inputs: imgs_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: expression = '' for i in item: if i == -1: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) # visualize the layers # conv_out = sess.run(model.conv_out,feed) # img_name = imgList[curr_step].split('/')[-1].split('.')[0] # # layer0 = conv_out[0] # for i in range(len(conv_out)): # layer = conv_out[i] # print(layer.shape) # plotNNFilter(layer) # plt.show() # plt.savefig("./imgs/filters/conv-{}_{}".format(i+1,img_name)) # print(decoded_expression) # layer0 = model.conv_out[0] # print(layer0.shape) # print(layer0) # print(type(layer0.eval())) # plotNNFilter(layer0) ## visualize the layers # test image # SIZE = 167,55 # imageToUse = imgList[0] # im = cv2.imread(imageToUse, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # im = cv2.resize(im, SIZE) # im = im[8:48,5:155] # im = cv2.resize(im, SIZE) # im = np.reshape(im, [SIZE[1],SIZE[0],1]) # cv2.imshow('image',im) # cv2.waitKey(0) # op = sess.graph.get_operations() # for i in op: # print(i.name) # exit(1) # print layers # plt.imshow(np.reshape(imageToUse,[28,28]), interpolation="nearest", cmap="gray") with open('./result.txt', 'w') as f: for code in decoded_expression: f.write(code + '\n') # print(code) # exit() decoded_expression = np.asarray(decoded_expression) imgList = np.asarray(imgList) # print 6 corect and 6 incorrect predictions c = decoded_expression == actual w = decoded_expression != actual correct = imgList[c] wrong = imgList[w] print("correct predictions:") print(correct[:6]) print("********") print("wrong predictions:") print(wrong[:6]) print("********") for i in range(6): print("prediction = {}".format(decoded_expression[w][i])) print("actual = {}".format(actual[w][i])) print("********") acc = float(c.sum()) / (c.sum()+w.sum()) print("accuracy = {}".format(acc))
def infer(self): FLAGS.num_threads = 1 gpus = list(filter(lambda x: x, FLAGS.gpus.split(','))) with tf.Graph().as_default(), tf.device('/cpu:0'): train_feeder = utils.DataIterator(is_val=True, random_shuff=False) X, Y = train_feeder.distored_inputs() model = cnn_lstm_otc_ocr.LSTMOCR('infer', gpus) train_op, decodes = model.build_graph(X, Y) total_steps = (len(train_feeder.image) + FLAGS.batch_size - 1) / FLAGS.batch_size config = tf.ConfigProto(allow_soft_placement=True) result_dir = os.path.dirname(FLAGS.infer_file) with tf.Session(config=config) as sess, open( os.path.join(FLAGS.output_dir, 'result'), 'w') as f: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) print(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') count = 0 for curr_step in range(total_steps): decoded_expression = [] start = time.time() dense_decoded_code = sess.run(decodes) print('time cost:', (time.time() - start)) print("dense_decoded_code:", dense_decoded_code) for d in dense_decoded_code: for item in d: expression = '' for i in item: if i not in utils.decode_maps: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) for code in decoded_expression: if count >= len(train_feeder.image): break f.write("%s,%s,%s\n" % (train_feeder.image[count], train_feeder.anno[count].encode('utf-8'), code.encode('utf-8'))) filename = os.path.splitext( os.path.basename( train_feeder.image[count]))[0] + ".txt" output_file = os.path.join(FLAGS.output_dir, filename) cur = open(output_file, "w") cur.write(code.encode('utf-8')) cur.close() print(code.encode('utf-8')) try: image_debug = cv2.imread(train_feeder.image[count]) image_debug = self.draw_debug( image_debug, code.encode('utf-8'), code == train_feeder.anno[count]) image_path = os.path.join( FLAGS.output_dir, os.path.basename(train_feeder.image[count])) cv2.imwrite(image_path, image_debug) except Exception as e: print(e) count += 1 coord.request_stop() coord.join(threads)
def train(train_dir = None, val_dir = None, mode = 'train'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) # 创建图 model.build_graph() print('loading train data, please wait---------------------') # 训练数据构造器 train_feeder = utils.DataIterator(data_dir = train_dir) print('get image:', train_feeder.size) print('loading validation data, please wait---------------------') # 验证数据构造器 val_feeder = utils.DataIterator(data_dir = val_dir) print('get image:', val_feeder.size) num_train_samples = train_feeder.size # 100000 num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000 / 100 num_val_samples = val_feeder.size # 100000 num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 100000 / 100 # 随机打乱验证集样本 shuffle_idx_val = np.random.permutation(num_val_samples) with tf.device('/gpu:0'): # tf.ConfigProto一般用在创建session的时候。用来对session进行参数配置 # allow_soft_placement = True # 如果你指定的设备不存在,允许TF自动分配设备 config = tf.ConfigProto(allow_soft_placement = True) with tf.Session(config = config) as sess: sess.run(tf.global_variables_initializer()) # 创建saver对象,用来保存和恢复模型的参数 saver = tf.train.Saver(tf.global_variables(), max_to_keep = 100) # 将sess里的graph放到日志文件中 train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) # 如果之前有保存的模型参数,将之恢复到现在的sess中 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) print('=============================begin training=============================') # 开始训练 for cur_epoch in range(FLAGS.num_epochs): shuffle_idx = np.random.permutation(num_train_samples) train_cost = 0 start_time = time.time() batch_time = time.time() for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ':time', time.time() - batch_time) batch_time = time.time() # 构造当前batch的样本indexs # 在训练样本空间中随机选取batch_size数量的的样本 indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size)] batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) # 构造模型feed参数 feed = {model.inputs:batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len} # 执行图 # fetch操作取回tensors summar_str, batch_cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # 计算损失值 # 这里的batch_cost是一个batch里的均值 train_cost += batch_cost * FLAGS.batch_size # 可视化 train_writer.add_summary(summar_str, step) # 保存模型文件checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step = step) # 每个batch验证集上得到解码结果 if step % FLAGS.validation_steps == 0: acc_batch_total = 0 lastbatch_err = 0 lr = 0 # 得到验证集的输入 # 每个batch做迭代验证 for j in xrange(num_batches_per_epoch_val): indexs_val = [shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size)] val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch(indexs_val) val_feed = {model.inputs: val_inputs, modell.labels: val_labels, model.seq_len: val_seq_len} dense_decoded, lastbatch_err, lr = sess.run([model.dense_decoded, model.lrn_rate], val_feed) # 打印在验证集上返回的结果 ori_labels = val_feeder.the_label(indexs_val) acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value = -1, isPrint = True) acc_batch_total += acc accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_smaples now = datetime.datetime.time() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.3f},avg_train_cost = {:.3f}, " \ "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
def infer(img_path, mode='infer'): # imgList = load_img_path('/home/yang/Downloads/FILE/ml/imgs/image_contest_level_1_validate/') imgList = helper.load_img_path(img_path) print(imgList[:5]) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() total_steps = len(imgList) / FLAGS.batch_size os.environ["CUDA_VISIBLE_DEVICES"] = '2' config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') decoded_expression = [] for curr_step in range(int(total_steps)): imgs_input = [] seq_len_input = [] for img in imgList[curr_step * FLAGS.batch_size:(curr_step + 1) * FLAGS.batch_size]: im = cv2.imread(img, cv2.IMREAD_COLOR).astype( np.float32) / 255. im = np.reshape(im, [ FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel ]) def get_input_lens(seqs): length = np.array([FLAGS.max_stepsize for _ in seqs], dtype=np.int64) return seqs, length inp, seq_len = get_input_lens(np.array([im])) imgs_input.append(im) seq_len_input.append(seq_len) imgs_input = np.asarray(imgs_input) seq_len_input = np.asarray(seq_len_input) seq_len_input = np.reshape(seq_len_input, [-1]) feed = {model.inputs: imgs_input} dense_decoded_code = sess.run(model.dense_decoded, feed) for item in dense_decoded_code: expression = '' for i in item: if i == -1: expression += '' else: expression += utils.decode_maps[i] decoded_expression.append(expression) with open('./result.txt', 'a') as f: for code in decoded_expression: print(code) f.write(code + '\n')
def train(train_dir=None, mode='train'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() label_files = [os.pat.join(train_dir, e) for e in os.listdir(train_dir) if e.endswith('.txt') and os.path.exists(os.path.join(train_dir, e.replace('.txt', '.jpg')))] train_num = int(len(label_files) * 0.8) test_num = len(label_files) - train_num print('total num', len(label_files), 'train num', train_num, 'test num', test_num) train_imgs = label_files[0:train_num] test_imgs = label_files[train_num:] print('loading train data') train_feeder = utils.DataIterator(data_dir=train_imgs) print('loading validation data') val_feeder = utils.DataIterator(data_dir=test_imgs) num_batches_per_epoch_train = int(train_num / FLAGS.batch_size) # example: 100000/100 num_batches_per_epoch_val = int(test_num / FLAGS.batch_size) # example: 10000/100 config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from checkpoint{0}'.format(ckpt)) print('=============================begin training=============================') for cur_epoch in range(FLAGS.num_epochs): train_cost = 0 start_time = time.time() batch_time = time.time() if cur_epoch == 0: random.shuffle(train_feeder.train_data) # the training part for cur_batch in range(num_batches_per_epoch_train): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() batch_inputs, result_img_length, batch_labels = \ train_feeder.get_batchsize_data(cur_batch) feed = {model.inputs: batch_inputs, model.labels: batch_labels, model.seq_len: result_img_length} # if summary is needed summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_cost += batch_cost * FLAGS.batch_size train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save checkpoint at step {0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: acc_batch_total = 0 lastbatch_err = 0 lr = 0 for val_j in range(num_batches_per_epoch_val): result_img_val, seq_len_input_val, batch_label_val = \ val_feeder.get_batchsize_data(val_j) val_feed = {model.inputs: result_img_val, model.labels: batch_label_val, model.seq_len: seq_len_input_val} dense_decoded, lastbatch_err, lr = \ sess.run([model.dense_decoded, model.cost, model.lrn_rate], val_feed) # print the decode result val_pre_list = [] for decode_code in dense_decoded: pred_strings = utils.label2text(decode_code) val_pre_list.append(pred_strings) ori_labels = val_feeder.get_val_label(val_j) acc = utils.accuracy_calculation(ori_labels, val_pre_list, ignore_value=-1, isPrint=True) acc_batch_total += acc accuracy = acc_batch_total / num_batches_per_epoch_val avg_train_cost = train_cost / ((cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.3f},avg_train_cost = {:.3f}, " \ "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
def train(train_dir, batch_size=64, image_height=60, image_width=180, image_channel=1, checkpoint_dir="../checkpoint/", num_epochs=100): # 加载数据 train_data = DataIterator(data_dir=train_dir, batch_size=batch_size, begin=0, end=800) valid_data = DataIterator(data_dir=train_dir, batch_size=batch_size, begin=800, end=1000) print('train data batch number: {}'.format(train_data.number_batch)) print('valid data batch number: {}'.format(valid_data.number_batch)) # 模型 model = cnn_lstm_otc_ocr.LSTMOCR(NumClasses, batch_size, image_height=image_height, image_width=image_width, image_channel=image_channel, is_train=True) model.build_graph() config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True), allow_soft_placement=True) with tf.Session(config=config) as sess: # 初始化 sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(checkpoint_dir + 'train', sess.graph) # 加载模型 ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from checkpoint{0}'.format(ckpt)) else: print('no checkpoint to restore') pass print('=======begin training=======') for cur_epoch in range(num_epochs): start_time = time.time() batch_time = time.time() # 训练 train_cost = 0 for cur_batch in range(train_data.number_batch): if cur_batch % 100 == 0: print('batch {}/{} time: {}'.format(cur_batch, train_data.number_batch, time.time() - batch_time)) batch_time = time.time() batch_inputs, _, sparse_labels = train_data.next_train_batch() summary, cost, step, _ = sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], {model.inputs: batch_inputs, model.labels: sparse_labels}) train_cost += cost train_writer.add_summary(summary, step) pass print("loss is {}".format(train_cost / train_data.number_batch)) # 保存模型 if cur_epoch % 1 == 0: if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) saver.save(sess, os.path.join(checkpoint_dir, 'ocr-model'), global_step=cur_epoch) pass # 测试 if cur_epoch % 1 == 0: lr = 0 acc_batch_total = 0 for j in range(valid_data.number_batch): val_inputs, _, sparse_labels, ori_labels = valid_data.next_test_batch(j) dense_decoded, lr = sess.run([model.dense_decoded, model.lrn_rate], {model.inputs: val_inputs, model.labels: sparse_labels}) acc_batch_total += accuracy_calculation(ori_labels, dense_decoded, -1) pass accuracy = acc_batch_total / valid_data.number_batch now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, accuracy = {:.3f}, time = {:.3f},lr={:.8f}" print(log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, num_epochs, accuracy, time.time() - start_time, lr)) pass pass pass
def infer(mode='infer'): FLAGS.num_threads = 1 gpus = list(filter(lambda x: x, FLAGS.gpus.split(','))) with tf.Graph().as_default(), tf.device('/cpu:0'): train_feeder = utils.DataIterator(is_val=True, random_shuff=False) X, Y_in, Y_out, length = train_feeder.distored_inputs() model = cnn_lstm_otc_ocr.LSTMOCR(mode, gpus) train_op, decodes = model.build_graph(X, Y_in, Y_out, length) total_steps = int((len(train_feeder.image) + FLAGS.batch_size - 1) / FLAGS.batch_size) config = tf.ConfigProto(allow_soft_placement=True) result_dir = os.path.dirname(FLAGS.infer_file) with tf.Session(config=config) as sess, open( os.path.join(result_dir, 'result_digit_v1.txt'), 'w') as f: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #saver = tf.train.Saver(tf.global_variables(), max_to_keep=3) #saver.restore(sess, './checkpoint_zhuyiwei/ocr-model-55001') variables_to_restore = model.variable_averages.variables_to_restore( ) saver = tf.train.Saver(variables_to_restore) ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) print("search from ", FLAGS.checkpoint_dir) print(FLAGS.checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print('restore from ckpt{}'.format(ckpt)) else: print('cannot restore') if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) count = 0 for curr_step in range(total_steps): decoded_expression = [] dense_decoded_code = sess.run(decodes) #print('dense_decode', dense_decoded_code) for batch in dense_decoded_code: for sequence in batch: expression = '' for code in sequence: if code == utils.TOKEN["<EOS>"]: break if code not in utils.decode_maps: expression += '' else: expression += utils.decode_maps[code] decoded_expression.append(expression) for expression in decoded_expression: if count >= len(train_feeder.image): break # f.write("%s,%s,%s\n"%(train_feeder.image[count], train_feeder.anno[count].encode('utf-8'), code.encode('utf-8'))) print(train_feeder.image[count]) #print(train_feeder.anno[count].encode('utf-8')) #print(expression.encode('utf-8')) print(train_feeder.anno[count]) print(expression) print('') filename = os.path.splitext( os.path.basename( train_feeder.image[count]))[0] + ".txt" output_file = os.path.join(FLAGS.output_dir, filename) cur = open(output_file, "w") #cur.write(expression.encode('utf-8')) cur.write(expression) cur.close() count += 1 coord.request_stop() coord.join(threads)
def train(train_dir=None, val_dir=None, mode='train'): #load dataset tfrecords_filename = '/home/youth/DL/CNN_LSTM_CTC_Tensorflow/tfrecords/train.tfrecords' filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=EPOCHS, shuffle=True) images, names, labels = gen_tfrecord.read_and_decode(filename_queue) print images, names, labels # b, h, w, c = tf.shape(images) shape = np.shape(images) # print shape seq_len = np.array([12 for _ in range(shape[0])], dtype=np.int64) labels = utils.sparse_tuple_from_label(labels) model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph(images, labels, seq_len) num_train_samples = gen_tfrecord.get_size(tfrecords_filename) # 100000 num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000/100 with tf.device('/cpu:0'): config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) print( '=============================begin training=============================' ) for cur_epoch in range(FLAGS.num_epochs): # shuffle_idx = np.random.permutation(num_train_samples) train_cost = 0 start_time = time.time() batch_time = time.time() # the tracing part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op]) # calculate the cost train_cost += batch_cost * FLAGS.batch_size train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # avg_train_cost = train_cost / ( (cur_batch + 1) * FLAGS.batch_size) # # # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{},avg_train_cost = {:.3f}, time = {:.3f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, avg_train_cost, time.time() - start_time)) coord.request_stop() coord.join(threads)
def train(train_dir=None, val_dir=None, mode='train'): model = cnn_lstm_otc_ocr.LSTMOCR(mode) model.build_graph() print(FLAGS.image_channel) print('loading train data') train_feeder = utils.DataIterator(data_dir=train_dir) print('size: ', train_feeder.size) print('loading validation data') val_feeder = utils.DataIterator(data_dir=val_dir) print('size: {}\n'.format(val_feeder.size)) num_train_samples = train_feeder.size # 100000 num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # example: 100000/100 num_val_samples = val_feeder.size num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # example: 10000/100 shuffle_idx_val = np.random.permutation(num_val_samples) os.environ["CUDA_VISIBLE_DEVICES"] = '2' config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from checkpoint{0}'.format(ckpt)) print( '=============================begin training=============================' ) sess.graph.finalize() for cur_epoch in range(FLAGS.num_epochs): shuffle_idx = np.random.permutation(num_train_samples) train_cost = 0 start_time = time.time() batch_time = time.time() # the training part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() indexs = [ shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size) ] batch_inputs, _, batch_labels = \ train_feeder.input_index_generate_batch(indexs) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = {model.inputs: batch_inputs, model.labels: batch_labels} # if summary is needed summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_cost += batch_cost * FLAGS.batch_size train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save checkpoint at step {0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: acc_batch_total = 0 lastbatch_err = 0 lr = 0 for j in range(num_batches_per_epoch_val): indexs_val = [ shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size) ] val_inputs, _, val_labels = \ val_feeder.input_index_generate_batch(indexs_val) val_feed = { model.inputs: val_inputs, model.labels: val_labels } dense_decoded, lastbatch_err, lr = \ sess.run([model.dense_decoded, model.cost, model.lrn_rate], val_feed) # print the decode result ori_labels = val_feeder.the_label(indexs_val) acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value=-1, isPrint=True) acc_batch_total += acc accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples avg_train_cost = train_cost / ( (cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.3f},avg_train_cost = {:.3f}, " \ "lastbatch_err = {:.3f}, time = {:.3f},lr={:.8f}" with open('test_acc.txt', 'a') as f: f.write( str( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr)) + "\n") print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, avg_train_cost, lastbatch_err, time.time() - start_time, lr))
def train(train_dir=None, val_dir=None, mode='train'): if FLAGS.model == 'lstm': model = cnn_lstm_otc_ocr.LSTMOCR(mode) else: print("no such model") sys.exit() #开始构建图 model.build_graph() print('loading train data, please wait---------------------') train_feeder = utils.DataIterator(data_dir=train_dir, num=4000000) print('get image: ', train_feeder.size) print('loading validation data, please wait---------------------') val_feeder = utils.DataIterator(data_dir=val_dir, num=40000) print('get image: ', val_feeder.size) num_train_samples = train_feeder.size num_batches_per_epoch = int(num_train_samples / FLAGS.batch_size) # 训练集一次epoch需要的batch数 num_val_samples = val_feeder.size num_batches_per_epoch_val = int(num_val_samples / FLAGS.batch_size) # 验证集一次epoch需要的batch数 shuffle_idx_val = np.random.permutation(num_val_samples) with tf.device('/cpu:0'): config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: #全局变量初始化 sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) #存储模型 train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) #导入预训练模型 if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) else: print("No checkpoint") print( '=============================begin training=============================' ) accuracy_res = [] accuracy_per_res = [] epoch_res = [] tmp_max = 0 tmp_epoch = 0 for cur_epoch in range(FLAGS.num_epochs): shuffle_idx = np.random.permutation(num_train_samples) train_cost = 0 start_time = time.time() batch_time = time.time() # the tracing part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() #获得这一轮batch数据的标号 indexs = [ shuffle_idx[i % num_train_samples] for i in range(cur_batch * FLAGS.batch_size, (cur_batch + 1) * FLAGS.batch_size) ] batch_inputs, batch_seq_len, batch_labels = \ train_feeder.input_index_generate_batch(indexs) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = { model.inputs: batch_inputs, model.labels: batch_labels, model.seq_len: batch_seq_len } # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_cost += batch_cost * FLAGS.batch_size train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-model'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: acc_batch_total = 0 acc_per_batch_total = 0 lastbatch_err = 0 lr = 0 for j in range(num_batches_per_epoch_val): indexs_val = [ shuffle_idx_val[i % num_val_samples] for i in range(j * FLAGS.batch_size, (j + 1) * FLAGS.batch_size) ] val_inputs, val_seq_len, val_labels = \ val_feeder.input_index_generate_batch(indexs_val) val_feed = { model.inputs: val_inputs, model.labels: val_labels, model.seq_len: val_seq_len } dense_decoded, lr = \ sess.run([model.dense_decoded, model.lrn_rate], val_feed) # print the decode result ori_labels = val_feeder.the_label(indexs_val) acc = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value=-1, isPrint=True) acc_per = utils.accuracy_calculation_single( ori_labels, dense_decoded, ignore_value=-1, isPrint=True) acc_per_batch_total += acc_per acc_batch_total += acc accuracy_per = (acc_per_batch_total * FLAGS.batch_size) / num_val_samples accuracy = (acc_batch_total * FLAGS.batch_size) / num_val_samples accuracy_per_res.append(accuracy_per) accuracy_res.append(accuracy) epoch_res.append(cur_epoch) if accuracy_per > tmp_max: tmp_max = accuracy tmp_epoch = cur_epoch avg_train_cost = train_cost / ( (cur_batch + 1) * FLAGS.batch_size) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "max_accuracy = {:.3f},max_Epoch {},accuracy = {:.3f},acc_batch_total = {:.3f},avg_train_cost = {:.3f}, " \ " time = {:.3f},lr={:.8f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, tmp_max, tmp_epoch, accuracy_per, acc_batch_total, avg_train_cost, time.time() - start_time, lr))