def eval_model(self): model = cnn_lstm_ctc_ocr.LSTMOCR('eval') model.build_graph() val_feeder, num_samples = self.input_batch_generator(self.split_name, batch_size=FLAGS.batch_size, data_dir=FLAGS.data_dir) num_batches_per_epoch = int(math.ceil(num_samples / float(FLAGS.batch_size))) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) # eval_writer = tf.summary.FileWriter("{}/{}".format(log_dir, self.split_name), sess.graph) if tf.gfile.IsDirectory(self.checkpoint_path): checkpoint_file = tf.train.latest_checkpoint(self.checkpoint_path) else: checkpoint_file = self.checkpoint_path print('Evaluating checkpoint_path={}, split={}, num_samples={}'.format(checkpoint_file, self.split_name, num_samples)) saver.restore(sess, checkpoint_file) true = 0. false = 0. for i in range(num_batches_per_epoch): inputs, labels, _ = next(val_feeder) feed = {model.inputs: inputs, model.labels: labels} start = time.time() _, predictions = sess.run([model.names_to_updates, model.dense_decoded], feed) # -- gt_encode = self.label_from_sparse_tuple(labels) gt = list() pred = list() for j in range(len(gt_encode)): gt_code = [utils.decode_maps[c] if c != -1 else '' for c in gt_encode[j]] gt_code = ''.join(gt_code) gt.append(gt_code) for j in range(len(predictions)): code = [utils.decode_maps[c] if c != -1 else '' for c in predictions[j]] code = ''.join(code) pred.append(code) for j in range(len(gt)): print("%s : %s" % (gt[j], pred[j])) if gt[j] == pred[j]: true += 1 else: false += 1 # -- elapsed = time.time() elapsed = elapsed - start print('{}/{}, {:.5f} seconds.'.format(i, num_batches_per_epoch, elapsed)) # print the decode result print("accuracy: %f" % (true/(true+false))) # summary_str, step = sess.run([CCR.merged_summay, CCR.global_step]) # eval_writer.add_summary(summary_str, step) return
def infer_model(self, img): # image processed img = img.astype(np.float32) / 255. img = cv2.resize(img, (FLAGS.image_width, FLAGS.image_height)) img = np.reshape( img, [FLAGS.image_height, FLAGS.image_width, FLAGS.image_channel]) # CCR model = cnn_lstm_ctc_ocr.LSTMOCR('eval') model.build_graph() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) if tf.gfile.IsDirectory(self.checkpoint_path): checkpoint_file = tf.train.latest_checkpoint( self.checkpoint_path) else: checkpoint_file = self.checkpoint_path print('Evaluating checkpoint_path={}'.format(checkpoint_file)) saver.restore(sess, checkpoint_file) # restore CCR finish inputs = [img] feed = {model.inputs: inputs} # start = time.time() predictions = sess.run(model.dense_decoded, feed) pred = list() for j in range(len(predictions)): code = [ utils.decode_maps[c] if c != -1 else '' for c in predictions[j] ] code = ''.join(code) pred.append(code) # print("%s" % pred[-1]) # elapsed = time.time() # elapsed = elapsed - start # print('Spent {:.5f} seconds.'.format(elapsed)) return pred[-1]
def train(mode='train'): model = cnn_lstm_ctc_ocr.LSTMOCR(mode) model.build_graph() print('loading train data, please wait---------------------') train_feeder, num_train_samples = data_prep.input_batch_generator( 'train', batch_size=FLAGS.batch_size, data_dir=FLAGS.data_dir) print('get image: ', num_train_samples) print('loading validation data, please wait---------------------') val_feeder, num_val_samples = data_prep.input_batch_generator( 'val', batch_size=FLAGS.batch_size * 2, data_dir=FLAGS.data_dir) print('get image: ', num_val_samples) num_batches_per_epoch = int( math.ceil(num_train_samples / float(FLAGS.batch_size))) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) if FLAGS.restore: ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt: # the global_step will restore sa well saver.restore(sess, ckpt) print('restore from the checkpoint{0}'.format(ckpt)) print( '=============================begin training=============================' ) for cur_epoch in range(FLAGS.num_epochs): start_time = time.time() batch_time = time.time() # the tracing part for cur_batch in range(num_batches_per_epoch): if (cur_batch + 1) % 100 == 0: print('batch', cur_batch, ': time', time.time() - batch_time) batch_time = time.time() batch_inputs, batch_labels, _ = next(train_feeder) # batch_inputs,batch_seq_len,batch_labels=utils.gen_batch(FLAGS.batch_size) feed = {model.inputs: batch_inputs, model.labels: batch_labels} # if summary is needed # batch_cost,step,train_summary,_ = sess.run([cost,global_step,merged_summay,optimizer],feed) #print("----------------------------") #print(sess.run([stn_output], feed)) #print("----------------------------") #exit() summary_str, batch_cost, step, _ = \ sess.run([model.merged_summay, model.cost, model.global_step, model.train_op], feed) # calculate the cost train_writer.add_summary(summary_str, step) # save the checkpoint if step % FLAGS.save_steps == 1: if not os.path.isdir(FLAGS.checkpoint_dir): os.mkdir(FLAGS.checkpoint_dir) logger.info('save the checkpoint of{0}', format(step)) saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'ocr-CCR'), global_step=step) # train_err += the_err * FLAGS.batch_size # do validation if step % FLAGS.validation_steps == 0: val_inputs, val_labels, ori_labels = next(val_feeder) val_feed = { model.inputs: val_inputs, model.labels: val_labels } dense_decoded, lr = \ sess.run([model.dense_decoded, model.lrn_rate], val_feed) # print the decode result accuracy = utils.accuracy_calculation(ori_labels, dense_decoded, ignore_value=-1, isPrint=True) # train_err /= num_train_samples now = datetime.datetime.now() log = "{}/{} {}:{}:{} Epoch {}/{}, " \ "accuracy = {:.5f},train_cost = {:.5f}, " \ ", time = {:.3f},lr={:.8f}" print( log.format(now.month, now.day, now.hour, now.minute, now.second, cur_epoch + 1, FLAGS.num_epochs, accuracy, batch_cost, time.time() - start_time, lr))
if __name__ == "__main__": with tf.get_default_graph().as_default(): # send image by base64 image_string_list = tf.placeholder(tf.string, shape=[None, ], name='image_string') batch_input_tensor = tf.map_fn(image_decode, image_string_list, dtype=tf.float32) tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True # maybe necessary, used to avoid cuda initialize error tfconfig.allow_soft_placement = True # maybe necessary, used to avoid cuda initialize error # tfconfig.log_device_placement = True # print message verbose with tf.Session(config=tfconfig) as sess: # # model of cnn lstm ctc cnn_lstm_ctc = EvaluateModel() ocr_model = cnn_lstm_ctc_ocr.LSTMOCR('eval', inputs=batch_input_tensor) ocr_model.build_graph() if tf.gfile.IsDirectory(configure.ccr_checkpoint_path): checkpoint_file = tf.train.latest_checkpoint(configure.ccr_checkpoint_path) else: checkpoint_file = configure.ccr_checkpoint_path ocr_cnn_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='cnn') ocr_lstm_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='lstm') ocr_stn_scope_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='stn-1') ocr_restore = tf.train.Saver( ocr_cnn_scope_to_restore + ocr_lstm_scope_to_restore + ocr_stn_scope_to_restore) ocr_restore.restore(sess, checkpoint_file) ccr_dense_decoded = ocr_model.dense_decoded # tf server configure
return predictions, img # ------------------------------------------- from tensorflow.python.tools import inspect_checkpoint as chkp if __name__ == "__main__": # config = configure.Config(root_path=root_path) ocr_checkpoint_path = "/data2/CNN_LSTM_CTC_Tensorflow/checkpoint" with tf.get_default_graph().as_default(): with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: cnn_lstm_ctc = EvaluateModel() ocr_model = cnn_lstm_ctc_ocr.LSTMOCR('eval') ocr_model.build_graph() if tf.gfile.IsDirectory(ocr_checkpoint_path): checkpoint_file = tf.train.latest_checkpoint( ocr_checkpoint_path) else: checkpoint_file = ocr_checkpoint_path # show tensors in the checkpoint chkp.print_tensors_in_checkpoint_file(checkpoint_file, tensor_name='', all_tensors=False) # get variable to restore ocr_stn_scope_to_restore = tf.get_collection(