Beispiel #1
0
def infer_single_image(checkpoint, fname):

    if not os.path.isfile(fname):
        print('file {} does not exist.'.format(fname))
        return

    img_w = IMG_SIZE[0]
    img_h = IMG_SIZE[1]

    img = cv2.imread(fname)
    img = cv2.resize(img, (img_w, img_h))
    img_batch = np.expand_dims(img, axis=0)

    # print(img_batch.shape)
    lprnet = LPRnet(is_train=False)

    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables())

        if not restore_checkpoint(sess, saver, checkpoint, is_train=False):
            return

        test_feed = {lprnet.inputs: img_batch}
        dense_decode = sess.run(lprnet.dense_decoded, test_feed)

        decoded_labels = []
        for item in dense_decode:
            expression = ['' if i == -1 else DECODE_DICT[i] for i in item]
            expression = ''.join(expression)
            decoded_labels.append(expression)

        for l in decoded_labels:
            print(l)
Beispiel #2
0
def test(checkpoint):
    lprnet = LPRnet(is_train=False)
    test_gen = utils.DataIterator(img_dir=TEST_DIR)
    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables())

        if not restore_checkpoint(sess, saver, checkpoint, is_train=False):
            return

        inference(sess, lprnet, test_gen)
Beispiel #3
0
def train(checkpoint, runtime_generate=False):
    lprnet = LPRnet(is_train=True)
    train_gen = utils.DataIterator(img_dir=TRAIN_DIR,
                                   runtime_generate=runtime_generate)
    val_gen = utils.DataIterator(img_dir=VAL_DIR)

    def train_batch(train_gen):
        if runtime_generate:
            train_inputs, train_targets, _ = train_gen.next_gen_batch()
        else:
            train_inputs, train_targets, _ = train_gen.next_batch()

        feed = {lprnet.inputs: train_inputs, lprnet.targets: train_targets}

        loss, steps, _, lr = sess.run( \
            [lprnet.loss, lprnet.global_step, lprnet.optimizer, lprnet.learning_rate], feed)

        if steps > 0 and steps % SAVE_STEPS == 0:
            ckpt_dir = CHECKPOINT_DIR
            ckpt_file = os.path.join(ckpt_dir, \
                        'LPRnet_steps{}_loss_{:.3f}.ckpt'.format(steps, loss))
            if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir)
            saver.save(sess, ckpt_file)
            print('checkpoint ', ckpt_file)
        return loss, steps, lr

    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=30)
        restore_checkpoint(sess, saver, checkpoint)

        print('training...')
        for curr_epoch in range(TRAIN_EPOCHS):
            print('Epoch {}/{}'.format(curr_epoch + 1, TRAIN_EPOCHS))
            train_loss = lr = 0
            st = time.time()
            for batch in range(BATCH_PER_EPOCH):
                b_loss, steps, lr = train_batch(train_gen)
                train_loss += b_loss
            tim = time.time() - st
            train_loss /= BATCH_PER_EPOCH
            log = "train loss: {:.3f}, steps: {}, time: {:.1f}s, learning rate: {:.5f}"
            print(log.format(train_loss, steps, tim, lr))

            if curr_epoch > 0 and curr_epoch % VALIDATE_EPOCHS == 0:
                inference(sess, lprnet, val_gen)
Beispiel #4
0
def export(checkpoint, format, path):
    lprnet = LPRnet(is_train=False)
    with tf.Session() as sess:
        sess.run(lprnet.init)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=30)
        saver.restore(sess, checkpoint)

        if (format == "saved_model"):
            builder = tf.saved_model.builder.SavedModelBuilder(path)
            freezing_graph = sess.graph
            builder.add_meta_graph_and_variables(
                sess, ["serve"],
                signature_def_map={
                    'serving_default':
                    tf.saved_model.signature_def_utils.predict_signature_def(
                        {
                            'inputs':
                            freezing_graph.get_tensor_by_name('inputs:0')
                        }, {
                            'decoded':
                            freezing_graph.get_tensor_by_name('decoded:0'),
                            'probability':
                            freezing_graph.get_tensor_by_name('probability:0')
                        }),
                },
                clear_devices=True)
            builder.save()

        elif (format == "frozen_graph"):
            if not os.path.exists(path):
                os.makedirs(path)

            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,
                sess.graph.as_graph_def(),
                ['decoded', 'probability'],
            )

            with tf.gfile.GFile(path + '/frozen_graph.pb', "wb") as outfile:
                outfile.write(output_graph_def.SerializeToString())