def evaluate(validation_set, validation_labels): """Evaluate model on Dataset for a number of steps.""" with tf.Graph().as_default(), tf.device('/cpu:0'): # Graph creation batch_size = validation_set.shape[0] images_placeholder, labels_placeholder = cifar10.placeholder_inputs( batch_size) logits = resnet.inference(images_placeholder, FLAGS.num_residual_blocks, reuse=False) predictions = tf.nn.softmax(logits) in_top1 = tf.to_float( tf.nn.in_top_k(predictions, labels_placeholder, k=1)) num_correct = tf.reduce_sum(in_top1) validation_accuracy = (batch_size - num_correct) / float(batch_size) # validation_accuracy = tf.reduce_sum(resnet.evaluation(logits, labels_placeholder)) / tf.constant(batch_size) validation_loss = resnet.loss(logits, labels_placeholder) # Reference to sess and saver sess = tf.Session() saver = tf.train.Saver() # Create summary writer graph_def = tf.get_default_graph().as_graph_def() summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, graph_def=graph_def) step = -1 while True: step = do_eval(saver, summary_writer, validation_accuracy, validation_loss, images_placeholder, labels_placeholder, validation_set, validation_labels, prev_global_step=step) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)
def train(training_set, training_labels): """Train on dataset for a number of steps.""" with tf.Graph().as_default(), tf.device('/gpu:0'): # Create a variable to count the number of train() calls. This equals the # number of batches processed * FLAGS.num_gpus. global_step = tf.Variable(0, name="global_step", trainable=False) # get num of examples in training set dataset_num_examples = training_set.shape[0] # Calculate the learning rate schedule. num_batches_per_epoch = int(dataset_num_examples / FLAGS.batch_size) # Decay the learning rate exponentially based on the number of steps. ''' lr = tf.train.exponential_decay(FLAGS.initial_learning_rate, global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) ''' lr_placeholder = tf.placeholder(dtype=tf.float32, shape=[]) # Create an optimizer that performs gradient descent. #opt = tf.train.AdamOptimizer(lr) opt = tf.train.MomentumOptimizer(lr_placeholder, MOMENTUM) #fetch the data batch from training set images, labels = cifar10.placeholder_inputs(FLAGS.batch_size) logits = resnet.inference(images, FLAGS.num_residual_blocks, reuse=False) #calc the loss and gradients loss = resnet.loss(logits, labels) regu_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n([loss] + regu_losses) grads = opt.compute_gradients(total_loss) # Apply the gradients to adjust the shared variables. apply_gradients_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_gradients_op]): train_op = tf.identity(total_loss, name='train_op') # Create a saver. saver = tf.train.Saver(tf.global_variables()) # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge_all() validation_accuracy = tf.reduce_sum(resnet.evaluation( logits, labels)) / tf.constant(FLAGS.batch_size) # Build an initialization operation to run below. init = tf.global_variables_initializer() # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # these two parameters is used to measure when to enter next epoch local_data_batch_idx = 0 epoch_counter = 0 batch_counter = 0 # Start the queue runners. tf.train.start_queue_runners(sess=sess) for step in range(FLAGS.max_steps): # change the API for new aug method epoch_counter, local_data_batch_idx, feed_dict = cifar10.fill_feed_dict( training_set, training_labels, images, labels, FLAGS.batch_size, local_data_batch_idx, epoch_counter, FLAGS.init_lr, lr_placeholder) batch_counter += 1 if batch_counter > num_batches_per_epoch: batch_counter = 0 start_time = time.time() _, loss_value, acc = sess.run( [train_op, total_loss, validation_accuracy], feed_dict=feed_dict) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' examples_per_sec = FLAGS.batch_size / float(duration) print( 'Train Epoch: {} [{}/{} ({:.0f}%)], Train Loss: {}, Time Cost: {}, Train Acc: {}' .format(epoch_counter, batch_counter, num_batches_per_epoch, (100. * (batch_counter * FLAGS.batch_size) / (FLAGS.batch_size * num_batches_per_epoch)), loss_value, time.time() - start_time, acc)) #tf.logging.info("Data batch index: %s, Current epoch idex: %s" % (str(epoch_counter), str(local_data_batch_idx))) if step == FLAGS.decay_step0 or step == FLAGS.decay_step1: FLAGS.init_lr = 0.1 * FLAGS.init_lr if step % 195 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)