コード例 #1
0
def tower_loss(scope, keep_prob):
    images, labels = captcha.inputs(train=True, batch_size=FLAGS.batch_size)
    logits = captcha.inference(images, keep_prob)
    _ = captcha.loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    return total_loss
コード例 #2
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True, batch_size=192)

        logits = captcha.inference(images, keep_prob=0.5)

        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.compat.v1.train.Saver()

        init_op = tf.group(tf.compat.v1.global_variables_initializer(),
                           tf.compat.v1.local_variables_initializer())

        sess = tf.compat.v1.Session()

        #sess.run(init_op)
        try:
            #saver.restore(sess, FLAGS.checkpoint)
            saver.restore(sess,
                          tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        except Exception as e:
            print(e)
            sess.run(init_op)
            #exit()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            _loss = 100000
            step = 0
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time
                if step % 10 == 0:
                    print('>> Step %d run_train: loss = %f (%.4f sec)' %
                          (step, loss_value, duration))
                if _loss > loss_value:
                    print('>> %s STEP %d LOSS %f SPAN %.4f' %
                          (datetime.now(), step, loss_value, duration))
                    saver.save(sess, FLAGS.checkpoint, global_step=step)
                    _loss = loss_value
                #open('learning.log', 'a').write(step.__str__()+'\t'+loss_value.__str__()+'\n')
                step += 1
        except Exception as e:
            #print('>> %s STEP:' % (datetime.now()))
            saver.save(sess, FLAGS.checkpoint, global_step=step)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
コード例 #3
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(
            train=True, batch_size=FLAGS.batch_size)
        logits = captcha.inference(images, keep_prob=1)
        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session() as sess:

            sess.run(init_op)

            writer = tf.summary.FileWriter('./logs/', sess.graph)

            # load DATA
            if hasCheckpoint():
                saver.restore(
                    sess, tf.train.latest_checkpoint('./captcha_train'))

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])
                    duration = time.time() - start_time
                    if step % 10 == 0:
                        print('>> Step %d run_train: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                                                duration))
                    if step % 100 == 0:
                        print('>> %s Saving in %s' %
                              (datetime.now(), FLAGS.checkpoint))
                        saver.save(sess, FLAGS.checkpoint, global_step=step)
                    step += 1
            except Exception as e:
                print('>> %s Saving in %s' %
                      (datetime.now(), FLAGS.checkpoint))
                saver.save(sess, FLAGS.checkpoint, global_step=step)
                coord.request_stop(e)
            finally:
                coord.request_stop()
            coord.join(threads)
            writer.close()
コード例 #4
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.5)

        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoint')
        print("checkpoint_path: %s" % checkpoint_path)
        try:
            step = 0
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time
                if step % 10 == 0:
                    print('>> Step %d run_train: loss = %.2f (%.3f sec)' %
                          (step, loss_value, duration))
                if step % 100 == 0:
                    print('>> %s Saving in %s' %
                          (datetime.now(), checkpoint_path))
                    saver.save(sess, checkpoint_path, global_step=step)
                step += 1
        except Exception as e:
            print('>> %s Saving in %s' % (datetime.now(), checkpoint_path))
            saver.save(sess, checkpoint_path, global_step=step)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
コード例 #5
0
def main():
    output_types, output_shapes, train_handle = captcha_data.generate_handle(
        "tfrecords/train.tfrecords")
    _, _, validation_handel = captcha_data.generate_handle(
        "tfrecords/validation.tfrecords")

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(string_handle=handle,
                                                   output_types=output_types,
                                                   output_shapes=output_shapes)
    next_element = iterator.get_next()
    images = next_element['image']
    labels = next_element['label']

    logits = captcha_model.inference(images)
    loss = captcha_model.loss(logits, labels)
    accuracy = captcha_model.evaluation(logits, labels)
    train_op = captcha_model.train(loss)

    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())

    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(init)

    train_handle_value = sess.run(train_handle)
    validation_handel_value = sess.run(validation_handel)

    for i in range(10000):
        _, loss_value = sess.run([train_op, loss],
                                 feed_dict={handle: train_handle_value})
        if (i + 1) % 10 == 0:
            accuracy_value = sess.run(
                accuracy, feed_dict={handle: validation_handel_value})
            print("loop: %d, loss: %f accuracy: %f" %
                  (i + 1, loss_value, accuracy_value))
    saver = tf.train.Saver()
    saver.save(sess, "save/model")
    sess.close()
コード例 #6
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)
        test_images, test_labels = captcha.inputs(train=False,
                                                  batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.7, is_training=True)
        test_logits = captcha.inference(test_images,
                                        keep_prob=1,
                                        is_training=False)

        loss = captcha.loss(logits, labels)
        correct = captcha.evaluation(logits, labels)
        test_correct = captcha.evaluation(test_logits, test_labels)
        eval_correct = correct / FLAGS.batch_size

        tf.summary.scalar('precision', eval_correct)
        tf.summary.scalar('loss', loss)
        tf.summary.image('image', images, 10)
        summary = tf.summary.merge_all()

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session() as sess:

            sess.run(init_op)

            summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    _, loss_value, test_co, train_co = sess.run(
                        [train_op, loss, test_correct, correct])
                    result = sess.run(summary)  #merged也是需要run的
                    summary_writer.add_summary(result, step)
                    summary_writer.flush()

                    duration = time.time() - start_time
                    if step % 100 == 0:
                        print(
                            '>> Step %d run_train: loss = %.2f , train_correct= %d, test_correct= %d  (%.3f sec)'
                            % (step, loss_value, train_co, test_co, duration))
                    with open('test.txt', 'a') as f:
                        f.write(str(test_co) + '\n')
                    if step % 300 == 0:
                        print('>> %s Saving in %s' %
                              (datetime.now(), FLAGS.checkpoint))
                        saver.save(sess, FLAGS.checkpoint, global_step=step)
                    step += 1
                    if step > 200000:
                        break
            except Exception as e:
                print('>> %s Saving in %s' %
                      (datetime.now(), FLAGS.checkpoint))
                saver.save(sess, FLAGS.checkpoint, global_step=step)
                coord.request_stop(e)
            finally:
                coord.request_stop()
            coord.join(threads)
コード例 #7
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)
        test_images, test_labels = captcha.inputs(train=False,
                                                  batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.75, is_training=True)
        test_logits = captcha.inference(test_images,
                                        keep_prob=1,
                                        is_training=False)

        loss = captcha.loss(logits, labels)

        test_correct = captcha.evaluation(test_logits, test_labels)  #test
        correct = captcha.evaluation(logits, labels)  #train

        train_precision = correct / FLAGS.batch_size
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('train_precision', train_precision)
        tf.summary.image('images', images, 10)
        summary = tf.summary.merge_all()
        train_op = captcha.training(loss)
        saver = tf.train.Saver()

        #    init_op = tf.group(tf.global_variables_initializer(),
        #                       tf.local_variables_initializer())

        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        #    sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            step = 120140
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value, test_value, train_value = sess.run(
                    [train_op, loss, test_correct, correct])
                summary_str = sess.run(summary)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                duration = time.time() - start_time
                step += 1
                if step % 10 == 0:
                    print(
                        '>> Step %d run_train: loss = %.2f, test = %.2f, train = %.2f (%.3f sec)'
                        %
                        (step, loss_value, test_value, train_value, duration))
                    #-------------------------------

                if step % 100 == 0:
                    print('>> %s Saving in %s' %
                          (datetime.now(), FLAGS.checkpoint))
                    saver.save(sess, FLAGS.checkpoint, global_step=step)
                    print(images.shape.as_list(), labels.shape.as_list())

                if step > 2000000:
                    break
        except KeyboardInterrupt:
            print('INTERRUPTED')
            coord.request_stop()
        except Exception as e:

            coord.request_stop(e)
        finally:
            saver.save(sess, FLAGS.checkpoint, global_step=step)
            print('Model saved in file :%s' % FLAGS.checkpoint)

            coord.request_stop()
            coord.join(threads)
        sess.close()
コード例 #8
0
ファイル: captcha_train.py プロジェクト: zion302/CAPTCHA
def main():
    filename_queue = tf.train.string_input_producer(
        ["../images/tfrecords/train.tfrecords"])
    image, label = captcha_inputs.read_records(filename_queue)
    images, labels = captcha_inputs.records_inputs(image, label,
                                                   MIN_AFTER_DEQUEUE)

    validation_filename_queue = tf.train.string_input_producer(
        ["../images/tfrecords/validation.tfrecords"])
    validation_image, validation_label = captcha_inputs.read_records(
        validation_filename_queue)
    validation_images, validation_labels = captcha_inputs.records_inputs(
        validation_image, validation_label, VALIDATION_MIN_AFTER_DEQUEUE)

    # images, labels = captcha_inputs.inputs("../images/train", MIN_AFTER_DEQUEUE)
    # validation_images, validation_labels = captcha_inputs.inputs("../images/validation", VALIDATION_MIN_AFTER_DEQUEUE)

    # images, labels = captcha_inputs.inputs("/home/windows98/TensorFlow/application/CAPTCHA/images/"
    # , MIN_AFTER_DEQUEUE)
    # eval_images, eval_labels = captcha_inputs("/home/windows98/TensorFlow/application/CAPTCHA/eval_images/PNG",
    # 5, 100, 300)
    # TODO: add validation
    logits = captcha_model.inference(images)
    tf.get_variable_scope().reuse_variables()
    with tf.device("/cpu:0"):
        validation_logits = captcha_model.inference(validation_images)
        accuracy = captcha_model.evaluation(validation_logits,
                                            validation_labels)
    loss = captcha_model.loss(logits, labels)
    train_op = captcha_model.train(loss)

    init = tf.group(tf.initialize_all_variables(),
                    tf.initialize_local_variables())
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    start_time = datetime.now()
    try:
        index = 1
        while not coord.should_stop():
            _, loss_value = sess.run([train_op, loss])
            print("step: " + str(index) + " loss:" + str(loss_value))
            if index % 5 == 0:
                validation_accuracy = sess.run(accuracy)
                print("validation accuracy: " + str(validation_accuracy))
            index += 1
    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
        end_time = datetime.now()
        print("Time Consumption: " + str(end_time - start_time))
    except KeyboardInterrupt:
        print("keyboard interrupt detected, stop running")
        del sess

    finally:
        # When done, ask the threads to stop.
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    sess.close()
    del sess