Esempio n. 1
0
def inference(images,
              num_classes,
              for_training=False,
              restore_logits=True,
              scope=None):
    # Parameters for BatchNorm.
    batch_norm_params = {
        # Decay for the moving averages.
        'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
    }
    # Set weight_decay for weights in Conv and FC layers.
    with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
        with slim.arg_scope([slim.ops.conv2d],
                            stddev=0.1,
                            activation=tf.nn.relu,
                            batch_norm_params=batch_norm_params):
            logits, endpoints, logits_2048 = slim.inception.inception_v3(
                images,
                dropout_keep_prob=0.8,
                num_classes=num_classes,
                is_training=for_training,
                restore_logits=restore_logits,
                scope=scope)

    # Grab the logits associated with the side head. Employed during training.
    auxiliary_logits = endpoints['aux_logits']

    return logits, auxiliary_logits, logits_2048
Esempio n. 2
0
def inference(images,
              num_classes,
              num_of_exs,
              for_training=False,
              restore_logits=True,
              scope=None):
    # Parameters for BatchNorm.
    batch_norm_params = {
        # Decay for the moving averages.
        'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
    }
    # Set weight_decay for weights in Conv and FC layers.
    with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
        with slim.arg_scope([slim.ops.conv2d],
                            activation=tf.nn.relu,
                            batch_norm_params=batch_norm_params):
            logits, dssm, endpoints = slim.nin_dssm.nin_dssm(
                images,
                num_classes=num_classes,
                num_of_exs=num_of_exs,
                is_training=for_training,
                restore_logits=restore_logits,
                scope=scope)

    # Add summaries for viewing model statistics on TensorBoard.
    _activation_summaries(endpoints)

    # Grab the logits associated with the side head. Employed during training.
    # auxiliary_logits = endpoints['aux_logits']

    # TODO add endpoints for extract features.
    return [logits, dssm], endpoints
Esempio n. 3
0
def inference(images,
              num_classes,
              for_training=False,
              restore_logits=True,
              scope=None):
    """Build Inception v3 model architecture.

  See here for reference: http://arxiv.org/abs/1512.00567

  Args:
    images: Images returned from inputs() or distorted_inputs().
    num_classes: number of classes
    for_training: If set to `True`, build the inference model for training.
      Kernels that operate differently for inference during training
      e.g. dropout, are appropriately configured.
    restore_logits: whether or not the logits layers should be restored.
      Useful for fine-tuning a model with different num_classes.
    scope: optional prefix string identifying the ImageNet tower.

  Returns:
    Logits. 2-D float Tensor.
    Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
  """
    # Parameters for BatchNorm.
    batch_norm_params = {
        # Decay for the moving averages.
        'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
    }
    # prepocessing
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)

    # Set weight_decay for weights in Conv and FC layers.
    with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
        with slim.arg_scope([slim.ops.conv2d],
                            stddev=0.1,
                            activation=tf.nn.relu,
                            batch_norm_params=batch_norm_params):
            logits, endpoints = slim.inception.inception_v3(
                images,
                dropout_keep_prob=0.8,
                num_classes=num_classes,
                is_training=for_training,
                restore_logits=restore_logits,
                scope=scope)

    # Add summaries for viewing model statistics on TensorBoard.
    _activation_summaries(endpoints)

    # Grab the logits associated with the side head. Employed during training.
    auxiliary_logits = endpoints['aux_logits']

    #return logits, auxiliary_logits
    return logits, auxiliary_logits, endpoints['mixed_35x35x288b']
Esempio n. 4
0
def inference(images, num_classes, for_training=False, restore_logits=True,
              scope=None):
  """Build Inception v3 model architecture.

  See here for reference: http://arxiv.org/abs/1512.00567

  Args:
    images: Images returned from inputs() or distorted_inputs().
    num_classes: number of classes
    for_training: If set to `True`, build the inference model for training.
      Kernels that operate differently for inference during training
      e.g. dropout, are appropriately configured.
    restore_logits: whether or not the logits layers should be restored.
      Useful for fine-tuning a model with different num_classes.
    scope: optional prefix string identifying the ImageNet tower.

  Returns:
    Logits. 2-D float Tensor.
    Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
  """
  # Parameters for BatchNorm.
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
      # epsilon to prevent 0s in variance.
      'epsilon': 0.001,
  }
  # Set weight_decay for weights in Conv and FC layers.
  with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.0001):
    with slim.arg_scope([slim.ops.conv2d],
                        stddev=0.1,
                        activation=tf.nn.relu,
                        batch_norm_params=batch_norm_params):
#      with.slim.arg_scope([slim.variables.variable], device='/cpu:0'):
      logits, endpoints = slim.inception.inception_v3(
          images,
          dropout_keep_prob=0.8,
          num_classes=num_classes,
          is_training=for_training,
          restore_logits=restore_logits,
          scope=scope)

  # Add summaries for viewing model statistics on TensorBoard.
  _activation_summaries(endpoints)

  # Grab the logits associated with the side head. Employed during training.
#  auxiliary_logits = endpoints['aux_logits']
  auxiliary_logits = None

  return logits, auxiliary_logits
Esempio n. 5
0
 def testVariablesByLayer(self):
   batch_size = 5
   height, width = 299, 299
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.ops.conv2d],
                         batch_norm_params={'decay': 0.9997}):
       slim.inception.inception_v3(inputs)
     self.assertEqual(len(get_variables()), 388)
     self.assertEqual(len(get_variables('conv0')), 4)
     self.assertEqual(len(get_variables('conv1')), 4)
     self.assertEqual(len(get_variables('conv2')), 4)
     self.assertEqual(len(get_variables('conv3')), 4)
     self.assertEqual(len(get_variables('conv4')), 4)
     self.assertEqual(len(get_variables('mixed_35x35x256a')), 28)
     self.assertEqual(len(get_variables('mixed_35x35x288a')), 28)
     self.assertEqual(len(get_variables('mixed_35x35x288b')), 28)
     self.assertEqual(len(get_variables('mixed_17x17x768a')), 16)
     self.assertEqual(len(get_variables('mixed_17x17x768b')), 40)
     self.assertEqual(len(get_variables('mixed_17x17x768c')), 40)
     self.assertEqual(len(get_variables('mixed_17x17x768d')), 40)
     self.assertEqual(len(get_variables('mixed_17x17x768e')), 40)
     self.assertEqual(len(get_variables('mixed_8x8x2048a')), 36)
     self.assertEqual(len(get_variables('mixed_8x8x2048b')), 36)
     self.assertEqual(len(get_variables('logits')), 2)
     self.assertEqual(len(get_variables('aux_logits')), 10)
Esempio n. 6
0
 def testVariablesByLayer(self):
     batch_size = 5
     height, width = 299, 299
     with self.test_session():
         inputs = tf.random_uniform((batch_size, height, width, 3))
         with slim.arg_scope([slim.ops.conv2d],
                             batch_norm_params={'decay': 0.9997}):
             slim.inception.inception_v3(inputs)
         self.assertEqual(len(get_variables()), 388)
         self.assertEqual(len(get_variables('conv0')), 4)
         self.assertEqual(len(get_variables('conv1')), 4)
         self.assertEqual(len(get_variables('conv2')), 4)
         self.assertEqual(len(get_variables('conv3')), 4)
         self.assertEqual(len(get_variables('conv4')), 4)
         self.assertEqual(len(get_variables('mixed_35x35x256a')), 28)
         self.assertEqual(len(get_variables('mixed_35x35x288a')), 28)
         self.assertEqual(len(get_variables('mixed_35x35x288b')), 28)
         self.assertEqual(len(get_variables('mixed_17x17x768a')), 16)
         self.assertEqual(len(get_variables('mixed_17x17x768b')), 40)
         self.assertEqual(len(get_variables('mixed_17x17x768c')), 40)
         self.assertEqual(len(get_variables('mixed_17x17x768d')), 40)
         self.assertEqual(len(get_variables('mixed_17x17x768e')), 40)
         self.assertEqual(len(get_variables('mixed_8x8x2048a')), 36)
         self.assertEqual(len(get_variables('mixed_8x8x2048b')), 36)
         self.assertEqual(len(get_variables('logits')), 2)
         self.assertEqual(len(get_variables('aux_logits')), 10)
Esempio n. 7
0
 def testTotalLossWithRegularization(self):
     batch_size = 5
     height, width = 299, 299
     num_classes = 1000
     with self.test_session():
         inputs = tf.random_uniform((batch_size, height, width, 3))
         dense_labels = tf.random_uniform((batch_size, num_classes))
         with slim.arg_scope([slim.ops.conv2d, slim.ops.fc],
                             weight_decay=0.00004):
             logits, end_points = slim.inception.inception_v3(
                 inputs, num_classes)
             # Cross entropy loss for the main softmax prediction.
             slim.losses.cross_entropy_loss(logits,
                                            dense_labels,
                                            label_smoothing=0.1,
                                            weight=1.0)
             # Cross entropy loss for the auxiliary softmax head.
             slim.losses.cross_entropy_loss(end_points['aux_logits'],
                                            dense_labels,
                                            label_smoothing=0.1,
                                            weight=0.4,
                                            scope='aux_loss')
         losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
         self.assertEqual(len(losses), 2)
         reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
         self.assertEqual(len(reg_losses), 98)
Esempio n. 8
0
 def testRegularizationLosses(self):
   batch_size = 5
   height, width = 299, 299
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
       slim.inception.inception_v3(inputs)
     losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
     self.assertEqual(len(losses), len(get_variables_by_name('weights')))
Esempio n. 9
0
 def testRegularizationLosses(self):
   batch_size = 5
   height, width = 299, 299
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
       slim.inception.inception_v3(inputs)
     losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
     self.assertEqual(len(losses), len(get_variables_by_name('weights')))
Esempio n. 10
0
 def testVariablesToRestoreWithoutLogits(self):
   batch_size = 5
   height, width = 299, 299
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.ops.conv2d],
                         batch_norm_params={'decay': 0.9997}):
       slim.inception.inception_v3(inputs, restore_logits=False)
     variables_to_restore = tf.get_collection(
         slim.variables.VARIABLES_TO_RESTORE)
     self.assertEqual(len(variables_to_restore), 384)
Esempio n. 11
0
 def testVariablesToRestoreWithoutLogits(self):
     batch_size = 5
     height, width = 299, 299
     with self.test_session():
         inputs = tf.random_uniform((batch_size, height, width, 3))
         with slim.arg_scope([slim.ops.conv2d],
                             batch_norm_params={'decay': 0.9997}):
             slim.inception.inception_v3(inputs, restore_logits=False)
         variables_to_restore = tf.get_collection(
             slim.variables.VARIABLES_TO_RESTORE)
         self.assertEqual(len(variables_to_restore), 384)
Esempio n. 12
0
 def testVariablesWithoutBatchNorm(self):
     batch_size = 5
     height, width = 299, 299
     with self.test_session():
         inputs = tf.random_uniform((batch_size, height, width, 3))
         with slim.arg_scope([slim.ops.conv2d], batch_norm_params=None):
             slim.inception.inception_v3(inputs)
         self.assertEqual(len(get_variables()), 196)
         self.assertEqual(len(get_variables_by_name('weights')), 98)
         self.assertEqual(len(get_variables_by_name('biases')), 98)
         self.assertEqual(len(get_variables_by_name('beta')), 0)
         self.assertEqual(len(get_variables_by_name('gamma')), 0)
         self.assertEqual(len(get_variables_by_name('moving_mean')), 0)
         self.assertEqual(len(get_variables_by_name('moving_variance')), 0)
Esempio n. 13
0
 def testVariablesWithoutBatchNorm(self):
   batch_size = 5
   height, width = 299, 299
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     with slim.arg_scope([slim.ops.conv2d],
                         batch_norm_params=None):
       slim.inception.inception_v3(inputs)
     self.assertEqual(len(get_variables()), 196)
     self.assertEqual(len(get_variables_by_name('weights')), 98)
     self.assertEqual(len(get_variables_by_name('biases')), 98)
     self.assertEqual(len(get_variables_by_name('beta')), 0)
     self.assertEqual(len(get_variables_by_name('gamma')), 0)
     self.assertEqual(len(get_variables_by_name('moving_mean')), 0)
     self.assertEqual(len(get_variables_by_name('moving_variance')), 0)
Esempio n. 14
0
 def testTotalLossWithRegularization(self):
   batch_size = 5
   height, width = 299, 299
   num_classes = 1000
   with self.test_session():
     inputs = tf.random_uniform((batch_size, height, width, 3))
     dense_labels = tf.random_uniform((batch_size, num_classes))
     with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
       logits, end_points = slim.inception.inception_v3(inputs, num_classes)
       # Cross entropy loss for the main softmax prediction.
       slim.losses.cross_entropy_loss(logits,
                                      dense_labels,
                                      label_smoothing=0.1,
                                      weight=1.0)
       # Cross entropy loss for the auxiliary softmax head.
       slim.losses.cross_entropy_loss(end_points['aux_logits'],
                                      dense_labels,
                                      label_smoothing=0.1,
                                      weight=0.4,
                                      scope='aux_loss')
     losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
     self.assertEqual(len(losses), 2)
     reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
     self.assertEqual(len(reg_losses), 98)
def train(dataset):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                 FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.RMSPropOptimizer(lr,
                                        RMSPROP_DECAY,
                                        momentum=RMSPROP_MOMENTUM,
                                        epsilon=RMSPROP_EPSILON)

        # Get images and labels for ImageNet and split the batch across GPUs.
        assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')
        split_batch_size = int(FLAGS.batch_size / FLAGS.num_gpus)

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
        images, labels = image_processing.distorted_inputs(
            dataset, num_preprocess_threads=num_preprocess_threads)

        input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        num_classes = dataset.num_classes() + 1

        # Split the batch of images and labels for towers.
        images_splits = tf.split(0, FLAGS.num_gpus, images)
        labels_splits = tf.split(0, FLAGS.num_gpus, labels)

        # Calculate the gradients for each model tower.
        tower_grads = []
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' %
                                   (inception.TOWER_NAME, i)) as scope:
                    # Force all Variables to reside on the CPU.
                    with slim.arg_scope([slim.variables.variable],
                                        device='/cpu:0'):
                        # Calculate the loss for one tower of the ImageNet model. This
                        # function constructs the entire ImageNet model but shares the
                        # variables across all towers.
                        loss = _tower_loss(images_splits[i], labels_splits[i],
                                           num_classes, scope)

                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()

                    # Retain the summaries from the final tower.
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)

                    # Retain the Batch Normalization updates operations only from the
                    # final tower. Ideally, we should grab the updates from all towers
                    # but these stats accumulate extremely fast so we can ignore the
                    # other stats from the other towers without significant detriment.
                    batchnorm_updates = tf.get_collection(
                        slim.ops.UPDATE_OPS_COLLECTION, scope)

                    # Calculate the gradients for the batch of data on this ImageNet
                    # tower.
                    grads = opt.compute_gradients(loss)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = _average_gradients(tower_grads)

        # Add a summaries for the input processing and global_step.
        summaries.extend(input_summaries)

        # Add a summary to track the learning rate.
        summaries.append(tf.scalar_summary('learning_rate', lr))

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.histogram_summary(var.op.name + '/gradients', grad))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.histogram_summary(var.op.name, var))

        # Track the moving averages of all trainable variables.
        # Note that we maintain a "double-average" of the BatchNormalization
        # global statistics. This is more complicated then need be but we employ
        # this for backward-compatibility with our previous models.
        variable_averages = tf.train.ExponentialMovingAverage(
            inception.MOVING_AVERAGE_DECAY, global_step)

        # Another possiblility is to use tf.slim.get_variables().
        variables_to_average = (tf.trainable_variables() +
                                tf.moving_average_variables())
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all updates to into a single train op.
        batchnorm_updates_op = tf.group(*batchnorm_updates)
        train_op = tf.group(apply_gradient_op, variables_averages_op,
                            batchnorm_updates_op)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.merge_summary(summaries)

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        if FLAGS.pretrained_model_checkpoint_path:
            assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
            variables_to_restore = tf.get_collection(
                slim.variables.VARIABLES_TO_RESTORE)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(
            FLAGS.train_dir,
            graph_def=sess.graph.as_graph_def(add_shapes=True))

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                examples_per_sec = FLAGS.batch_size / float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, duration))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def inference(images,
              num_classes,
              net,
              for_training=False,
              restore_logits=True,
              scope=None):
    """Build Inception v3 model architecture.

  See here for reference: http://arxiv.org/abs/1512.00567

  Args:
    images: Images returned from inputs() or distorted_inputs().
    num_classes: number of classes
    for_training: If set to `True`, build the inference model for training.
      Kernels that operate differently for inference during training
      e.g. dropout, are appropriately configured.
    restore_logits: whether or not the logits layers should be restored.
      Useful for fine-tuning a model with different num_classes.
    scope: optional prefix string identifying the ImageNet tower.

  Returns:
    Logits. 2-D float Tensor.
    Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
  """
    # Parameters for BatchNorm.
    batch_norm_params = {
        # Decay for the moving averages.
        'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
    }
    if batch_norm_params:
        print("INFO: batch_norm_params is initialized for slim.ops.conv2d")
    # Set weight_decay for weights in Conv and FC layers.
    with slim.arg_scope([slim.ops.conv2d, slim.ops.fc],
                        weight_decay=FLAGS.weight_decay
                        ):  # default 0.00004 for inception_v3
        with slim.arg_scope([slim.ops.conv2d],
                            stddev=0.1,
                            activation=tf.nn.relu,
                            batch_norm_params=batch_norm_params):
            if net == 'inception_v3':
                logits, endpoints = slim.inception.inception_v3(
                    images,
                    dropout_keep_prob=FLAGS.dropout_keep_prob,
                    num_classes=num_classes,
                    is_training=for_training,
                    restore_logits=restore_logits,
                    seed=FLAGS.seed,
                    scope=scope)
            else:
                method_to_call = getattr(slim.models, net)
                logits, endpoints = method_to_call(
                    images,
                    dropout_keep_prob=FLAGS.dropout_keep_prob,
                    num_classes=num_classes,
                    is_training=for_training,
                    restore_logits=restore_logits,
                    seed=FLAGS.seed,
                    weight_decay=FLAGS.weight_decay,
                    scope=scope)
            #else:
            #    raise ValueError("Wrong net type:{}".format(net))

    # Add summaries for viewing model statistics on TensorBoard.
    # _activation_summaries(endpoints)

    # Grab the logits associated with the side head. Employed during training.
    auxiliary_logits = endpoints['aux_logits']

    return logits, auxiliary_logits
Esempio n. 17
0
     # Split the batch of images and labels for towers.
    images_splits = tf.split(0, FLAGS.num_gpus, images)
    labels_splits = tf.split(0, FLAGS.num_gpus, labels)

    # Calculate the gradients for each model tower.
    tower_grads = []
<<<<<<< HEAD
    for i in range(FLAGS.num_gpus):
=======
    reuse_variables = None
    for i in xrange(FLAGS.num_gpus):
>>>>>>> remote
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
          # Force all Variables to reside on the CPU.
          with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
            # Calculate the loss for one tower of the ImageNet model. This
            # function constructs the entire ImageNet model but shares the
            # variables across all towers.
            loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
                               scope, reuse_variables)

          # Reuse variables for the next tower.
          reuse_variables = True

          # Retain the summaries from the final tower.
          summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

          # Retain the Batch Normalization updates operations only from the
          # final tower. Ideally, we should grab the updates from all towers
          # but these stats accumulate extremely fast so we can ignore the
Esempio n. 18
0
def train(dataset):
  """Train on dataset for a number of steps."""
  with tf.Graph().as_default(), tf.device('/cpu:0'):
    # Create a variable to count the number of train() calls. This equals the
    # number of batches processed * FLAGS.num_gpus.
    global_step = tf.get_variable(
        'global_step', [],
        initializer=tf.constant_initializer(0), trainable=False)

    # Calculate the learning rate schedule.
    num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                             FLAGS.batch_size)
    decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                    global_step,
                                    decay_steps,
                                    FLAGS.learning_rate_decay_factor,
                                    staircase=True)

    # Create an optimizer that performs gradient descent.
    opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
                                    momentum=RMSPROP_MOMENTUM,
                                    epsilon=RMSPROP_EPSILON)

    # Get images and labels for ImageNet and split the batch across GPUs.
    assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
        'Batch size must be divisible by number of GPUs')
    split_batch_size = int(FLAGS.batch_size / FLAGS.num_gpus)

    # Override the number of preprocessing threads to account for the increased
    # number of GPU towers.
    num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
    images, labels = image_processing.distorted_inputs(
        dataset,
        num_preprocess_threads=num_preprocess_threads)

    input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Number of classes in the Dataset label set plus 1.
    # Label 0 is reserved for an (unused) background class.
    num_classes = dataset.num_classes() + 1

     # Split the batch of images and labels for towers.
    images_splits = tf.split(0, FLAGS.num_gpus, images)
    labels_splits = tf.split(0, FLAGS.num_gpus, labels)

    # Calculate the gradients for each model tower.
    tower_grads = []
    reuse_variables = None
    for i in xrange(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
          # Force all Variables to reside on the CPU.
          with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
            # Calculate the loss for one tower of the ImageNet model. This
            # function constructs the entire ImageNet model but shares the
            # variables across all towers.
            loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
                               scope, reuse_variables)

          # Reuse variables for the next tower.
          reuse_variables = True

          # Retain the summaries from the final tower.
          summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

          # Retain the Batch Normalization updates operations only from the
          # final tower. Ideally, we should grab the updates from all towers
          # but these stats accumulate extremely fast so we can ignore the
          # other stats from the other towers without significant detriment.
          batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
                                                scope)

          # Calculate the gradients for the batch of data on this ImageNet
          # tower.
          grads = opt.compute_gradients(loss)

          # Keep track of the gradients across all towers.
          tower_grads.append(grads)

    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grads = _average_gradients(tower_grads)

    # Add a summaries for the input processing and global_step.
    summaries.extend(input_summaries)

    # Add a summary to track the learning rate.
    summaries.append(tf.scalar_summary('learning_rate', lr))

    # Add histograms for gradients.
    for grad, var in grads:
      if grad is not None:
        summaries.append(
            tf.histogram_summary(var.op.name + '/gradients', grad))

    # Apply the gradients to adjust the shared variables.
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    # Add histograms for trainable variables.
    for var in tf.trainable_variables():
      summaries.append(tf.histogram_summary(var.op.name, var))

    # Track the moving averages of all trainable variables.
    # Note that we maintain a "double-average" of the BatchNormalization
    # global statistics. This is more complicated then need be but we employ
    # this for backward-compatibility with our previous models.
    variable_averages = tf.train.ExponentialMovingAverage(
        inception.MOVING_AVERAGE_DECAY, global_step)

    # Another possiblility is to use tf.slim.get_variables().
    variables_to_average = (tf.trainable_variables() +
                            tf.moving_average_variables())
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all updates to into a single train op.
    batchnorm_updates_op = tf.group(*batchnorm_updates)
    train_op = tf.group(apply_gradient_op, variables_averages_op,
                        batchnorm_updates_op)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation from the last tower summaries.
    summary_op = tf.merge_summary(summaries)

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to build towers on GPU, as some of the ops do not have GPU
    # implementations.
    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    if FLAGS.pretrained_model_checkpoint_path:
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      variables_to_restore = tf.get_collection(
          slim.variables.VARIABLES_TO_RESTORE)
      restorer = tf.train.Saver(variables_to_restore)
      restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
      print('%s: Pre-trained model restored from %s' %
            (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(
        FLAGS.train_dir,
        graph_def=sess.graph.as_graph_def(add_shapes=True))

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        examples_per_sec = FLAGS.batch_size / float(duration)
        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print(format_str % (datetime.now(), step, loss_value,
                            examples_per_sec, duration))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 19
0
def train(dataset):
  """Train on dataset for a number of steps."""
  with tf.Graph().as_default(), tf.device('/cpu:0'):
    tf.set_random_seed(FLAGS.seed)
    if FLAGS.num_nodes > 0:
      num_nodes = FLAGS.num_nodes
    else:
      num_nodes = FLAGS.num_gpus
    # Create a variable to count the number of train() calls. This equals the
    # number of batches processed * FLAGS.num_nodes.
    global_step = tf.get_variable(
        'global_step', [],
        initializer=tf.constant_initializer(0), trainable=False)

    # Calculate the learning rate schedule.
    num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                             FLAGS.batch_size)
    decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

    # Decay the learning rate exponentially based on the number of steps.
    if ('fixed'==FLAGS.learning_rate_decay_type or 'adam' == FLAGS.optimizer):
      lr = FLAGS.initial_learning_rate
    elif 'exponential'==FLAGS.learning_rate_decay_type:
      lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                    global_step/num_nodes,
                                    decay_steps,
                                    FLAGS.learning_rate_decay_factor,
                                    staircase=True)
    elif 'polynomial'==FLAGS.learning_rate_decay_type:
      lr = tf.train.polynomial_decay(FLAGS.initial_learning_rate,
                                    global_step/num_nodes,
                                    FLAGS.max_steps,
                                    end_learning_rate=0.0,
                                    power=0.5)
    else:
      raise ValueError('Wrong learning_rate_decay_type!')

    # Create an optimizer that performs gradient descent.
    opt = None
    if ('gd' == FLAGS.optimizer):
        opt = tf.train.GradientDescentOptimizer(lr)
    elif ('momentum' == FLAGS.optimizer):
        opt = tf.train.MomentumOptimizer(lr, FLAGS.momentum)
    elif ('adam' == FLAGS.optimizer):
        opt = tf.train.AdamOptimizer(lr)
    elif ('rmsprop' == FLAGS.optimizer):
        opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
              momentum=FLAGS.momentum,
              epsilon=RMSPROP_EPSILON)
    else:
        raise ValueError("Wrong optimizer!")

    # Get images and labels for ImageNet and split the batch across GPUs.
    assert FLAGS.batch_size % num_nodes == 0, (
        'Batch size must be divisible by number of nodes')

    # Override the number of preprocessing threads to account for the increased
    # number of GPU towers.
    num_preprocess_threads = FLAGS.num_preprocess_threads * num_nodes
    if FLAGS.benchmark_mode:
      images = tf.constant(0.5, shape=[FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
      labels = tf.random_uniform([FLAGS.batch_size], minval=0, maxval=dataset.num_classes()-1, dtype=tf.int32)
    else:
      images, labels = image_processing.distorted_inputs(
          dataset,
          num_preprocess_threads=num_preprocess_threads)


    # Number of classes in the Dataset label set plus 1.
    # Label 0 is reserved for an (unused) background class.
    if FLAGS.dataset_name == 'imagenet':
      num_classes = dataset.num_classes() + 1
    else:
      num_classes = dataset.num_classes()

    # Split the batch of images and labels for towers.
    images_splits = tf.split(images, num_nodes, 0)
    labels_splits = tf.split(labels, num_nodes, 0)

    # Calculate the gradients for each model tower.
    tower_grads = [] # gradients of cross entropy or total cost for each tower
    tower_floating_grads = []  # gradients of cross entropy or total cost for each tower
    tower_batchnorm_updates = []
    tower_scalers = []
    #tower_reg_grads = []
    reuse_variables = None
    tower_entropy_losses = []
    tower_reg_losses = []
    for i in range(num_nodes):
      with tf.device('/gpu:%d' % (i%FLAGS.num_gpus)):
        with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
          with tf.variable_scope('%s_%d' % (inception.TOWER_NAME, i)):
            # Force Variables to reside on the individual GPU.
            #with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
            with slim.arg_scope([slim.variables.variable], device='/gpu:%d' % (i%FLAGS.num_gpus)):
              # Calculate the loss for one tower of the ImageNet model. This
              # function constructs the entire ImageNet model but shares the
              # variables across all towers.
              loss, entropy_loss, reg_loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
                                 scope, reuse_variables)
            tower_entropy_losses.append(entropy_loss)
            tower_reg_losses.append(reg_loss)

            # Reuse variables for the next tower?
            reuse_variables = None

            # Retain the Batch Normalization updates operations.
            batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
                                                scope)
            batchnorm_updates = batchnorm_updates + \
                                tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)

            tower_batchnorm_updates.append(batchnorm_updates)

            # Calculate the gradients for the batch of data on this ImageNet
            # tower.
            grads = opt.compute_gradients(loss, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope))

            # Keep track of the gradients across all towers.
            tower_grads.append(grads)
            tower_floating_grads.append(grads)

            # Calculate the scalers of binary gradients
            if 1 == FLAGS.grad_bits:
              # Always calculate scalers whatever clip_factor is.
              # Returns max value when clip_factor==0.0
              scalers = bingrad_common.gradient_binarizing_scalers(grads, FLAGS.clip_factor)
              tower_scalers.append(scalers)

            # regularization gradients
            #if FLAGS.weight_decay:
            #  reg_grads = opt.compute_gradients(reg_loss, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope))
            #  tower_reg_grads.append(reg_grads)

    if 1 == FLAGS.grad_bits:
      # for grads in tower_grads:
      #   _gradient_summary(grads, 'floating')

      # We must calculate the mean of each scaler. Note that this is the
      # synchronization point across all towers @ CPU.
      # mean_scalers = bingrad_common.average_scalers(tower_scalers)
      mean_scalers = bingrad_common.max_scalers(tower_scalers)
      # for mscaler in mean_scalers:
      #   if mscaler is not None:
      #     tf.summary.scalar(mscaler.op.name + '/mean_scaler', mscaler)

      grad_shapes_for_deocder = []
      for i in xrange(num_nodes):
        with tf.device('/gpu:%d' % (i%FLAGS.num_gpus)):
          with tf.name_scope('binarizer_%d' % (i)) as scope:
            # Clip and binarize gradients
            # and keep track of the gradients across all towers.
            if FLAGS.quantize_logits:
              tower_grads[i][:] = bingrad_common.stochastical_binarize_gradients(
                tower_grads[i][:], mean_scalers[:])
            else:
              tower_grads[i][:-2] = bingrad_common.stochastical_binarize_gradients(
                  tower_grads[i][:-2], mean_scalers[:-2])

            _gradient_summary(tower_grads[i], 'binary', add_sparsity=True)

          if FLAGS.use_encoding:
            # encoding
            with tf.name_scope('encoder_%d' % (i)) as scope:
              if 0==i:
                tower_grads[i][:-2], grad_shapes_for_deocder = \
                  bingrad_common.encode_to_ternary_gradients(tower_grads[i][:-2], get_shape=True)
              else:
                tower_grads[i][:-2] = bingrad_common.encode_to_ternary_gradients(tower_grads[i][:-2], get_shape=False)

    # decoding @ CPU
    if (1 == FLAGS.grad_bits) and FLAGS.use_encoding:
      with tf.name_scope('decoder') as scope:
        for i in xrange(num_nodes):
          tower_grads[i][:-2] = bingrad_common.decode_from_ternary_gradients(
            tower_grads[i][:-2], mean_scalers[:-2], grad_shapes_for_deocder)

    # Switch between binarized and floating gradients
    if (FLAGS.floating_grad_epoch>0) and (1 == FLAGS.grad_bits):
      epoch_remainder = tf.mod( ( (global_step / num_nodes) * FLAGS.batch_size) / dataset.num_examples_per_epoch(),
             FLAGS.floating_grad_epoch)
      cond_op = tf.equal(tf.to_int32(tf.floor(epoch_remainder)), tf.to_int32(FLAGS.floating_grad_epoch-1))
      for i in xrange(num_nodes):
        with tf.name_scope('switcher_%d' % (i)) as scope:
          _, selected_variables = zip( *(tower_floating_grads[i]) )
          selected_gradients = []
          for j in range(len(tower_floating_grads[i])):
            selected_gradients.append( tf.cond(cond_op,
                                  lambda: tower_floating_grads[i][j][0],
                                  lambda: tower_grads[i][j][0]) )
          tower_grads[i] = list(zip(selected_gradients, selected_variables))


    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers @ CPU.
    if len(tower_grads)>1:
      tower_grads = bingrad_common.average_gradients2(tower_grads)


    # Add a summary to track the learning rate.
    tf.summary.scalar('learning_rate', lr)

    # Add histograms for gradients.
    # for grads in tower_grads:
    #   _gradient_summary(grads, 'final')

    # Apply the gradients to adjust the shared variables.
    # @ GPUs
    #apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    apply_gradient_op = []
    for i in xrange(num_nodes):
      with tf.device('/gpu:%d' % (i%FLAGS.num_gpus)):
        with tf.name_scope('grad_applier_%d' % (i)) as scope:
          # apply data loss SGD. global_step is incremented by num_nodes per iter
          apply_gradient_op.append(opt.apply_gradients(tower_grads[i],
                                          global_step=global_step))
          #if FLAGS.weight_decay:
          #  # apply regularization, global_step is omitted to avoid incrementation
          #  apply_gradient_op.append(opt.apply_gradients(tower_reg_grads[i]))

    # Add histograms for trainable variables.
    for var in tf.trainable_variables():
      tf.summary.histogram(var.op.name, var)

    # Track the moving averages of all trainable variables.
    # Note that we maintain a "double-average" of the BatchNormalization
    # global statistics. This is more complicated then need be but we employ
    # this for backward-compatibility with our previous models.
    variable_averages = tf.train.ExponentialMovingAverage(
        inception.MOVING_AVERAGE_DECAY, global_step/num_nodes)

    # Another possiblility is to use tf.slim.get_variables().
    variables_to_average = (tf.trainable_variables() +
                            tf.moving_average_variables())
    variables_averages_op = variable_averages.apply(variables_to_average)

    # Group all updates to into a single train op.
    #batchnorm_updates_op = tf.group(*batchnorm_updates)
    #train_op = tf.group(apply_gradient_op, variables_averages_op,
    #                    batchnorm_updates_op)
    batchnorm_updates_op = tf.no_op()
    for tower_batchnorm_update in tower_batchnorm_updates:
      batchnorm_updates_op = tf.group(batchnorm_updates_op, *tower_batchnorm_update)
    apply_gradient_op = tf.group(*apply_gradient_op)
    train_op = tf.group(apply_gradient_op, variables_averages_op, batchnorm_updates_op)

    # Create a saver.
    #saver = tf.train.Saver(tf.all_variables())
    if FLAGS.save_tower>=0:
      # Only save the variables in a tower
      save_pattern = ('(%s_%d)' % (inception.TOWER_NAME, FLAGS.save_tower)) + ".*" #+ ".*ExponentialMovingAverage"
      var_dic = {}
      _vars = tf.global_variables()
      for _var in _vars:
          if re.compile(save_pattern).match(_var.op.name):
              _var_name = re.sub('%s_[0-9]*/' % inception.TOWER_NAME, '', _var.op.name)
              var_dic[_var_name] = _var
      saver = tf.train.Saver(var_dic)
    else:
      saver = tf.train.Saver(tf.global_variables())

    # average loss summaries
    avg_entropy_loss = tf.reduce_mean(tower_entropy_losses)
    avg_reg_loss = tf.reduce_mean(tower_reg_losses)
    avg_total_loss = tf.add(avg_entropy_loss, avg_reg_loss)
    tf.summary.scalar('avg_entropy_loss', avg_entropy_loss)
    tf.summary.scalar('avg_reg_loss', avg_reg_loss)
    tf.summary.scalar('avg_total_loss', avg_total_loss)

    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)

    # Build the summary operation from the last tower summaries.
    summary_op = tf.summary.merge(summaries)

    # Build an initialization operation to run below.
    init = tf.global_variables_initializer()

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to build towers on GPU, as some of the ops do not have GPU
    # implementations.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True ############ Excepted GPU op may be placed CPU
    config.log_device_placement = FLAGS.log_device_placement
    sess = tf.Session(config=config)
    sess.run(init)

    trained_step = 0
    if FLAGS.pretrained_model_checkpoint_path:
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      ckpt = tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path)
      if ckpt and ckpt.model_checkpoint_path:
        trained_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        trained_step = int(trained_step) + 1
        variables_to_restore = tf.get_collection(
            slim.variables.VARIABLES_TO_RESTORE)+ \
                               tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        restorer = tf.train.Saver(variables_to_restore)
        if os.path.isabs(ckpt.model_checkpoint_path):
          restorer.restore(sess, ckpt.model_checkpoint_path)
        else:
          restorer.restore(sess, os.path.join(FLAGS.pretrained_model_checkpoint_path,
                                         ckpt.model_checkpoint_path))
        print('%s: Pre-trained model restored from %s' %
              (datetime.now(), FLAGS.pretrained_model_checkpoint_path))
      else:
        print('%s: Restoring pre-trained model from %s failed!' %
              (datetime.now(), FLAGS.pretrained_model_checkpoint_path))
        exit()

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir,
        graph=tf.get_default_graph())

    for step in range(trained_step, FLAGS.max_steps):
      start_time = time.time()
      _, entropy_loss_value, reg_loss_value = sess.run([train_op, entropy_loss, reg_loss])
      duration = time.time() - start_time

      assert not np.isnan(entropy_loss_value), 'Model diverged with entropy_loss = NaN'

      if step % 10 == 0:
        examples_per_sec = FLAGS.batch_size / float(duration)
        format_str = ('%s: step %d, entropy_loss = %.2f, reg_loss = %.2f, total_loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print(format_str % (datetime.now(), step,
                            entropy_loss_value, reg_loss_value, entropy_loss_value+reg_loss_value,
                            examples_per_sec, duration))

      if step % FLAGS.save_iter == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % FLAGS.save_iter == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 20
0
def train(dataset):
    """Train on dataset for a number of steps."""
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable(
            'global_step', [],
            initializer=tf.constant_initializer(0), trainable=False)

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                 FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
                                        momentum=RMSPROP_MOMENTUM,
                                        epsilon=RMSPROP_EPSILON)

        # Get images and labels for ImageNet and split the batch across GPUs.
        assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')
        split_batch_size = int(FLAGS.batch_size / FLAGS.num_gpus)

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
        images, labels = image_processing.distorted_inputs(
            dataset,
            num_preprocess_threads=num_preprocess_threads)

        input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        num_classes = dataset.num_classes() + 1

        # Split the batch of images and labels for towers.
        images_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=images)
        labels_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=labels)

        # Calculate the gradients for each model tower.
        tower_grads = []
        reuse_variables = None
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
                    # Force all Variables to reside on the CPU.
                    with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
                        # Calculate the loss for one tower of the ImageNet model. This
                        # function constructs the entire ImageNet model but shares the
                        # variables across all towers.
                        loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
                                           scope, reuse_variables)

                    # Reuse variables for the next tower.
                    reuse_variables = True

                    # Retain the summaries from the final tower.
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

                    # Retain the Batch Normalization updates operations only from the
                    # final tower. Ideally, we should grab the updates from all towers
                    # but these stats accumulate extremely fast so we can ignore the
                    # other stats from the other towers without significant detriment.
                    batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
                                                          scope)

                    # Calculate the gradients for the batch of data on this ImageNet
                    # tower.
                    grads = opt.compute_gradients(loss)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = _average_gradients(tower_grads)

        # Add a summaries for the input processing and global_step.
        summaries.extend(input_summaries)

        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', lr))

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables.
        # Note that we maintain a "double-average" of the BatchNormalization
        # global statistics. This is more complicated then need be but we employ
        # this for backward-compatibility with our previous models.
        variable_averages = tf.train.ExponentialMovingAverage(
            inception.MOVING_AVERAGE_DECAY, global_step)

        # Another possibility is to use tf.slim.get_variables().
        variables_to_average = (tf.trainable_variables() +
                                tf.moving_average_variables())
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all updates to into a single train op.
        batchnorm_updates_op = tf.group(*batchnorm_updates)
        train_op = tf.group(apply_gradient_op, variables_averages_op,
                            batchnorm_updates_op)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))


        def profile(run_metadata, epoch=0):
            with open('profs/timeline_step' + str(epoch) + '.json', 'w') as f:
                # Create the Timeline object, and write it to a json file
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                f.write(chrome_trace)

        def graph_to_dot(graph):
            dot = Digraph()
            for n in graph.as_graph_def().node:
                dot.node(n.name, label= n.name)
                for i in n.input:
                    dot.edge(i, n.name)
            return dot

        dot_rep = graph_to_dot(tf.get_default_graph())
        s = Source(dot_rep, filename="test.gv", format="PNG")
        with open('profs/A_dot.dot', 'w') as fwr:
            fwr.write(str(dot_rep))

        options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)
        run_metadata = tf.RunMetadata()

        sess.run(init, run_metadata=run_metadata, options=options)

        profile(run_metadata, -1)

      #  s.view()
        s.save('inc.PNG')



        if FLAGS.pretrained_model_checkpoint_path:
            assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
            variables_to_restore = tf.get_collection(
                slim.variables.VARIABLES_TO_RESTORE)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(
            FLAGS.train_dir,
            graph=sess.graph)

        operations_tensors = {}
        operations_names = tf.get_default_graph().get_operations()
        count1 = 0
        count2 = 0

        for operation in operations_names:
            operation_name = operation.name
            operations_info = tf.get_default_graph().get_operation_by_name(operation_name).values()
            if len(operations_info) > 0:
                if not (operations_info[0].shape.ndims is None):
                    operation_shape = operations_info[0].shape.as_list()
                    operation_dtype_size = operations_info[0].dtype.size
                    if not (operation_dtype_size is None):
                        operation_no_of_elements = 1
                        for dim in operation_shape:
                            if not(dim is None):
                                operation_no_of_elements = operation_no_of_elements * dim
                        total_size = operation_no_of_elements * operation_dtype_size
                        operations_tensors[operation_name] = total_size
                    else:
                        count1 = count1 + 1
                else:
                    count1 = count1 + 1
                    operations_tensors[operation_name] = -1

            else:
                count2 = count2 + 1
                operations_tensors[operation_name] = -1
        print(count1)
        print(count2)

        with open('tensors_sz.json', 'w') as f:
            json.dump(operations_tensors, f)

        for step in range(FLAGS.max_steps):
            start_time = time.time()
            if step > 100 and step % 101 == 0:
                sess.run([train_op, loss], run_metadata=run_metadata, options=options)
                profile(run_metadata, step)
            else:
                _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                examples_per_sec = FLAGS.batch_size / float(duration)
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, duration))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)