def input_pipeline(is_training=True, model_scope=FLAGS.model_scope, num_epochs=None): if 'all' in model_scope: lnorm_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( tf.constant(config.global_norm_key, dtype=tf.int64), tf.constant(config.global_norm_lvalues, dtype=tf.int64)), 0) rnorm_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( tf.constant(config.global_norm_key, dtype=tf.int64), tf.constant(config.global_norm_rvalues, dtype=tf.int64)), 1) else: lnorm_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( tf.constant(config.local_norm_key, dtype=tf.int64), tf.constant(config.local_norm_lvalues, dtype=tf.int64)), 0) rnorm_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer( tf.constant(config.local_norm_key, dtype=tf.int64), tf.constant(config.local_norm_rvalues, dtype=tf.int64)), 1) preprocessing_fn = lambda org_image, classid, shape, key_x, key_y, key_v: preprocessing.preprocess_image( org_image, classid, shape, FLAGS.train_image_size, FLAGS.train_image_size, key_x, key_y, key_v, (lnorm_table, rnorm_table), is_training=is_training, data_format=('NCHW' if FLAGS.data_format == 'channels_first' else 'NHWC'), category=(model_scope if 'all' not in model_scope else '*'), bbox_border=FLAGS.bbox_border, heatmap_sigma=FLAGS.heatmap_sigma, heatmap_size=FLAGS.heatmap_size) images, shape, classid, targets, key_v, isvalid, norm_value = dataset.slim_get_split( FLAGS.data_dir, preprocessing_fn, FLAGS.batch_size, FLAGS.num_readers, FLAGS.num_preprocessing_threads, num_epochs=num_epochs, is_training=is_training, file_pattern=FLAGS.dataset_name, category=(model_scope if 'all' not in model_scope else '*'), reader=None) return images, { 'targets': targets, 'key_v': key_v, 'shape': shape, 'classid': classid, 'isvalid': isvalid, 'norm_value': norm_value }
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(), tf.device('/cpu:0'): # Create a variable to count the number of train() calls. This equals the # number of batches processed * FLAGS.num_gpus. global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) # Calculate the learning rate schedule. opt = tf.train.AdamOptimizer(1e-4) isTrain_ph = tf.placeholder(tf.bool, shape=None, name="is_train") # images = tf.placeholder(tf.float32, [None, 224, 224, 3]) # labels = tf.placeholder(tf.float32, shape=[1024]) dataset = flowers.get_split(FLAGS.subset, FLAGS.data_dir) provider = slim.dataset_data_provider.DatasetDataProvider( dataset, shuffle=True, common_queue_capacity=2 * FLAGS.batch_size, common_queue_min=FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) image = preprocessing.preprocess_image(image, 224, 224, is_training=True) images, labels = tf.train.batch([image, label], batch_size=FLAGS.batch_size, num_threads=4, capacity=5 * FLAGS.batch_size) with tf.variable_scope(tf.get_variable_scope()) as scope: loss = cpu_loss(images, labels, scope, isTrain_ph) # Calculate the gradients for the batch of data on this CIFAR tower. grads = opt.compute_gradients(loss) # Apply the gradients to adjust the shared variables. # apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) train_op = opt.apply_gradients(grads) # Create a saver. saver = tf.train.Saver(tf.all_variables()) # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)) sess.run(init) # Start the queue runners. tf.train.start_queue_runners(sess=sess) for step in xrange(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss], feed_dict={isTrain_ph: False}) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / FLAGS.num_gpus format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)