def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar10_model.distorted_inputs()
        logits = cifar10_model.inference(images)
        loss = cifar10_model.loss(logits, labels)
        train_op = cifar10_model.train(loss, global_step)
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            if step % 100 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
def train():
    # 读取图片并带入网络计算
    images, labels = cifar10_input.distorted_inputs(DATA_DIR, BATCH_SIZE)
    t_logits = cifar10_model.inference(images)
    # 损失值
    t_loss = cifar10_model.loss(t_logits, labels)
    tf.summary.scalar('loss_value', t_loss)
    # 优化器
    global_step = tf.Variable(0, trainable=False)
    t_optimizer = cifar10_model.train_step(t_loss, global_step)
    # 准确值
    t_accuracy = cifar10_model.accuracy(t_logits, labels)  # 训练集正确率计算
    tf.summary.scalar('accuracy_value', t_accuracy)

    merged = tf.summary.merge_all()
    saver = tf.train.Saver()
    Accuracy_value = []
    Loss_value = []
    # 设定定量的GPU显存使用量
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    with tf.Session(config=config) as session:
        session.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=session, coord=coord)
        train_writer = tf.summary.FileWriter('./signal_GPU/logs',
                                             session.graph)
        for index in range(EPOCHES):
            _, loss_value, accuracy_value, summary = session.run(
                [t_optimizer, t_loss, t_accuracy, merged])
            Accuracy_value.append(accuracy_value)
            Loss_value.append(loss_value)
            if index % 1000 == 0:
                print('index:', index, ' loss_value:', loss_value,
                      ' accuracy_value:', accuracy_value)
            train_writer.add_summary(summary, index)
        saver.save(session, os.path.join('./signal_GPU/saver/', 'model.ckpt'))
        # accuracy value
        plt.figure(figsize=(20, 10))
        plt.plot(range(EPOCHES), Accuracy_value)
        plt.xlabel('training step')
        plt.ylabel('accuracy value')
        plt.title('the accuracy value of training data')
        plt.savefig('./signal_GPU/accuracy.png')
        # loss value
        plt.figure()
        plt.plot(range(EPOCHES), Loss_value)
        plt.xlabel('training value')
        plt.ylabel('loss value')
        plt.title('the value of the loss function of the training data')
        plt.savefig('./signal_GPU/loss.png')
        #
        train_writer.close()
        coord.request_stop()
        coord.join(threads)
def model_fn(features, labels, mode, params):
    logits = cifar10_model.inference(image_batch=features,
                                     batch_size=params.get('batch_size'))
    loss = cifar10_model.loss(logits, labels)
    train_op = cifar10_model.train(loss, batch_size=params.get('batch_size'))

    if mode == tf.estimator.ModeKeys.TRAIN:
        logging_hook = tf.train.LoggingTensorHook({'loss': loss},
                                                  every_n_iter=1000)
        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=[logging_hook])
示例#4
0
def run_training():
    cifar10_data = Cifar10Data('./input_data')

    images_pl = tf.placeholder(tf.float32, [
        None, cifar10_model.IMAGE_BATCH_HEIGHT,
        cifar10_model.IMAGE_BATCH_WIDTH, cifar10_model.IMAGE_BATCH_DEPTH
    ])
    labels_pl = tf.placeholder(tf.int32)
    keep_prob_pl = tf.placeholder(tf.float32)
    learning_rate_pl = tf.placeholder(tf.float32)
    is_test_pl = tf.placeholder(tf.bool)
    iter_pl = tf.placeholder(tf.int32)

    with tf.Session() as sess:
        logits, update_ema = cifar10_model.inference(images_pl, iter_pl,
                                                     is_test_pl, keep_prob_pl)
        total_loss = cifar10_model.loss(logits, labels_pl)
        train_op = cifar10_model.train(total_loss, learning_rate_pl)
        eval_op = cifar10_eval.evaluation(logits, labels_pl)

        saver = tf.train.Saver()
        summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
        sess.run(tf.global_variables_initializer())

        # learning rate decay
        max_learning_rate = 0.02  # 0.003
        min_learning_rate = 0.0001
        decay_speed = 1600.0  # 2000.0
        for step in xrange(MAX_STEPS):
            print('step %d/%d' % (step, MAX_STEPS))
            start_time = time.time()
            images_feed, labels_feed = cifar10_data.random_training_batch(
                cifar10_model.BATCH_SIZE)
            learning_rate = min_learning_rate + (max_learning_rate -
                                                 min_learning_rate) * math.exp(
                                                     -step / decay_speed)

            images_feed = sess.run(
                cifar10_model.random_distort_images(images_feed))

            feed_dict = {
                images_pl: images_feed,
                labels_pl: labels_feed,
                keep_prob_pl: 0.75,
                learning_rate_pl: learning_rate,
                is_test_pl: False,
                iter_pl: step
            }
            sess.run(train_op, feed_dict=feed_dict)

            feed_dict = {
                images_pl: images_feed,
                labels_pl: labels_feed,
                keep_prob_pl: 1.0,
                learning_rate_pl: learning_rate,
                is_test_pl: False,
                iter_pl: step
            }
            sess.run(update_ema, feed_dict=feed_dict)

            duration = time.time() - start_time

            # Write the summaries and print an overview fairly often.
            if (step + 1) % 100 == 0 or (step + 1) == MAX_STEPS:
                train_eval_val, loss_value = sess.run([eval_op, total_loss],
                                                      feed_dict=feed_dict)
                print('Step %d: loss = %.2f, lr = %f (%.3f sec)' %
                      (step + 1, loss_value, learning_rate, duration))
                print('Training Data Eval: %.4f' % train_eval_val)

                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                # Evaluate the model periodically.
                # feed_dict = {images_pl: data_sets.testing_image,
                #             labels_pl: data_sets.testing_label}
                # test_eval_val = sess.run(eval_op, feed_dict=feed_dict)
                test_eval_val, test_loss_val = cifar10_eval.mass_evaluation(
                    cifar10_data, sess, eval_op, total_loss, images_pl,
                    labels_pl, keep_prob_pl, is_test_pl)
                print('Testing Data Eval: ' + str(test_eval_val) + '  loss: ' +
                      str(test_loss_val))

            # Save a checkpoint periodically.
            if (step + 1) % 1000 == 0 or (step + 1) == MAX_STEPS:
                checkpoint_file = os.path.join(LOG_DIR, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)

        summary_writer.close()
示例#5
0
def train():
    """
    train cifar10 for a number of steps
    """

    with tf.Graph().as_default():
        global_step = tf.train.get_checkpoint_state()

        # get images and labels for cifar-10
        # force input pipeline to CPU:0 to avoid operations sometimes ending up
        # on GPU and resulting in a show down.
        with tf.device('/cpu:0'):
            images, labels = cifar10_model.distorted_inputs()

        # build a graph that computes the logits predictions from
        # the inference model.
        logits = cifar10_model.inference(images)

        # calculate loss.
        loss = cifar10_model.loss(logits, labels)

        # build a graph that trains the model with one batch of examples
        # and updates the model parameters
        train_op = cifar10_model.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """
            Logs loss and runtime. 
            """
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)