def train():
	"""Train MNIST for a number of steps."""
  	# create a new graph and use it as default graph in the following context:
	with tf.Graph().as_default():
		global_step = tf.train.get_or_create_global_step()

		# Get images and labels
		# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
		# GPU and resulting in a slow down.
		with tf.device('/cpu:0'):
			labels, images = mnist_input.inputs([FILE_NAMES], batchSize=100, shuffle=True)
	
		# Build a Graph that computes the logits predictions from the
		# inference model.
		logits = mnist_model.inference(images)
	
		# Calculate loss.
		loss = mnist_model.loss(logits, labels)
	
		# Build a Graph that trains the model with one batch of examples and
		# updates the model parameters.
		train_op = mnist_model.train(loss, 0.001, global_step)
	
		class _LoggerHook(tf.train.SessionRunHook):
		"""Logs loss and runtime."""
			def begin(self):
				self._step = -1
				self._start_time = time.time()
		
			def before_run(self, run_context):
				self._step += 1
				return tf.train.SessionRunArgs(loss)  # Asks for loss value.
		
			def after_run(self, run_context, run_values):
				if self._step % 100 == 0:
					current_time = time.time()
					duration = current_time - self._start_time
					self._start_time = current_time
		
					loss_value = run_values.results
					sec_per_batch = float(duration / 100)
		
					format_str = ('%s: step %d, loss = %.2f (%.3f sec/batch)')
					print (format_str % (datetime.now(), self._step, loss_value, sec_per_batch))
	
		with tf.train.MonitoredTrainingSession(hooks=[_LoggerHook()]) as mon_sess:
			while not mon_sess.should_stop():
				mon_sess.run(train_op)
Exemple #2
0
def tower_loss(scope, images, labels):
    """scope:	name scope asigned to current device."""
    """images and labels:	data batch."""

    # construct network on current device
    logits = mnist_model.inference(images)
    # construct backward projection (i.e. compute loss) on current device
    # we must stop the data flow before the return node
    # to avoid 'fetch-op' from data node outside current device
    _ = mnist_model.loss(logits, labels)
    # now we use another method to fetch data only from current device
    # and compute the local loss
    losses = tf.get_collection('losses', scope)
    # compute total loss
    total_loss = tf.add_n(losses, name='total_loss')
    # add summaries
    for l in losses + [total_loss]:
        loss_name = re.sub('%s_[0-9]*/' % 'tower', '', l.op.name)
        tf.summary.scalar(loss_name, l)

    # over, return the total loss for current batch
    return total_loss
Exemple #3
0
    # Input image node (x): 2D tensor; batch size x flattened 28 x 28 MNIST image
    # Target output classes node (y_): 2D tensor; batch size x number of classes
    x = tf.placeholder(tf.float32, shape=[None, IMAGE_PIXELS], name="image")
    y_ = tf.placeholder(tf.float32, shape=[None, N_CLASSES], name="label")

global_step = tf.Variable(0,
                          dtype=tf.int32,
                          trainable=False,
                          name='global_step')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
train_phase = tf.placeholder_with_default(False, shape=[], name='train_phase')

logits = mnist_model.inference(x, IMAGE_SIZE, N_CLASSES, keep_prob,
                               train_phase)

loss = mnist_model.loss(y_, logits)

train_op = mnist_model.train(loss)

accuracy = mnist_model.evaluate(logits, y_)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1)

    # Set checkpoint path and restore checkpoint if exists
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, save_path=ckpt.model_checkpoint_path)
        print('Loaded model from latest checkpoint')
Exemple #4
0
@author: herow
"""
import tensorflow as tf
import mnist_model as mnist
import tensorflow.examples.tutorials.mnist.input_data as input_data

data_set = input_data.read_data_sets('MNIST_data', one_hot=True)

with tf.Graph().as_default():
    images_placeholder = tf.placeholder("float", shape=[None, 784])
    labels_placeholder = tf.placeholder("float", shape=[None, 10])
    drop_out = tf.placeholder("float")
    learning_rate = tf.placeholder("float")
    predicts = mnist.inference(images_placeholder, drop_out)
    loss = mnist.loss(predicts, labels_placeholder)
    train_op = mnist.training(loss, learning_rate)
    correct_prediction = tf.equal(tf.argmax(predicts, 1),
                                  tf.argmax(labels_placeholder, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    #accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    for i in xrange(2000):
        batch = data_set.train.next_batch(50)

        #        if i%100 == 0:
        #            train_accuracy = sess.run(accuracy,feed_dict={images_placeholder:batch[0],
        #                     labels_placeholder:batch[1], drop_out:1.0, learning_rate:1e-4})
        #            print "step %d, training accuracy %g"%(i,train_accuracy)
Exemple #5
0
def train():
  """Train model for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for MNIST data.
    images, labels = model.inputs()
    
    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = model.inference(images)

    # Calculate loss.
    loss = model.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = model.train(loss, global_step)

    accuracy = model.accuracy(logits,labels)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    print('Initializing all variables...')
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)
    print('Starting queue runners...')
    
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value,accuracy_value = sess.run([train_op, loss,accuracy])
      #print(sess.run(images)[0])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 1 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, batch accuracy = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, accuracy_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Exemple #6
0
def train(data_dir,
          model_dir,
          log_dir,
          batch_size=BATCH_SIZE,
          max_batches=MAX_TRAINING_STEPS):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False, name='global_step')

        images_ph, labels_ph, keep_rate_ph = mnist_model.placeholders()
        pred, logits = mnist_model.inference(images_ph, keep_rate_ph)
        loss = mnist_model.loss(logits, labels_ph)
        avg_loss = mnist_model.avg_loss()
        train_op = mnist_model.training(loss, LEARNING_RATE, global_step)
        accuracy = mnist_model.evaluation(pred, labels_ph)
        images, labels = mnist_input.input_pipeline(
            data_dir, batch_size, mnist_input.DataTypes.train)

        merged_summary_op = tf.merge_all_summaries()

        saver = tf.train.Saver()
        sess = tf.Session()
        ckpt = _get_checkpoint(model_dir)
        if not ckpt:
            print("Grand New training")
            init_op = tf.initialize_all_variables()
            sess.run(init_op)
        else:
            print("Resume training after %s" % ckpt)
            saver.restore(sess, ckpt)

        coord = tf.train.Coordinator()
        # Start the queue runner, QueueRunner created in mnist_input.py
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        train_writer = tf.train.SummaryWriter(log_dir + '/train', sess.graph)

        acc_step = sess.run(global_step)
        print('accumulated step = %d' % acc_step)
        print('prevous avg_loss = %.3f' % sess.run(avg_loss))
        # Training cycle
        try:
            lst = []
            for step in range(1, MAX_TRAINING_STEPS + 1):
                if coord.should_stop():
                    break

                images_r, labels_r = sess.run([images, labels])

                train_feed = {
                    images_ph: images_r,
                    labels_ph: labels_r,
                    keep_rate_ph: 0.5
                }

                start_time = time.time()
                _, train_loss = sess.run([train_op, loss],
                                         feed_dict=train_feed)
                duration = time.time() - start_time
                lst.append(train_loss)

                assert not np.isnan(
                    train_loss), 'Model diverged with loss = NaN'

                if step % DISPLAY_STEP == 0:
                    examples_per_sec = BATCH_SIZE / duration
                    sec_per_batch = float(duration)
                    print(
                        '%s: step %d, train_loss = %.6f (%.1f examples/sec; %.3f sec/batch)'
                        % (datetime.now(), step, train_loss, examples_per_sec,
                           sec_per_batch))

                if step % LOG_STEP == 0:
                    avg = np.mean(lst)
                    del lst[:]
                    #print('avg loss = %.3f' % avg)
                    sess.run(avg_loss.assign(avg))
                    summary_str = sess.run(merged_summary_op,
                                           feed_dict=train_feed)
                    train_writer.add_summary(summary_str, acc_step + step)

                if step % CKPT_STEP == 0 or step == MAX_TRAINING_STEPS:
                    ckpt_path = os.path.join(model_dir, 'model.ckpt')
                    saver.save(sess, ckpt_path, global_step)

        except tf.errors.OutOfRangeError:
            print('Done training for %d epochs' % (num_epochs))

        finally:
            # When done, ask the threads to stop
            coord.request_stop()

        coord.join(threads)
        train_writer.close()
        sess.close()