示例#1
0
def tower_loss(scope, keep_prob):
    images, labels = captcha.inputs(train=True, batch_size=FLAGS.batch_size)
    logits = captcha.inference(images, keep_prob)
    _ = captcha.loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    return total_loss
示例#2
0
def run_eval():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        images, labels = captcha.inputs(train=False,
                                        batch_size=FLAGS.batch_size)
        logits = captcha.inference(images, keep_prob=1)
        eval_correct = captcha.evaluation(logits, labels)
        sess = tf.Session()
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
            true_count = 0
            total_true_count = 0
            total_sample_count = num_iter * FLAGS.batch_size
            step = 0
            log('loop: %d, total_sample_count: %d' %
                (num_iter, total_sample_count))
            while step < num_iter and not coord.should_stop():
                true_count = sess.run(eval_correct)
                total_true_count += true_count
                precision = true_count / FLAGS.batch_size
                log('Step %d: true/total: %d/%d precision @ 1 = %.3f' %
                    (step, true_count, FLAGS.batch_size, precision))
                step += 1
            precision = total_true_count / total_sample_count
            log('true/total: %d/%d precision @ 1 = %.3f' %
                (total_true_count, total_sample_count, precision))
        except Exception as e:
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
示例#3
0
    def predict(self, base64_image_str):
        with tf.Graph().as_default(), tf.device('/cpu:0'):
            input_images = self.input_data(base64_image_str)
            images = tf.constant(input_images)
            logits = captcha.inference(images, keep_prob=1)
            result = captcha.output(logits)
            saver = tf.train.Saver()
            sess = tf.Session()
            # log(checkpoint_dir)
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
            # log(tf.train.latest_checkpoint(checkpoint_dir))
            recog_result = sess.run(result)
            sess.close()
            text = self.one_hot_to_texts(recog_result)

            return text[0]
示例#4
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.5)

        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            step = 0
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time
                if step % 10 == 0:
                    logger.info('Step %d run_train: loss = %.2f (%.3f sec)' %
                                (step, loss_value, duration))
                if step % 100 == 0:
                    logger.info('%s Saving in %s' %
                                (datetime.now(), FLAGS.checkpoint))
                    saver.save(sess, FLAGS.checkpoint, global_step=step)
                step += 1
        except Exception as e:
            logger.info('%s Saving in %s' % (datetime.now(), FLAGS.checkpoint))
            saver.save(sess, FLAGS.checkpoint, global_step=step)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
示例#5
0
def run_predict():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        input_images, input_filenames = input_data(FLAGS.captcha_dir)
        logger.info(input_filenames)
        images = tf.constant(input_images)
        logits = captcha.inference(images, keep_prob=1)
        result = captcha.output(logits)
        saver = tf.train.Saver()
        sess = tf.Session()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        logger.info(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        recog_result = sess.run(result)
        sess.close()
        text = one_hot_to_texts(recog_result)
        total_count = len(input_filenames)
        true_count = 0.
        for i in range(total_count):
            logger.info('image ' + input_filenames[i] + " recognize ----> '" +
                        text[i] + "'")
            if text[i] in input_filenames[i]:
                true_count += 1
        precision = true_count / total_count
        logger.info('true/total: %d/%d recognize @ 1 = %.3f' %
                    (true_count, total_count, precision))