Esempio n. 1
0
def input_pipeline(is_training=True,
                   model_scope=FLAGS.model_scope,
                   num_epochs=None):
    if 'all' in model_scope:
        lnorm_table = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                tf.constant(config.global_norm_key, dtype=tf.int64),
                tf.constant(config.global_norm_lvalues, dtype=tf.int64)), 0)
        rnorm_table = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                tf.constant(config.global_norm_key, dtype=tf.int64),
                tf.constant(config.global_norm_rvalues, dtype=tf.int64)), 1)
    else:
        lnorm_table = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                tf.constant(config.local_norm_key, dtype=tf.int64),
                tf.constant(config.local_norm_lvalues, dtype=tf.int64)), 0)
        rnorm_table = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(
                tf.constant(config.local_norm_key, dtype=tf.int64),
                tf.constant(config.local_norm_rvalues, dtype=tf.int64)), 1)

    preprocessing_fn = lambda org_image, classid, shape, key_x, key_y, key_v: preprocessing.preprocess_image(
        org_image,
        classid,
        shape,
        FLAGS.train_image_size,
        FLAGS.train_image_size,
        key_x,
        key_y,
        key_v, (lnorm_table, rnorm_table),
        is_training=is_training,
        data_format=('NCHW'
                     if FLAGS.data_format == 'channels_first' else 'NHWC'),
        category=(model_scope if 'all' not in model_scope else '*'),
        bbox_border=FLAGS.bbox_border,
        heatmap_sigma=FLAGS.heatmap_sigma,
        heatmap_size=FLAGS.heatmap_size)

    images, shape, classid, targets, key_v, isvalid, norm_value = dataset.slim_get_split(
        FLAGS.data_dir,
        preprocessing_fn,
        FLAGS.batch_size,
        FLAGS.num_readers,
        FLAGS.num_preprocessing_threads,
        num_epochs=num_epochs,
        is_training=is_training,
        file_pattern=FLAGS.dataset_name,
        category=(model_scope if 'all' not in model_scope else '*'),
        reader=None)

    return images, {
        'targets': targets,
        'key_v': key_v,
        'shape': shape,
        'classid': classid,
        'isvalid': isvalid,
        'norm_value': norm_value
    }
Esempio n. 2
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Calculate the learning rate schedule.
        opt = tf.train.AdamOptimizer(1e-4)

        isTrain_ph = tf.placeholder(tf.bool, shape=None, name="is_train")
        # images = tf.placeholder(tf.float32, [None, 224, 224, 3])
        # labels = tf.placeholder(tf.float32, shape=[1024])

        dataset = flowers.get_split(FLAGS.subset, FLAGS.data_dir)
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=True,
            common_queue_capacity=2 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])
        image = preprocessing.preprocess_image(image,
                                               224,
                                               224,
                                               is_training=True)
        images, labels = tf.train.batch([image, label],
                                        batch_size=FLAGS.batch_size,
                                        num_threads=4,
                                        capacity=5 * FLAGS.batch_size)

        with tf.variable_scope(tf.get_variable_scope()) as scope:
            loss = cpu_loss(images, labels, scope, isTrain_ph)
            # Calculate the gradients for the batch of data on this CIFAR tower.
            grads = opt.compute_gradients(loss)

        # Apply the gradients to adjust the shared variables.
        # apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        train_op = opt.apply_gradients(grads)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())
        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss],
                                     feed_dict={isTrain_ph: False})
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration / FLAGS.num_gpus

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)