def evaluate(): """Eval CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): # Get images and labels for CIFAR-10. eval_data = True label_enqueue, images, labels = load_input.inputs(eval_data,distorted=False) # Build a Graph that computes the logits predictions from the # inference model. logits = model.rnn_model(images) # Calculate predictions. top_k_op = tf.nn.in_top_k(logits, labels, 1) # Restore the moving average version of the learned variables for eval. variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay) variables_to_restore = {} for v in tf.all_variables(): if v in tf.trainable_variables(): restore_name = variable_averages.average_name(v) else: restore_name = v.op.name variables_to_restore[restore_name] = v saver = tf.train.Saver(variables_to_restore) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() graph_def = tf.get_default_graph().as_graph_def() summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, graph_def=graph_def) while True: eval_once(saver, summary_writer, top_k_op, summary_op, label_enqueue) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)
def train(): with tf.Graph().as_default(), tf.device('/gpu:0'): global_step = tf.get_variable( 'global_step',[], initializer=tf.constant_initializer(0), trainable=False) eval_data = False label_enqueue, images, labels = load_input.inputs(eval_data, distorted=True) # Build a Graph that computes the logits predictions from the # inference model. logits,glimpse_vars= model.rnn_model(images) # Calculate loss. loss = model.loss(logits, labels) n = tf.zeros([1], dtype=tf.int32) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = model.train(loss, global_step) # Create a saver. saver = tf.train.Saver(tf.all_variables()) pretrained_glimpse_vars = { u'conv1/weights': glimpse_vars['conv1/weights:0'], u'conv1/biases': glimpse_vars['conv1/biases:0'], u'conv2/weights': glimpse_vars['conv2/weights:0'], u'conv2/biases': glimpse_vars['conv2/biases:0'], u'conv3/weights': glimpse_vars['conv3/weights:0'], u'conv3/biases': glimpse_vars['conv3/biases:0'], } # pretrained_context_vars = { # u'conv1/weights:': context_vars['conv1/weights:0'], # u'conv1/biases:': context_vars['conv1/biases:0'], # u'conv2/weights:': context_vars['conv2/weights:0'], # u'conv2/biases:': context_vars['conv2/biases:0'], # u'conv3/weights:': context_vars['conv3/weights:0'], # u'conv3/biases:': context_vars['conv3/biases:0'], # } # print "="*50 # for var in tf.all_variables(): # print var.name, ":", var pretrained_glimpse_saver = tf.train.Saver(pretrained_glimpse_vars) #pretrained_context_saver = tf.train.Saver(pretrained_context_vars) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.merge_all_summaries() # # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) as sess: sess.run(init) pretrained_ckpt = FLAGS.pretrained_checkpoint_path pretrained_glimpse_saver.restore(sess, pretrained_ckpt) #pretrained_context_saver.restore(sess, pretrained_ckpt) coord = tf.train.Coordinator() threads = [] for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True)) sess.run(label_enqueue) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph_def) for step in xrange(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / float(duration) sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch) if step % 100 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) end_epoch = False if step > 0: for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): size = qr._queue.size().eval() if size - FLAGS.batch_size < FLAGS.min_queue_size: end_epoch = True if end_epoch: sess.run(label_enqueue) coord.request_stop() coord.join(threads)