def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar10_model.distorted_inputs()
        logits = cifar10_model.inference(images)
        loss = cifar10_model.loss(logits, labels)
        train_op = cifar10_model.train(loss, global_step)
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            if step % 100 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
def model_fn(features, labels, mode, params):
    logits = cifar10_model.inference(image_batch=features,
                                     batch_size=params.get('batch_size'))
    loss = cifar10_model.loss(logits, labels)
    train_op = cifar10_model.train(loss, batch_size=params.get('batch_size'))

    if mode == tf.estimator.ModeKeys.TRAIN:
        logging_hook = tf.train.LoggingTensorHook({'loss': loss},
                                                  every_n_iter=1000)
        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=[logging_hook])
Ejemplo n.º 3
0
            image = tf.placeholder(tf.uint8)
            label = tf.placeholder(tf.int32)
            phase = tf.placeholder(tf.bool)

            dataset_iterator = cifar10_input.input_dataset(image, label, BATCH_SIZE, 1)
            data = dataset_iterator.get_next()
            image_queue = data["features"]
            label_queue = data["label"]
            
            step = tf.train.get_or_create_global_step()
            learning_rate = tf.train.exponential_decay(in_lr, step, DECAY_STEP, DECAY_RATE, staircase = True)
            # learning_rate = tf.constant(in_lr)

            with tf.device('/gpu:0'):
                logits = cifar10_model.dnn(image_queue, mean, variance, phase)
                loss, train_step = cifar10_model.train(logits, label_queue, learning_rate, lmbd, step, tf.trainable_variables())
                accuracy = cifar10_model.old_evaluate(logits, label_queue)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True
            session_args = {
                'checkpoint_dir': './trained_model',
                'save_checkpoint_steps': 300,
                'config': config
            }
            with tf.train.MonitoredTrainingSession(**session_args) as sess:

                count = 1
                for epoch in range(NO_OF_EPOCHS):
                    l = list(range(5))
Ejemplo n.º 4
0
def run_training():
    cifar10_data = Cifar10Data('./input_data')

    images_pl = tf.placeholder(tf.float32, [
        None, cifar10_model.IMAGE_BATCH_HEIGHT,
        cifar10_model.IMAGE_BATCH_WIDTH, cifar10_model.IMAGE_BATCH_DEPTH
    ])
    labels_pl = tf.placeholder(tf.int32)
    keep_prob_pl = tf.placeholder(tf.float32)
    learning_rate_pl = tf.placeholder(tf.float32)
    is_test_pl = tf.placeholder(tf.bool)
    iter_pl = tf.placeholder(tf.int32)

    with tf.Session() as sess:
        logits, update_ema = cifar10_model.inference(images_pl, iter_pl,
                                                     is_test_pl, keep_prob_pl)
        total_loss = cifar10_model.loss(logits, labels_pl)
        train_op = cifar10_model.train(total_loss, learning_rate_pl)
        eval_op = cifar10_eval.evaluation(logits, labels_pl)

        saver = tf.train.Saver()
        summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
        sess.run(tf.global_variables_initializer())

        # learning rate decay
        max_learning_rate = 0.02  # 0.003
        min_learning_rate = 0.0001
        decay_speed = 1600.0  # 2000.0
        for step in xrange(MAX_STEPS):
            print('step %d/%d' % (step, MAX_STEPS))
            start_time = time.time()
            images_feed, labels_feed = cifar10_data.random_training_batch(
                cifar10_model.BATCH_SIZE)
            learning_rate = min_learning_rate + (max_learning_rate -
                                                 min_learning_rate) * math.exp(
                                                     -step / decay_speed)

            images_feed = sess.run(
                cifar10_model.random_distort_images(images_feed))

            feed_dict = {
                images_pl: images_feed,
                labels_pl: labels_feed,
                keep_prob_pl: 0.75,
                learning_rate_pl: learning_rate,
                is_test_pl: False,
                iter_pl: step
            }
            sess.run(train_op, feed_dict=feed_dict)

            feed_dict = {
                images_pl: images_feed,
                labels_pl: labels_feed,
                keep_prob_pl: 1.0,
                learning_rate_pl: learning_rate,
                is_test_pl: False,
                iter_pl: step
            }
            sess.run(update_ema, feed_dict=feed_dict)

            duration = time.time() - start_time

            # Write the summaries and print an overview fairly often.
            if (step + 1) % 100 == 0 or (step + 1) == MAX_STEPS:
                train_eval_val, loss_value = sess.run([eval_op, total_loss],
                                                      feed_dict=feed_dict)
                print('Step %d: loss = %.2f, lr = %f (%.3f sec)' %
                      (step + 1, loss_value, learning_rate, duration))
                print('Training Data Eval: %.4f' % train_eval_val)

                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                # Evaluate the model periodically.
                # feed_dict = {images_pl: data_sets.testing_image,
                #             labels_pl: data_sets.testing_label}
                # test_eval_val = sess.run(eval_op, feed_dict=feed_dict)
                test_eval_val, test_loss_val = cifar10_eval.mass_evaluation(
                    cifar10_data, sess, eval_op, total_loss, images_pl,
                    labels_pl, keep_prob_pl, is_test_pl)
                print('Testing Data Eval: ' + str(test_eval_val) + '  loss: ' +
                      str(test_loss_val))

            # Save a checkpoint periodically.
            if (step + 1) % 1000 == 0 or (step + 1) == MAX_STEPS:
                checkpoint_file = os.path.join(LOG_DIR, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)

        summary_writer.close()
Ejemplo n.º 5
0
def train():
    """
    train cifar10 for a number of steps
    """

    with tf.Graph().as_default():
        global_step = tf.train.get_checkpoint_state()

        # get images and labels for cifar-10
        # force input pipeline to CPU:0 to avoid operations sometimes ending up
        # on GPU and resulting in a show down.
        with tf.device('/cpu:0'):
            images, labels = cifar10_model.distorted_inputs()

        # build a graph that computes the logits predictions from
        # the inference model.
        logits = cifar10_model.inference(images)

        # calculate loss.
        loss = cifar10_model.loss(logits, labels)

        # build a graph that trains the model with one batch of examples
        # and updates the model parameters
        train_op = cifar10_model.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """
            Logs loss and runtime. 
            """
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Ejemplo n.º 6
0
    LAMBDA = 0.001

    image = tf.placeholder(tf.float32, shape = [None, 32, 32, 3])
    label = tf.placeholder(tf.int32)

    dataset_iterator = cifar10_input.input_dataset(image, label, BATCH_SIZE, NO_OF_EPOCHS)
    data = dataset_iterator.get_next()
    image_queue = data["features"]
    label_queue = data["label"]

    step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE, step, DECAY_STEP, DECAY_RATE, staircase=True)
    logits_train = cifar10_model.dnn(image_queue, training = True)
    tf.get_variable_scope().reuse_variables()
    logits_test = cifar10_model.dnn(image_queue, training = False)
    loss, train_step = cifar10_model.train(logits_train, label_queue, learning_rate, LAMBDA, step)
    accuracy = cifar10_model.old_evaluate(logits_test, label_queue)

    path = './dataset/cifar-10-batches-py'
    filename_list = [(path + '/data_batch_%d' % i) for i in range(1, 6)]

    session_args = {
        'checkpoint_dir': './trained_model',
        'save_checkpoint_steps': 300
    }
    with tf.train.MonitoredTrainingSession(**session_args) as sess:

        count = 1
        for i in range(5):
            cifar10_dataset = cifar10_input.unpickle(filename_list[i])
            image_in = np.reshape(cifar10_dataset[b'data'], (-1, 32, 32, 3))