def infer_single_image(checkpoint, fname): if not os.path.isfile(fname): print('file {} does not exist.'.format(fname)) return img_w = IMG_SIZE[0] img_h = IMG_SIZE[1] img = cv2.imread(fname) img = cv2.resize(img, (img_w, img_h)) img_batch = np.expand_dims(img, axis=0) # print(img_batch.shape) lprnet = LPRnet(is_train=False) with tf.Session() as sess: sess.run(lprnet.init) saver = tf.train.Saver(tf.global_variables()) if not restore_checkpoint(sess, saver, checkpoint, is_train=False): return test_feed = {lprnet.inputs: img_batch} dense_decode = sess.run(lprnet.dense_decoded, test_feed) decoded_labels = [] for item in dense_decode: expression = ['' if i == -1 else DECODE_DICT[i] for i in item] expression = ''.join(expression) decoded_labels.append(expression) for l in decoded_labels: print(l)
def test(checkpoint): lprnet = LPRnet(is_train=False) test_gen = utils.DataIterator(img_dir=TEST_DIR) with tf.Session() as sess: sess.run(lprnet.init) saver = tf.train.Saver(tf.global_variables()) if not restore_checkpoint(sess, saver, checkpoint, is_train=False): return inference(sess, lprnet, test_gen)
def train(checkpoint, runtime_generate=False): lprnet = LPRnet(is_train=True) train_gen = utils.DataIterator(img_dir=TRAIN_DIR, runtime_generate=runtime_generate) val_gen = utils.DataIterator(img_dir=VAL_DIR) def train_batch(train_gen): if runtime_generate: train_inputs, train_targets, _ = train_gen.next_gen_batch() else: train_inputs, train_targets, _ = train_gen.next_batch() feed = {lprnet.inputs: train_inputs, lprnet.targets: train_targets} loss, steps, _, lr = sess.run( \ [lprnet.loss, lprnet.global_step, lprnet.optimizer, lprnet.learning_rate], feed) if steps > 0 and steps % SAVE_STEPS == 0: ckpt_dir = CHECKPOINT_DIR ckpt_file = os.path.join(ckpt_dir, \ 'LPRnet_steps{}_loss_{:.3f}.ckpt'.format(steps, loss)) if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir) saver.save(sess, ckpt_file) print('checkpoint ', ckpt_file) return loss, steps, lr with tf.Session() as sess: sess.run(lprnet.init) saver = tf.train.Saver(tf.global_variables(), max_to_keep=30) restore_checkpoint(sess, saver, checkpoint) print('training...') for curr_epoch in range(TRAIN_EPOCHS): print('Epoch {}/{}'.format(curr_epoch + 1, TRAIN_EPOCHS)) train_loss = lr = 0 st = time.time() for batch in range(BATCH_PER_EPOCH): b_loss, steps, lr = train_batch(train_gen) train_loss += b_loss tim = time.time() - st train_loss /= BATCH_PER_EPOCH log = "train loss: {:.3f}, steps: {}, time: {:.1f}s, learning rate: {:.5f}" print(log.format(train_loss, steps, tim, lr)) if curr_epoch > 0 and curr_epoch % VALIDATE_EPOCHS == 0: inference(sess, lprnet, val_gen)
def export(checkpoint, format, path): lprnet = LPRnet(is_train=False) with tf.Session() as sess: sess.run(lprnet.init) saver = tf.train.Saver(tf.global_variables(), max_to_keep=30) saver.restore(sess, checkpoint) if (format == "saved_model"): builder = tf.saved_model.builder.SavedModelBuilder(path) freezing_graph = sess.graph builder.add_meta_graph_and_variables( sess, ["serve"], signature_def_map={ 'serving_default': tf.saved_model.signature_def_utils.predict_signature_def( { 'inputs': freezing_graph.get_tensor_by_name('inputs:0') }, { 'decoded': freezing_graph.get_tensor_by_name('decoded:0'), 'probability': freezing_graph.get_tensor_by_name('probability:0') }), }, clear_devices=True) builder.save() elif (format == "frozen_graph"): if not os.path.exists(path): os.makedirs(path) output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), ['decoded', 'probability'], ) with tf.gfile.GFile(path + '/frozen_graph.pb', "wb") as outfile: outfile.write(output_graph_def.SerializeToString())