def evaluate(): """Eval CIFAR-10 for a number of steps.""" with tf.Graph().as_default(), tf.device('/gpu:0'): # Get images and labels for CIFAR-10. eval_data = True label_enqueue, images, labels = load_input.inputs(eval_data,distorted=False) # Build a Graph that computes the logits predictions from the # inference model. logits = model.inference(images) # Calculate predictions. top_k_op = tf.nn.in_top_k(logits, labels, 1) # Restore the moving average version of the learned variables for eval. variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = {} for v in tf.all_variables(): if v in tf.trainable_variables(): restore_name = variable_averages.average_name(v) else: restore_name = v.op.name variables_to_restore[restore_name] = v saver = tf.train.Saver(variables_to_restore) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() graph_def = tf.get_default_graph().as_graph_def() summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, graph_def=graph_def) while True: eval_once(saver, summary_writer, top_k_op, summary_op, label_enqueue) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)
def train(): with tf.Graph().as_default(): global_step = tf.get_variable( 'global_step',[], initializer=tf.constant_initializer(0), trainable=False) eval_data = False label_enqueue, images, labels = load_input.inputs(eval_data, distorted=True) # Build a Graph that computes the logits predictions from the # inference model. logits = model.inference(images) # Calculate loss. loss = model.loss(logits, labels) n = tf.zeros([1], dtype=tf.int32) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = model.train(loss, global_step) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() # # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) as sess: sess.run(init) ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] coord = tf.train.Coordinator() threads = [] for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True)) sess.run(label_enqueue) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph_def) for step in xrange(FLAGS.max_steps): #print 'step:',step start_time = time.time() _, loss_value = sess.run([train_op, loss]) #print 'ran' duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / float(duration) sec_per_batch = float(duration) sys.stdout.flush() format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch) #print 'here' if step % 100 == 0: #print "entering summary" summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) #print "exiting summary" # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) end_epoch = False if step > 0: for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): size = qr._queue.size().eval() if size - FLAGS.batch_size < FLAGS.min_queue_size: end_epoch = True if end_epoch: sess.run(label_enqueue) coord.request_stop() coord.join(threads)