示例#1
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True, batch_size=192)

        logits = captcha.inference(images, keep_prob=0.5)

        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.compat.v1.train.Saver()

        init_op = tf.group(tf.compat.v1.global_variables_initializer(),
                           tf.compat.v1.local_variables_initializer())

        sess = tf.compat.v1.Session()

        #sess.run(init_op)
        try:
            #saver.restore(sess, FLAGS.checkpoint)
            saver.restore(sess,
                          tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        except Exception as e:
            print(e)
            sess.run(init_op)
            #exit()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            _loss = 100000
            step = 0
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time
                if step % 10 == 0:
                    print('>> Step %d run_train: loss = %f (%.4f sec)' %
                          (step, loss_value, duration))
                if _loss > loss_value:
                    print('>> %s STEP %d LOSS %f SPAN %.4f' %
                          (datetime.now(), step, loss_value, duration))
                    saver.save(sess, FLAGS.checkpoint, global_step=step)
                    _loss = loss_value
                #open('learning.log', 'a').write(step.__str__()+'\t'+loss_value.__str__()+'\n')
                step += 1
        except Exception as e:
            #print('>> %s STEP:' % (datetime.now()))
            saver.save(sess, FLAGS.checkpoint, global_step=step)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
示例#2
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(
            train=True, batch_size=FLAGS.batch_size)
        logits = captcha.inference(images, keep_prob=1)
        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session() as sess:

            sess.run(init_op)

            writer = tf.summary.FileWriter('./logs/', sess.graph)

            # load DATA
            if hasCheckpoint():
                saver.restore(
                    sess, tf.train.latest_checkpoint('./captcha_train'))

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])
                    duration = time.time() - start_time
                    if step % 10 == 0:
                        print('>> Step %d run_train: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                                                duration))
                    if step % 100 == 0:
                        print('>> %s Saving in %s' %
                              (datetime.now(), FLAGS.checkpoint))
                        saver.save(sess, FLAGS.checkpoint, global_step=step)
                    step += 1
            except Exception as e:
                print('>> %s Saving in %s' %
                      (datetime.now(), FLAGS.checkpoint))
                saver.save(sess, FLAGS.checkpoint, global_step=step)
                coord.request_stop(e)
            finally:
                coord.request_stop()
            coord.join(threads)
            writer.close()
示例#3
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.5)

        loss = captcha.loss(logits, labels)

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoint')
        print("checkpoint_path: %s" % checkpoint_path)
        try:
            step = 0
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss])
                duration = time.time() - start_time
                if step % 10 == 0:
                    print('>> Step %d run_train: loss = %.2f (%.3f sec)' %
                          (step, loss_value, duration))
                if step % 100 == 0:
                    print('>> %s Saving in %s' %
                          (datetime.now(), checkpoint_path))
                    saver.save(sess, checkpoint_path, global_step=step)
                step += 1
        except Exception as e:
            print('>> %s Saving in %s' % (datetime.now(), checkpoint_path))
            saver.save(sess, checkpoint_path, global_step=step)
            coord.request_stop(e)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
示例#4
0
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)
        test_images, test_labels = captcha.inputs(train=False,
                                                  batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.7, is_training=True)
        test_logits = captcha.inference(test_images,
                                        keep_prob=1,
                                        is_training=False)

        loss = captcha.loss(logits, labels)
        correct = captcha.evaluation(logits, labels)
        test_correct = captcha.evaluation(test_logits, test_labels)
        eval_correct = correct / FLAGS.batch_size

        tf.summary.scalar('precision', eval_correct)
        tf.summary.scalar('loss', loss)
        tf.summary.image('image', images, 10)
        summary = tf.summary.merge_all()

        train_op = captcha.training(loss)

        saver = tf.train.Saver(tf.global_variables())

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        with tf.Session() as sess:

            sess.run(init_op)

            summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir))
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()
                    _, loss_value, test_co, train_co = sess.run(
                        [train_op, loss, test_correct, correct])
                    result = sess.run(summary)  #merged也是需要run的
                    summary_writer.add_summary(result, step)
                    summary_writer.flush()

                    duration = time.time() - start_time
                    if step % 100 == 0:
                        print(
                            '>> Step %d run_train: loss = %.2f , train_correct= %d, test_correct= %d  (%.3f sec)'
                            % (step, loss_value, train_co, test_co, duration))
                    with open('test.txt', 'a') as f:
                        f.write(str(test_co) + '\n')
                    if step % 300 == 0:
                        print('>> %s Saving in %s' %
                              (datetime.now(), FLAGS.checkpoint))
                        saver.save(sess, FLAGS.checkpoint, global_step=step)
                    step += 1
                    if step > 200000:
                        break
            except Exception as e:
                print('>> %s Saving in %s' %
                      (datetime.now(), FLAGS.checkpoint))
                saver.save(sess, FLAGS.checkpoint, global_step=step)
                coord.request_stop(e)
            finally:
                coord.request_stop()
            coord.join(threads)
def run_train():
    """Train CAPTCHA for a number of steps."""

    with tf.Graph().as_default():
        images, labels = captcha.inputs(train=True,
                                        batch_size=FLAGS.batch_size)
        test_images, test_labels = captcha.inputs(train=False,
                                                  batch_size=FLAGS.batch_size)

        logits = captcha.inference(images, keep_prob=0.75, is_training=True)
        test_logits = captcha.inference(test_images,
                                        keep_prob=1,
                                        is_training=False)

        loss = captcha.loss(logits, labels)

        test_correct = captcha.evaluation(test_logits, test_labels)  #test
        correct = captcha.evaluation(logits, labels)  #train

        train_precision = correct / FLAGS.batch_size
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('train_precision', train_precision)
        tf.summary.image('images', images, 10)
        summary = tf.summary.merge_all()
        train_op = captcha.training(loss)
        saver = tf.train.Saver()

        #    init_op = tf.group(tf.global_variables_initializer(),
        #                       tf.local_variables_initializer())

        sess = tf.Session()
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        #    sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            step = 120140
            while not coord.should_stop():
                start_time = time.time()
                _, loss_value, test_value, train_value = sess.run(
                    [train_op, loss, test_correct, correct])
                summary_str = sess.run(summary)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                duration = time.time() - start_time
                step += 1
                if step % 10 == 0:
                    print(
                        '>> Step %d run_train: loss = %.2f, test = %.2f, train = %.2f (%.3f sec)'
                        %
                        (step, loss_value, test_value, train_value, duration))
                    #-------------------------------

                if step % 100 == 0:
                    print('>> %s Saving in %s' %
                          (datetime.now(), FLAGS.checkpoint))
                    saver.save(sess, FLAGS.checkpoint, global_step=step)
                    print(images.shape.as_list(), labels.shape.as_list())

                if step > 2000000:
                    break
        except KeyboardInterrupt:
            print('INTERRUPTED')
            coord.request_stop()
        except Exception as e:

            coord.request_stop(e)
        finally:
            saver.save(sess, FLAGS.checkpoint, global_step=step)
            print('Model saved in file :%s' % FLAGS.checkpoint)

            coord.request_stop()
            coord.join(threads)
        sess.close()