def main(unused_args):
  assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'

  # Extract all the hostnames for the ps and worker jobs to construct the
  # cluster spec.
  ps_hosts = FLAGS.ps_hosts.split(',')
  worker_hosts = FLAGS.worker_hosts.split(',')
  tf.logging.info('PS hosts are: %s' % ps_hosts)
  tf.logging.info('Worker hosts are: %s' % worker_hosts)

  cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
                                       'worker': worker_hosts})
  server = tf.train.Server(
      {'ps': ps_hosts,
       'worker': worker_hosts},
      job_name=FLAGS.job_name,
      task_index=FLAGS.task_id,
      protocol=FLAGS.protocol)

  if FLAGS.job_name == 'ps':
    # `ps` jobs wait for incoming connections from the workers.
    server.join()
  else:
    # `worker` jobs will actually do the work.
    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    # Only the chief checks for or creates train_dir.
    if FLAGS.task_id == 0:
      if not tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.MakeDirs(FLAGS.train_dir)
    train(server.target, dataset, cluster_spec)
Esempio n. 2
0
def main(unused_argv=None):
  dataset = ImagenetData(subset=FLAGS.subset)
  assert dataset.data_files()
  if tf.gfile.Exists(FLAGS.eval_dir):
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  evaluate(dataset)
Esempio n. 3
0
def main(_):
    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    inception_train.train(dataset)
def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    # data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')

    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    # return cifar10_input.distorted_inputs(data_dir=data_dir,
    #                                       batch_size=FLAGS.batch_size)
    return cifar10_imagenet.distorted_inputs(
        dataset,
        batch_size=FLAGS.batch_size,
        num_preprocess_threads=FLAGS.num_preprocess_threads)
Esempio n. 5
0
def tower_loss(scope):
  """Calculate the total loss on a single tower running the baxNet model.

  Args:
    scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'

  Returns:
     Tensor of shape [] containing the total loss for a batch of data
  """
  dataset = ImagenetData(subset='train')
  assert dataset.data_files()
#  if tf.gfile.Exists(FLAGS.eval_dir):
#    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
#  tf.gfile.MakeDirs(FLAGS.eval_dir)
  
  num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
  images, labels = image_processing.distorted_inputs(dataset,
                                                     num_preprocess_threads=num_preprocess_threads)

  # Build inference Graph.
  logits = baxNet.inference(images)

  # Build the portion of the Graph calculating the losses. Note that we will
  # assemble the total_loss using a custom function below.
  _ = baxNet.loss(logits, labels)

  # Assemble all of the losses for the current tower only.
  losses = tf.get_collection('losses', scope)

  # Calculate the total loss for the current tower.
  total_loss = tf.add_n(losses, name='total_loss')

  # Compute the moving average of all individual losses and the total loss.
  loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
  loss_averages_op = loss_averages.apply(losses + [total_loss])

  # Attach a scalar summary to all individual losses and the total loss; do the
  # same for the averaged version of the losses.
  for l in losses + [total_loss]:
    # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
    # session. This helps the clarity of presentation on tensorboard.
    loss_name = re.sub('%s_[0-9]*/' % baxNet.TOWER_NAME, '', l.op.name)
    # Name each loss as '(raw)' and name the moving average version of the loss
    # as the original loss name.
    tf.scalar_summary(loss_name +' (raw)', l)
    tf.scalar_summary(loss_name, loss_averages.average(l))

  with tf.control_dependencies([loss_averages_op]):
    total_loss = tf.identity(total_loss)
  return total_loss
Esempio n. 6
0
def inputs(subset):
    """Construct distorted input for CIFAR training using the Reader ops.

    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [batch_size] size.

    Raises:
      ValueError: If no data_dir
    """
    if FLAGS.dataset is None:
        raise ValueError('Please supply a dataset')

    if FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenet_scale':
        dataset = ImagenetData(FLAGS.dataset, subset=subset)
    elif FLAGS.dataset == 'cifar10':
        dataset = Cifar10Data(FLAGS.dataset, subset=subset)
    elif FLAGS.dataset == 'cifar100':
        dataset = Cifar100Data(FLAGS.dataset, subset=subset)

    if subset == 'train':
        images, labels = image_reader.create_data_batch(
            dataset, FLAGS.train_batch_size)
    elif subset == 'validation':
        images, labels = image_reader.create_data_batch(
            dataset, FLAGS.test_batch_size)

    tf.add_to_collection('images', images)
    return images, labels
Esempio n. 7
0
 def __init__(self, split, batch_size):
   self.dataset = ImagenetData(split)
   if split == 'train':
     self.batch = imagenet.distorted_inputs(self.dataset, batch_size=batch_size)
   elif split == 'validation':
     self.batch = imagenet.inputs(self.dataset, batch_size=batch_size)
   else:
     raise Exception('Unknown split {}'.format(split))
Esempio n. 8
0
def get_validation_num():
    if FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenet_scale':
        return ImagenetData('imagenet',
                            subset='validation').num_examples_per_epoch()
    elif FLAGS.dataset == 'cifar10':
        return Cifar10Data('cifar10',
                           subset='validation').num_examples_per_epoch()
    elif FLAGS.dataset == 'cifar100':
        return Cifar100Data('cifar100',
                            subset='validation').num_examples_per_epoch()
Esempio n. 9
0
def main(argv=None):
    dataset = ImagenetData(subset=FLAGS.subset)
    if FLAGS.job_name:
        assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'

#        if tf.gfile.Exists(arg_parsing.MODEL_DIR):
#            tf.gfile.DeleteRecursively(arg_parsing.MODEL_DIR)
#        else:
#           tf.gfile.MakeDirs(FLAGS.model_dir)
        printInfo()
        print("dataset",dataset)
        train.train_dis_(dataset)
    else:
        assert dataset.data_files()
        # if (FLAGS.mode == 'testing'):
        #     test.test(FLAGS.mode)
        # else:
        printInfo()
        print("dataset",dataset)
        train.train(dataset)
Esempio n. 10
0
def main(_):
    util.check_tensorflow_version()

    dataset = ImagenetData(subset=FLAGS.subset)

    processor = ProcessorImagenet()
    processor.label_offset = FLAGS.label_offset

    feed = FeedImagesWithLabels(dataset=dataset, processor=processor)

    model_params = {
        'num_classes': feed.num_classes_for_network(),
        'network': FLAGS.network,
    }

    if FLAGS.my:
        # My variants of Resnet, Inception, and VGG networks
        model = ModelMySlim(params=model_params)
    else:
        # Google's tf.slim models
        model = ModelGoogleSlim(params=model_params)
        model.check_norm(processor.normalize)

    exec_eval.evaluate(feed=feed, model=model)
Esempio n. 11
0
def main(_):
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        dataset = ImagenetData(subset=FLAGS.subset)
        assert dataset.data_files()
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                 FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

        # Decay the learning rate exponentially based on the number of steps.
        learning_rate = tf.train.exponential_decay(
            FLAGS.learning_rate,
            global_step,
            decay_steps,
            FLAGS.learning_rate_decay_factor,
            staircase=True)

        tf.summary.scalar('lr', learning_rate)

        is_training = tf.placeholder(tf.bool)

        #opt = tf.train.AdamOptimizer(learning_rate)
        opt = tf.train.RMSPropOptimizer(learning_rate,
                                        RMSPROP_DECAY,
                                        momentum=RMSPROP_MOMENTUM,
                                        epsilon=RMSPROP_EPSILON)

        with tf.name_scope("create_inputs"):
            #if tf.gfile.Exists(FLAGS.SNAPSHOT_DIR):
            #    tf.gfile.DeleteRecursively(FLAGS.SNAPSHOT_DIR)
            #tf.gfile.MakeDirs(FLAGS.SNAPSHOT_DIR)

            # Get images and labels for ImageNet and split the batch across GPUs.
            assert FLAGS.batch_size % FLAGS.gpu_nums == 0, (
                'Batch size must be divisible by number of GPUs')
            split_batch_size = int(FLAGS.batch_size / FLAGS.gpu_nums)

            # Override the number of preprocessing threads to account for the increased
            # number of GPU towers.
            num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.gpu_nums
            images, labels = image_processing.distorted_inputs(
                dataset, num_preprocess_threads=num_preprocess_threads)
            #tf.summary.image('images', images, max_outputs = 10)

            images_splits = tf.split(axis=0,
                                     num_or_size_splits=FLAGS.gpu_nums,
                                     value=images)
            labels_splits = tf.split(axis=0,
                                     num_or_size_splits=FLAGS.gpu_nums,
                                     value=tf.one_hot(indices=labels,
                                                      depth=FLAGS.num_classes))

        multi_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(FLAGS.gpu_nums):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('ImageNet', i)) as scope:

                        graph = Model_Graph(num_class=FLAGS.num_classes,
                                            is_training=is_training)

                        model = graph._build_defaut_graph(
                            images=images_splits[i])

                        # Top-1 accuracy
                        top1acc = tf.reduce_mean(
                            tf.cast(
                                tf.nn.in_top_k(
                                    model.logits,
                                    tf.argmax(labels_splits[i], axis=1), 1),
                                tf.float32))
                        # Top-n accuracy
                        topnacc = tf.reduce_mean(
                            tf.cast(
                                tf.nn.in_top_k(
                                    model.logits,
                                    tf.argmax(labels_splits[i], axis=1),
                                    FLAGS.top_k), tf.float32))

                        tf.summary.scalar('top1acc_{}'.format(i), top1acc)
                        tf.summary.scalar('topkacc_{}'.format(i), topnacc)

                        all_trainable = [v for v in tf.trainable_variables()]

                        loss = tf.nn.softmax_cross_entropy_with_logits(
                            logits=model.logits, labels=labels_splits[i])

                        l2_losses = [
                            FLAGS.weight_decay * tf.nn.l2_loss(v)
                            for v in tf.trainable_variables()
                            if 'weights' in v.name
                        ]
                        reduced_loss = tf.reduce_mean(loss) + tf.add_n(
                            l2_losses)

                        tf.summary.scalar('loss_{}'.format(i), reduced_loss)

                        tf.get_variable_scope().reuse_variables()

                        #batchnorm_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)
                        batchnorm_updates = tf.get_collection(
                            tf.GraphKeys.UPDATE_OPS)

                        grads = opt.compute_gradients(reduced_loss,
                                                      all_trainable)
                        multi_grads.append(grads)

        grads = average_gradients(multi_grads)

        # Track the moving averages of all trainable variables.
        # Note that we maintain a "double-average" of the BatchNormalization
        # global statistics. This is more complicated then need be but we employ
        # this for backward-compatibility with our previous models.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY, global_step)

        variables_to_average = (tf.trainable_variables() +
                                tf.moving_average_variables())
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all updates to into a single train op.
        batchnorm_updates_op = tf.group(*batchnorm_updates)
        train_op = tf.group(opt.apply_gradients(grads, global_step),
                            variables_averages_op, batchnorm_updates_op)

        #grads_value = list(zip(grads, all_trainable))
        #for grad, var in grads_value:
        #    tf.summary.histogram(var.name + '/gradient', grad)

        summary_op = tf.summary.merge_all()

        # Set up tf session and initialize variables.
        config = tf.ConfigProto()
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        init = tf.global_variables_initializer()

        sess.run(init)

        # Saver for storing checkpoints of the model.
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=2)

        restore_var = [v for v in tf.trainable_variables()] + [
            v for v in tf.global_variables() if 'moving_mean' in v.name
            or 'moving_variance' in v.name or 'global_step' in v.name
        ]

        ckpt = tf.train.get_checkpoint_state(FLAGS.SNAPSHOT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=restore_var)
            load(loader, sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found.')
            load_step = 0

        summary_writer = tf.summary.FileWriter(FLAGS.SNAPSHOT_DIR,
                                               graph=sess.graph)

        # Iterate over training steps.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)

        for step in range(FLAGS.num_steps):
            start_time = time.time()

            feed_dict = {is_training: True}
            if step % 50000 == 0 and step != 0:
                loss_value, _ = sess.run([reduced_loss, train_op],
                                         feed_dict=feed_dict)
                save(saver, sess, FLAGS.SNAPSHOT_DIR, step)
            elif step % 100 == 0:
                summary_str, loss_value, _ = sess.run(
                    [summary_op, reduced_loss, train_op], feed_dict=feed_dict)
                duration = time.time() - start_time
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()
                print('step {:d} \t loss = {:.3f}, ({:.3f} sec/step)'.format(
                    step, loss_value, duration))
            else:
                loss_value, _ = sess.run([reduced_loss, train_op],
                                         feed_dict=feed_dict)

        coord.request_stop()
        coord.join(threads)
Esempio n. 12
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Get images and labels for CIFAR-10.
        #dataset = CIFARData(subset='train')
        dataset = ImagenetData(subset='train')
        assert dataset.data_files()

        #test_set = CIFARData(subset='validation')
        test_set = ImagenetData(subset='validation')
        assert test_set.data_files()

        epoch1 = .5 * helper.MAX_EPOCHS
        epoch2 = .75 * helper.MAX_EPOCHS
        step1 = dataset.num_examples_per_epoch() * epoch1 // (
            helper.BATCH_SIZE)
        step2 = dataset.num_examples_per_epoch() * epoch2 // (
            helper.BATCH_SIZE)
        print('Reducing learning rate at step ' + str(step1) + ' and step ' +
              str(step2) + ' and ending at ' + str(helper.MAX_STEPS))

        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Learning rate
        lr = .1

        #learning_rate = tf.placeholder(tf.float32, shape=[], name='learning_rate')
        dropout = tf.placeholder(tf.float32, shape=[], name='dropout')
        is_training = tf.placeholder(tf.bool, shape=[], name='is_training')

        boundaries = [step1, step2]
        values = [lr, lr / 10, lr / 100]

        learning_rate = tf.train.piecewise_constant(global_step,
                                                    boundaries,
                                                    values,
                                                    name=None)

        decayed_lr = tf.train.polynomial_decay(lr,
                                               global_step,
                                               helper.MAX_STEPS,
                                               end_learning_rate=0.0001,
                                               power=4.0,
                                               cycle=False,
                                               name=None)

        # Create an optimizer that performs gradient descent.
        with tf.name_scope('Optimizer'):
            opt = tf.train.MomentumOptimizer(learning_rate=decayed_lr,
                                             momentum=0.9,
                                             use_nesterov=True)
            #opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True)

        tf.summary.scalar('decayed_learning_rate', decayed_lr)
        tf.summary.scalar('learning_rate', learning_rate)

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        num_preprocess_threads = helper.NUM_THREADS * helper.N_GPUS
        distorted_images, distorted_labels = image_processing.distorted_inputs(
            dataset,
            batch_size=helper.SPLIT_BATCH_SIZE,
            num_preprocess_threads=num_preprocess_threads)

        #images, labels = image_processing.inputs(dataset, batch_size=helper.BATCH_SIZE, num_preprocess_threads=num_preprocess_threads)
        test_images, test_labels = image_processing.inputs(
            test_set,
            batch_size=helper.SPLIT_BATCH_SIZE,
            num_preprocess_threads=num_preprocess_threads)

        input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Split the batch of images and labels for towers.
        #images_splits = tf.split(axis=0, num_or_size_splits=helper.N_GPUS, value=distorted_images)
        #labels_splits = tf.split(axis=0, num_or_size_splits=helper.N_GPUS, value=distorted_labels)

        batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
            [distorted_images, distorted_labels], capacity=2 * helper.N_GPUS)

        # Calculate the gradients for each model tower.
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(helper.N_GPUS):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' %
                                       (helper.TOWER_NAME, i)) as scope:
                        # Calculate the loss for one tower of the CIFAR model. This function
                        # constructs the entire CIFAR model but shares the variables across
                        # all towers.
                        image_batch, label_batch = batch_queue.dequeue()
                        loss = tower_loss(scope,
                                          image_batch,
                                          label_batch,
                                          dropout=dropout,
                                          is_training=is_training)
                        #loss = tower_loss(scope, images_splits[i], labels_splits[i], dropout=dropout, is_training=is_training)

                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        tf.get_variable_scope().reuse_variables()

                        grads = opt.compute_gradients(loss)

                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Add a summaries for the input processing and global_step.
        summaries.extend(input_summaries)

        # Apply the gradients to adjust the shared variables.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            apply_gradient_op = opt.apply_gradients(grads,
                                                    global_step=global_step)

            # Track the moving averages of all trainable variables.
            variable_averages = tf.train.ExponentialMovingAverage(
                helper.MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())

            # Group all updates to into a single train op.
            #train_op = apply_gradient_op
            train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Add histograms for trainable variables.
        #for var in tf.trainable_variables():
        #    summaries.append(tf.summary.histogram(var.op.name, var))

        for grad, var in grads:
            summaries.append(tf.summary.histogram(var.op.name, var))
            #summaries.append(tf.summary.histogram(var.op.name + '_gradient', grad))

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        cross_entropy_op = tf.reduce_mean(tf.get_collection('cross_entropies'),
                                          name='cross_entropy')

        accuracy_op = tf.reduce_mean(tf.get_collection('accuracy'),
                                     name='accuracies')
        summaries.append(tf.summary.scalar('cross_entropy', cross_entropy_op))
        summaries.append(tf.summary.scalar('accuracy', accuracy_op))

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))

        #run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #run_metadata = tf.RunMetadata()

        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        if RESTORE == True:
            ckpt = tf.train.get_checkpoint_state(SAVE_POINT)
            saver.restore(sess, ckpt.model_checkpoint_path)

            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/imagenet_train/model.ckpt-0,
            # extract global_step from it.
            restored_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            print('Successfully loaded model from %s at step=%s.' %
                  (ckpt.model_checkpoint_path, restored_step))
            step = int(restored_step)
            range_step = range(step, helper.MAX_STEPS)
            tf.get_variable_scope().reuse_variables()
            global_step = tf.get_variable('global_step', trainable=False)
        else:
            range_step = range(helper.MAX_STEPS)

        summary_writer = tf.summary.FileWriter('summary', graph=sess.graph)
        num_params = helper.count_params() / 1e6
        print('Total number of params = %.2fM' % num_params)
        print("training")
        top1_error = [-1.0, -1.0]
        top1_step = 0
        top5_error = [-1.0, -1.0]
        top5_step = 0

        for step in range_step:

            start_time = time.time()
            _, loss_value, cross_entropy_value, accuracy_value = sess.run(
                [train_op, loss, cross_entropy_op, accuracy_op],
                feed_dict={
                    dropout: 0.8,
                    is_training: True
                }
            )  #, options=run_options, run_metadata=run_metadata)#, learning_rate: lr})
            duration = time.time() - start_time

            if step == step1 or step == step2:
                print('Decreasing Learning Rate')
                lr /= 10

            if step % 10 == 0:
                num_examples_per_step = helper.BATCH_SIZE
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration

                format_str = (
                    'step %d, loss = %.2f, cross entropy = %.2f, accuracy = %.2f, %.3f sec/batch'
                )
                print(format_str % (step, loss_value, cross_entropy_value,
                                    accuracy_value, sec_per_batch))
                """
                # Create the Timeline object, and write it to a json
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)
                """

            if step % 100 == 0:
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           dropout: 0.8,
                                           is_training: False
                                       })  #, learning_rate: lr})
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 5000 == 0 or (step + 1) == helper.MAX_STEPS:
                if step != 0:
                    checkpoint_path = SAVE_POINT + 'model.ckpt'
                    saver.save(sess, checkpoint_path, global_step=step)
                    print('Model saved')

                    #evaluate(distorted_images, distorted_labels, sess, dropout=dropout, is_training=is_training, train=True)
                    top1, top5 = evaluate(test_images,
                                          test_labels,
                                          sess,
                                          dropout=dropout,
                                          is_training=is_training,
                                          train=False)
                    if top1 > top1_error[0]:
                        top1_error[0] = top1
                        top1_error[1] = top5
                        top1_step = step
                    if top5 > top5_error[1]:
                        top5_error[0] = top1
                        top5_error[1] = top5
                        top5_step = step
                    print(
                        "Best top1 model achieved top1: %.4f, top5: %.4f at step %d"
                        % (top1_error[0], top1_error[1], top1_step))
                    print(
                        "Best top5 model achieved top1: %.4f, top5: %.4f at step %d"
                        % (top5_error[0], top5_error[1], top5_step))
Esempio n. 13
0
        # Calculate predictions.
        top_1_op = tf.nn.in_top_k(model.logits, labels, 1)
        top_5_op = tf.nn.in_top_k(model.logits, labels, 5)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        #saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               graph_def=graph_def)

        while True:
            _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op,
                       model, labels)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)


if __name__ == '__main__':
    dataset = ImagenetData(subset='validation')
    assert dataset.data_files()
    evaluate(dataset)
Esempio n. 14
0
    images = tf.cast(images, tf.float32)
    images = tf.reshape(images, shape=[batch_size, height, width, depth])

    # Display the training images in the visualizer.
    tf.summary.image('images', images)

    return images, tf.reshape(label_index_batch, [batch_size])



if __name__ == '__main__':
    
#     dataset = FlowersData(subset="validation")
#     images, labels = distorted_inputs(dataset)
    
    dataset = ImagenetData(subset="val")
    images, labels = distorted_inputs(dataset)
    
    print(images,labels)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        coord=tf.train.Coordinator()     
        threads= tf.train.start_queue_runners(coord=coord)
        
        for i in range(5):
            image_np,label_np=sess.run([images, labels])
            plt.imshow(image_np[0,:,:,:])
            plt.title('label name:'+str(label_np[0]))
            plt.show()
            
def main(argv=None):
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    tf.logging.info('PS hosts are: %s' % ps_hosts)
    tf.logging.info('Worker hosts are: %s' % worker_hosts)
    cluster_spec = tf.train.ClusterSpec({
        'ps': ps_hosts,
        'worker': worker_hosts
    })
    server = tf.train.Server({
        'ps': ps_hosts,
        'worker': worker_hosts
    },
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_id,
                             protocol=FLAGS.protocol)

    sspManager = SspManager(len(worker_hosts), 5)
    if FLAGS.job_name == 'ps':
        if FLAGS.task_id == 0:
            rpcServer = sspManager.create_rpc_server(ps_hosts[0].split(':')[0])
            rpcServer.serve()
        server.join()

    time.sleep(5)
    rpcClient = sspManager.create_rpc_client(ps_hosts[0].split(':')[0])

    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    is_chief = (FLAGS.task_id == 0)
    if is_chief:
        if not tf.gfile.Exists(FLAGS.train_dir):
            tf.gfile.MakeDirs(FLAGS.train_dir)

    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])

    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        with slim.scopes.arg_scope(
            [slim.variables.variable, slim.variables.global_step],
                device=slim.variables.VariableDeviceChooser(
                    num_parameter_servers)):
            '''Prepare Input'''
            global_step = slim.variables.global_step()
            batch_size = tf.placeholder(dtype=tf.int32,
                                        shape=(),
                                        name='batch_size')
            images, labels = image_processing.distorted_inputs(
                dataset,
                batch_size,
                num_preprocess_threads=FLAGS.num_preprocess_threads)
            num_classes = dataset.num_classes() + 1
            '''Inference'''
            logits = inception.inference(images,
                                         num_classes,
                                         for_training=True)
            '''Loss'''
            inception.loss(logits, labels, batch_size)
            losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
            losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            total_loss = tf.add_n(losses, name='total_loss')
            if is_chief:
                loss_averages = tf.train.ExponentialMovingAverage(0.9,
                                                                  name='avg')
                loss_averages_op = loss_averages.apply(losses + [total_loss])
                with tf.control_dependencies([loss_averages_op]):
                    total_loss = tf.identity(total_loss)
            '''Optimizer'''
            exp_moving_averager = tf.train.ExponentialMovingAverage(
                inception.MOVING_AVERAGE_DECAY, global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                     FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              FLAGS.num_epochs_per_decay / num_workers)
            lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                            global_step,
                                            decay_steps,
                                            FLAGS.learning_rate_decay_factor,
                                            staircase=True)
            opt = tf.train.RMSPropOptimizer(lr,
                                            RMSPROP_DECAY,
                                            momentum=RMSPROP_MOMENTUM,
                                            epsilon=RMSPROP_EPSILON)
            '''Train Operation'''
            batchnorm_updates = tf.get_collection(
                slim.ops.UPDATE_OPS_COLLECTION)
            assert batchnorm_updates, 'Batchnorm updates are missing'
            batchnorm_updates_op = tf.group(*batchnorm_updates)
            with tf.control_dependencies([batchnorm_updates_op]):
                total_loss = tf.identity(total_loss)
            naive_grads = opt.compute_gradients(total_loss)
            grads = [(tf.scalar_mul(
                tf.cast(batch_size / FLAGS.batch_size, tf.float32), grad), var)
                     for grad, var in naive_grads]
            apply_gradients_op = opt.apply_gradients(grads,
                                                     global_step=global_step)
            with tf.control_dependencies([apply_gradients_op]):
                train_op = tf.identity(total_loss, name='train_op')
            '''Supervisor and Session'''
            saver = tf.train.Saver()
            init_op = tf.global_variables_initializer()
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=FLAGS.train_dir,
                                     init_op=init_op,
                                     summary_op=None,
                                     global_step=global_step,
                                     recovery_wait_secs=1,
                                     saver=saver,
                                     save_model_secs=FLAGS.save_interval_secs)
            tf.logging.info('%s Supervisor' % datetime.now())
            sess_config = tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)
            sess = sv.prepare_or_wait_for_session(server.target,
                                                  config=sess_config)
            queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
            '''Start Training'''
            sv.start_queue_runners(sess, queue_runners)
            tf.logging.info('Started %d queues for processing input data.',
                            len(queue_runners))

            batch_size_num = FLAGS.batch_size
            for step in range(FLAGS.max_steps):
                start_time = time.time()
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                loss_value, gs = sess.run(
                    [train_op, global_step],
                    feed_dict={batch_size: batch_size_num},
                    options=run_options,
                    run_metadata=run_metadata)

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                duration = time.time() - start_time
                examples_per_sec = batch_size_num / float(duration)
                sec_per_batch = float(duration)
                format_str = (
                    "time: " + str(time.time()) +
                    '; %s: step %d (gs %d), loss= %.2f (%.1f samples/s; %.3f s/batch)'
                )
                tf.logging.info(format_str %
                                (datetime.now(), step, gs, loss_value,
                                 examples_per_sec, sec_per_batch))
                rpcClient.check_staleness(FLAGS.task_id, step)
def train():
    """Train Inception on a dataset for a number of steps."""
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    tf.logging.info('PS hosts are: %s' % ps_hosts)
    tf.logging.info('Worker hosts are: %s' % worker_hosts)

    cluster_spec = tf.train.ClusterSpec({
        'ps': ps_hosts,
        'worker': worker_hosts
    })
    server = tf.train.Server({
        'ps': ps_hosts,
        'worker': worker_hosts
    },
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_id,
                             protocol=FLAGS.protocol)
    batchSizeManager = BatchSizeManager(FLAGS.batch_size, len(worker_hosts))
    if FLAGS.job_name == 'ps':
        if FLAGS.task_id == 0:
            rpcServer = batchSizeManager.create_rpc_server(
                ps_hosts[0].split(':')[0])
            rpcServer.serve()
        server.join()

    dataset = ImagenetData(subset=FLAGS.subset)
    rpcClient = batchSizeManager.create_rpc_client(ps_hosts[0].split(':')[0])
    assert dataset.data_files()
    # Only the chief checks for or creates train_dir.
    if FLAGS.task_id == 0:
        if not tf.gfile.Exists(FLAGS.train_dir):
            tf.gfile.MakeDirs(FLAGS.train_dir)
    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])
    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    # Both should be greater than 0 in a distributed training.
    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')

    # Choose worker 0 as the chief. Note that any worker could be the chief
    # but there should be only one chief.
    is_chief = (FLAGS.task_id == 0)

    #batchSizeManager = BatchSizeManager(32, 4)

    # Ops are assigned to worker by default.
    tf.logging.info('cccc-num_parameter_servers:' + str(num_parameter_servers))
    partitioner = tf.fixed_size_partitioner(num_parameter_servers, 0)

    device_setter = tf.train.replica_device_setter(
        ps_tasks=num_parameter_servers)
    slim = tf.contrib.slim
    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        with tf.variable_scope('root', partitioner=partitioner):
            # Variables and its related init/assign ops are assigned to ps.
            #    with slim.arg_scope(
            #        [slim.variables.variable, slim.variables.global_step],
            #        device=slim.variables.VariableDeviceChooser(num_parameter_servers)):
            with tf.device(device_setter):
                #	partitioner=partitioner):
                # Create a variable to count the number of train() calls. This equals the
                # number of updates applied to the variables.
                #      global_step = slim.variables.global_step()
                global_step = tf.Variable(0, trainable=False)

                # Calculate the learning rate schedule.

                batch_size = tf.placeholder(dtype=tf.int32,
                                            shape=(),
                                            name='batch_size')
                num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                         FLAGS.batch_size)
                # Decay steps need to be divided by the number of replicas to aggregate.
                decay_steps = int(num_batches_per_epoch *
                                  FLAGS.num_epochs_per_decay /
                                  num_replicas_to_aggregate)

                # Decay the learning rate exponentially based on the number of steps.
                lr = tf.train.exponential_decay(
                    FLAGS.initial_learning_rate,
                    global_step,
                    decay_steps,
                    FLAGS.learning_rate_decay_factor,
                    staircase=True)
                # Add a summary to track the learning rate.
                #      tf.summary.scalar('learning_rate', lr)

                # Create an optimizer that performs gradient descent.

                images, labels = image_processing.distorted_inputs(
                    dataset,
                    batch_size,
                    num_preprocess_threads=FLAGS.num_preprocess_threads)
                print(images.get_shape())
                print(labels.get_shape())

                # Number of classes in the Dataset label set plus 1.
                # Label 0 is reserved for an (unused) background class.
                #      num_classes = dataset.num_classes() + 1
                num_classes = dataset.num_classes()
                print(num_classes)
                #      logits = inception.inference(images, num_classes, for_training=True)
                network_fn = nets_factory.get_network_fn(
                    'inception_v3', num_classes=num_classes)
                (logits, _) = network_fn(images)
                print(logits.get_shape())
                # Add classification loss.
                #      inception.loss(logits, labels, batch_size)

                # Gather all of the losses including regularization losses.
                labels = tf.one_hot(labels, 1000, 1, 0)
                cross_entropy = tf.losses.softmax_cross_entropy(
                    logits=logits, onehot_labels=labels)
                #      losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
                #      losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                total_loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
                    [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

                #      total_loss = tf.add_n(losses, name='total_loss')

                loss_averages = tf.train.ExponentialMovingAverage(0.9,
                                                                  name='avg')
                loss_averages_op = loss_averages.apply(losses + [total_loss])

                with tf.control_dependencies([loss_averages_op]):
                    opt = tf.train.RMSPropOptimizer(lr,
                                                    RMSPROP_DECAY,
                                                    momentum=RMSPROP_MOMENTUM,
                                                    epsilon=RMSPROP_EPSILON)
                    grads0 = opt.compute_gradients(total_loss)
                    grads = [(tf.scalar_mul(
                        tf.cast(batch_size / FLAGS.batch_size, tf.float32),
                        grad), var) for grad, var in grads0]
                    total_loss = tf.identity(total_loss)

                exp_moving_averager = tf.train.ExponentialMovingAverage(
                    MOVING_AVERAGE_DECAY, global_step)
                variables_averages_op = exp_moving_averager.apply(
                    tf.trainable_variables())

                apply_gradients_op = opt.apply_gradients(
                    grads, global_step=global_step)

                with tf.control_dependencies(
                    [apply_gradients_op, variables_averages_op]):
                    train_op = tf.identity(total_loss, name='train_op')

                # Get chief queue_runners and init_tokens, which is used to synchronize
                # replicas. More details can be found in SyncReplicasOptimizer.
#      chief_queue_runners = [opt.get_chief_queue_runner()]
#      init_tokens_op = opt.get_init_tokens_op()

# Create a saver.
                saver = tf.train.Saver()

                # Build the summary operation based on the TF collection of Summaries.
                #      summary_op = tf.summary.merge_all()

                # Build an initialization operation to run below.
                init_op = tf.global_variables_initializer()

                # We run the summaries in the same thread as the training operations by
                # passing in None for summary_op to avoid a summary_thread being started.
                # Running summaries and training operations in parallel could run out of
                # GPU memory.
                sv = tf.train.Supervisor(
                    is_chief=is_chief,
                    logdir=FLAGS.train_dir,
                    init_op=init_op,
                    summary_op=None,
                    global_step=global_step,
                    recovery_wait_secs=1,
                    saver=None,
                    save_model_secs=FLAGS.save_interval_secs)

                tf.logging.info('%s Supervisor' % datetime.now())

                sess_config = tf.ConfigProto(
                    allow_soft_placement=True,
                    log_device_placement=FLAGS.log_device_placement)

                # Get a session.
                sess = sv.prepare_or_wait_for_session(server.target,
                                                      config=sess_config)

                # Start the queue runners.
                queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
                sv.start_queue_runners(sess, queue_runners)
                tf.logging.info('Started %d queues for processing input data.',
                                len(queue_runners))

                #      if is_chief:
                #        sv.start_queue_runners(sess, chief_queue_runners)
                #        sess.run(init_tokens_op)

                # Train, checking for Nans. Concurrently run the summary operation at a
                # specified interval. Note that the summary_op and train_op never run
                # simultaneously in order to prevent running out of GPU memory.
                #      next_summary_time = time.time() + FLAGS.save_summaries_secs
                step = 0
                time0 = time.time()
                batch_size_num = 1
                while not sv.should_stop():
                    try:
                        start_time = time.time()

                        batch_size_num = 32
                        #	   batch_size_num = int((int(step)/3*10)) % 100000 + 1
                        #          if step < 5:
                        #            batch_size_num = 32
                        #          batch_size_num = (batch_size_num ) % 64 + 1
                        #          else:
                        #            batch_size_num = 80

                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        my_images, loss_value, step = sess.run(
                            [images, train_op, global_step],
                            feed_dict={batch_size: batch_size_num},
                            options=run_options,
                            run_metadata=run_metadata)
                        b = time.time()
                        #          assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
                        if step > FLAGS.max_steps:
                            break
                        duration = time.time() - start_time
                        #	  thread = threading2.Thread(target=get_computation_time, name="get_computation_time",args=(run_metadata.step_stats,step,))
                        #	  thread.start()
                        #          tl = timeline.Timeline(run_metadata.step_stats)
                        #          last_batch_time = tl.get_local_step_duration('sync_token_q_Dequeue')
                        c0 = time.time()
                        #          batch_size_num = batchSizeManager.dictate_new_batch_size(FLAGS.task_id, last_batch_time)
                        #          batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, last_batch_time, available_cpu, available_memory, step, batch_size_num)
                        # batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, 0,0,0, step, batch_size_num)
                        #          ctf = tl.generate_chrome_trace_format()
                        #          with open("timeline.json", 'a') as f:
                        #            f.write(ctf)

                        if step % 1 == 0:
                            examples_per_sec = FLAGS.batch_size / float(
                                duration)
                            c = time.time()
                            tf.logging.info("time statistics" +
                                            " - train_time: " +
                                            str(b - start_time) +
                                            " - get_batch_time: " +
                                            str(c0 - b) + " - get_bs_time:  " +
                                            str(c - c0) + " - accum_time: " +
                                            str(c - time0) +
                                            " - batch_size: " +
                                            str(batch_size_num))
                            format_str = (
                                'Worker %d: %s: step %d, loss = %.2f'
                                '(%.1f examples/sec; %.3f  sec/batch)')
                            tf.logging.info(
                                format_str %
                                (FLAGS.task_id, datetime.now(), step,
                                 loss_value, examples_per_sec, duration))

                        # Determine if the summary_op should be run on the chief worker.
#          if is_chief and next_summary_time < time.time():
#            tf.logging.info('Running Summary operation on the chief.')
#            summary_str = sess.run(summary_op)
#            sv.summary_computed(sess, summary_str)
#            tf.logging.info('Finished running Summary operation.')

# Determine the next time for running the summary.
#            next_summary_time += FLAGS.save_summaries_secs
                    except:
                        if is_chief:
                            tf.logging.info(
                                'Chief got exception while running!')
                        raise

                # Stop the supervisor.  This also waits for service threads to finish.
                sv.stop()
Esempio n. 17
0
FLAGS = tf.app.flags.FLAGS

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('data_dir', 'imagenet-data', """XXXX""")
tf.app.flags.DEFINE_string('train_dir', 'log',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 700001,
                            """Number of iterations to run.""")
tf.app.flags.DEFINE_string('model_file', 'model/DCNet_', """Directory to save model""")


is_training = tf.placeholder("bool")

train_set = ImagenetData(subset='train')
tr_images, tr_labels = alex2012_image_processing.distorted_inputs(train_set)

val_set  = ImagenetData(subset='validation')
val_images, val_labels = alex2012_image_processing.inputs(val_set)

images, labels = tf.cond(is_training, lambda: [tr_images, tr_labels], lambda: [val_images, val_labels])

cnn = VGG()
cnn.build(images, train_set.num_classes(), is_training)

fit_loss = loss2(cnn.score, labels, train_set.num_classes(), 'c_entropy') 
reg_loss = tf.add_n(tf.losses.get_regularization_losses())
orth_loss = tf.add_n(tf.get_collection('orth_constraint'))
loss_op = fit_loss + orth_loss + reg_loss
Esempio n. 18
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Learning rate
        lr = .001

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(lr)

        split_batch_size = int(helper.BATCH_SIZE / helper.N_GPUS)
        num_preprocess_threads = helper.NUM_THREADS * helper.N_GPUS

        # Get images and labels for CIFAR-10.
        dataset = ImagenetData(subset='train')
        assert dataset.data_files()

        assert helper.BATCH_SIZE % helper.N_GPUS == 0, (
            'Batch size must be divisible by number of GPUs')
        split_batch_size = int(helper.BATCH_SIZE / helper.N_GPUS)

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        num_preprocess_threads = helper.NUM_THREADS * helper.N_GPUS
        images, labels = image_processing.distorted_inputs(
            dataset,
            batch_size=helper.BATCH_SIZE,
            num_preprocess_threads=num_preprocess_threads)

        # Split the batch of images and labels for towers.
        images_splits = tf.split(axis=0,
                                 num_or_size_splits=helper.N_GPUS,
                                 value=images)
        labels_splits = tf.split(axis=0,
                                 num_or_size_splits=helper.N_GPUS,
                                 value=labels)

        # Calculate the gradients for each model tower.
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(helper.N_GPUS):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' %
                                       (helper.TOWER_NAME, i)) as scope:
                        # Calculate the loss for one tower of the CIFAR model. This function
                        # constructs the entire CIFAR model but shares the variables across
                        # all towers.
                        loss = tower_loss(scope, images_splits[i],
                                          labels_splits[i])

                        tf.get_variable_scope().reuse_variables()

                        grads = opt.compute_gradients(loss)

                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Track the moving averages of all trainable variables.
        variable_averages = tf.train.ExponentialMovingAverage(
            helper.MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(
            tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)
        print("training")

        #for epoch in range(helper.MAX_EPOCH):
        for epoch in range(helper.MAX_STEPS):

            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            num_examples_per_step = helper.BATCH_SIZE * helper.N_GPUS
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = duration / helper.N_GPUS

            format_str = (
                '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
            print(format_str % (datetime.now(), i, loss_value,
                                examples_per_sec, sec_per_batch))