def distorted_inputs(): lefts, disps, confs = model_input.distorted_inputs(data_dir=FLAGS.data_name, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: lefts = tf.cast(lefts, tf.float16) disps = tf.cast(disps, tf.float16) confs = tf.cast(confs, tf.float16) return lefts, disps, confs
def distorted_inputs(): """Construct distorted input for training using the Reader ops. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'batches-bin') # Add filename return model_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size)
def distorted_inputs(): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir) return model_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size)
def train(): global_step = tf.Variable(0, trainable=False) images, labels = model_input.distorted_inputs( ['train.tfrecords'], FLAGS.batch_size) logits = model_inference.inference(images, FLAGS.batch_size) loss = model_inference.loss(logits, labels) # Prepare training operation # Set up learning rate num_batches_per_epoch = model_input.TRAINING_SET_SIZE / FLAGS.batch_size decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) tf.scalar_summary('learning_rate', lr) # Compute the moving average of all individual losses and the total loss. loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') losses = tf.get_collection('losses') loss_averages_op = loss_averages.apply(losses + [loss]) # Attach a scalar summary to all individual losses and the total loss; do the # same for the averaged version of the losses. for l in losses + [loss]: # Name each loss as '(raw)' and name the moving average version of the loss # as the original loss name. tf.scalar_summary(l.op.name + ' (raw)', l) tf.scalar_summary(l.op.name, loss_averages.average(l)) # Gradients with tf.control_dependencies([loss_averages_op]): opt = tf.train.GradientDescentOptimizer(lr) grads = opt.compute_gradients(loss) # can do gradient processing here before applying. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.histogram_summary(var.op.name, var) # Add histograms for gradients. for grad, var in grads: if grad: tf.histogram_summary(var.op.name + '/gradients', grad) # Track the moving averages of all trainable variables. variable_averages = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) # this no op is used so that we 'join' the two operations apply_gradient # and variables_averages, making sure they're executed every training step # by using this control_dependencies thing with tf.control_dependencies([apply_gradient_op, variables_averages_op]): train_op = tf.no_op(name='train') summary_op = tf.merge_all_summaries() init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto( log_device_placement=False)) sess.run(init) tf.train.start_queue_runners(sess=sess) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph.as_graph_def(add_shapes=True)) num_steps = num_batches_per_epoch * FLAGS.run_epochs saver = tf.train.Saver(tf.all_variables()) print('will run for ' + str(num_steps) + ' steps') for step in xrange(num_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 100 == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % 1000 == 0 or (step + 1) == num_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)