def evaluate(dataset,n):
  """Evaluate model on Dataset for a number of steps."""
  with tf.Graph().as_default():
    # Get images and labels from the dataset.
    images, pitchs, yaws, rolls, names = image_processing.inputs(dataset)
   
    # if n==1:
    #   images_leftup = images[:,0:85,0:85,:]
    #   images_leftup = tf.image.resize_bilinear(images_leftup, [32, 32],
    #                                    align_corners=False)
    #   eval_output = model.inference(images_leftup)

    # if n==2:
    #   images_rightup = images[:,14:99,0:85,:]
    #   images_rightup = tf.image.resize_bilinear(images_rightup, [32, 32],
    #                                    align_corners=False)
    #   eval_output = model.inference(images_rightup)

    # if n==3:
    #   images_leftdown = images[:,0:85,14:99,:]
    #   images_leftdown = tf.image.resize_bilinear(images_leftdown, [32, 32],
    #                                    align_corners=False)
    #   eval_output = model.inference(images_leftdown)

    # if n==4:
    #   images_rightdown = images[:,14:99,14:99,:]
    #   images_rightdown = tf.image.resize_bilinear(images_rightdown, [32, 32],
    #                                    align_corners=False)
    #   eval_output = model.inference(images_rightdown)

    # if n==5:
    images_center = images[:,0:7,92:99,:]
    images_center = tf.image.resize_bilinear(images_center, [32, 32],
                                       align_corners=False)
    eval_output = model.inference(images_center)
    
    p = tf.expand_dims(pitchs,1)
    y = tf.expand_dims(yaws,1)
    r = tf.expand_dims(rolls,1)
    labels = tf.concat([p, y, r],1)
   
    # Calculate predictions.
 
    error_op = tf.reduce_sum(tf.abs(eval_output-labels),0)
    acc_op = tf.abs(eval_output-labels)

    saver = tf.train.Saver()

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    graph_def = tf.get_default_graph().as_graph_def()
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                            graph_def=graph_def)

    while True:
      _eval_once(saver, summary_writer, error_op, summary_op, acc_op, labels, eval_output,n)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)
Exemplo n.º 2
0
  def build_inputs(self):
    """Input prefetching, preprocessing and batching.

    Outputs:
      inputs: images with 4-D Tensor [batch_size, height, width, channels]
      labels: labels in each angle class
    """
#   if self.mode == "inference":
#     # In inference mode, images are fed via placeholder.
#     with tf.variable_scope('images'):
#       self.images = tf.placeholder(dtype=tf.float32,
#         shape=[None, self.num_frames, self.image_size, self.image_size, 3])

    if self.mode == 'train':
      with tf.variable_scope('images_and_labels'):
        self.images, self.labels = image_processing.distorted_inputs(
                                       batch_size=self.batch_size,
                                       num_preprocess_threads=self.num_preprocess_threads)
        #   self.images = tf.random_normal([self.batch_size, self.image_size, self.image_size, 3], dtype=tf.float32)
        #   self.labels = tf.random_uniform(shape=[self.batch_size, self.num_classes], maxval=2, dtype=tf.int32)

    elif self.mode == 'validation':
      with tf.variable_scope('images_and_labels'):
        self.images, self.labels = image_processing.inputs(
                                          batch_size=self.batch_size_val,
                                          num_preprocess_threads=self.num_preprocess_threads)
        # self.images = tf.random_normal([self.batch_size, self.image_size, self.image_size, 3], dtype=tf.float32)
        # self.labels = tf.random_uniform(shape=[self.batch_size, self.num_classes], maxval=2, dtype=tf.int32)

    else:
      with tf.variable_scope('images_and_labels'):
        self.images = tf.placeholder(dtype=tf.float32,
                                     shape=[1, FLAGS.image_size, FLAGS.image_size, 3])

    print('complete build inputs.')
Exemplo n.º 3
0
def evaluate(dataset):
  """Evaluate model on Dataset for a number of steps."""
  with tf.Graph().as_default():
    # Get images and labels from the dataset.
    images, pitchs, yaws, rolls, names = image_processing.inputs(dataset)
   
    p = tf.expand_dims(pitchs,1)
    y = tf.expand_dims(yaws,1)
    r = tf.expand_dims(rolls,1)
    labels = tf.concat([p, y, r],1)

    eval_output = model.inference(images,FLAGS.is_training)
    # Calculate predictions.
    error_op = tf.reduce_sum(tf.abs(eval_output-labels),0)
    acc_op = tf.abs(eval_output-labels)

    saver = tf.train.Saver()

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    graph_def = tf.get_default_graph().as_graph_def()
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                            graph_def=graph_def)

    while True:
      _eval_once(saver, summary_writer, error_op, summary_op, acc_op)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)
Exemplo n.º 4
0
def build_val_graph(config, dataset):

    with tf.device('/cpu:0'):
        inputs, labels = image_processing.inputs(
            dataset,
            batch_size=config['parameters']['batch_size'],
            height=config['input']['height'],
            width=config['input']['width'],
            channels=config['input']['channels'],
            num_preprocess_threads=8)

    with tf.device('/gpu:0'):
        logits, endpoints = cnn_architectures.create_model(
            config['model']['architecture'],
            inputs,
            is_training=False,
            num_classes=config['input']['classes'],
            reuse=True)

    labels = tf.cast(labels, tf.int64)  # if needed,change to type int64
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    cross_entropy_mean = tf.reduce_mean(cross_entropy)

    loss = tf.add_n([cross_entropy_mean] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name='total_loss')

    correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    tf.summary.scalar('val/accuracy', accuracy, collections=['validation'])
    tf.summary.scalar('val/loss', loss, collections=['validation'])

    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name, var, collections=['validation'])

    return loss, accuracy, tf.summary.merge_all(key='validation')
Exemplo n.º 5
0
def build_val_graph(config, dataset):
    with tf.device('/cpu:0'):
        inputs, labels = image_processing.inputs(
            dataset,
            batch_size=config['parameters']['batch_size'],
            height=config['input']['height'],
            width=config['input']['width'],
            channels=config['input']['channels'],
            num_preprocess_threads=8)

    with tf.device('/gpu:0'):
        logits, endpoints = cnn_architectures.create_model(
            config['model']['architecture'],
            inputs,
            is_training=False,
            num_classes=config['input']['classes'],
            reuse=True)

    if config['parameters']['loss'] == 'regression':
        labels = tf.cast(labels - config['parameters']['label_mean'],
                         tf.float32)  # if needed,change to type int64
        mean_squared_error = tf.losses.mean_squared_error(labels=labels,
                                                          predictions=logits)
        loss = tf.add_n([mean_squared_error] +
                        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES),
                        name='total_loss')
        accuracy = tf.constant(0, shape=[], dtype=tf.float32)
    if config['parameters']['loss'] == 'classification':
        labels = tf.cast(labels // 5, tf.int64)

        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels)
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        loss = tf.add_n([cross_entropy_mean] +
                        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES),
                        name='total_loss')

        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    with tf.name_scope('metrics'):
        m_loss, loss_update_op = tf.contrib.metrics.streaming_mean(loss,
                                                                   name='loss')
        m_accuracy, accuracy_update_op = tf.contrib.metrics.streaming_mean(
            accuracy, name='accuracy')

    stream_vars = [i for i in tf.local_variables() if 'metrics' in i.name]
    reset_op = [tf.variables_initializer(stream_vars)]

    tf.summary.scalar('loss', m_loss, collections=['validation'])
    tf.summary.scalar('accuracy', accuracy, collections=['validation'])

    if config['output']['trainable_variables_to_summary']:
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var, collections=['validation'])

    return m_loss, m_accuracy, tf.summary.merge_all(
        key='validation'), tf.group(loss_update_op,
                                    accuracy_update_op), reset_op
Exemplo n.º 6
0
def evaluate(dataset):
    """Evaluate model on Dataset for a number of steps."""
    with tf.Graph().as_default():
        # Get images and labels from the dataset.
        images, labels = image_processing.inputs(dataset)

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        num_classes = dataset.num_classes() + 1

        # Build a Graph that computes the logits predictions from the
        # inference model.
        if FLAGS.bitpack is False:
            logits, _ = inception.inference_resnet(images,
                                                   num_classes,
                                                   for_training=False)
        else:
            logits, _ = inception.inference_resnet_bitpack_val(
                images, num_classes, for_training=False)

        # Calculate loss.
        split_batch_size = images.get_shape().as_list()[0]
        loss = inception.loss_resnet(logits,
                                     labels,
                                     batch_size=split_batch_size)
        total_loss = tf.add_n(loss)

        # Calculate predictions.
        top_1_op = tf.nn.in_top_k(logits, labels, 1)
        top_5_op = tf.nn.in_top_k(logits, labels, 5)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            inception.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               graph_def=graph_def)

        while True:
            _eval_once(saver,
                       summary_writer,
                       top_1_op,
                       top_5_op,
                       summary_op,
                       loss=total_loss)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Exemplo n.º 7
0
def evaluation():
    #data_files_ = TRAIN_FILE
    data_files_ = data_files(FLAGS.train_or_validation)
    images, labels = image_processing.inputs(data_files_,
                                             FLAGS.num_epochs,
                                             batch_size=FLAGS.batch_size)

    labels = tf.one_hot(labels, 1000)
    logits = inference(images)
    correct_pred = tf.equal(tf.arg_max(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess = tf.Session()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    #save/restore model
    d = {}
    l = [
        'w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w_fc1',
        'b_fc1', 'w_fc2', 'b_fc2', 'w_output', 'b_output'
    ]
    for i in l:
        d[i] = [
            v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            if v.name == i + ':0'
        ][0]
    saver = tf.train.Saver(d)
    saver.restore(sess, FLAGS.model_path)

    try:
        step = 0
        start_time = time.time()
        while not coord.should_stop():
            start_batch = time.time()
            acc = sess.run(accuracy)
            duration = time.time() - start_batch
            print('Step %d | accuracy = %.2f (%.3f sec/batch)') % (step, acc,
                                                                   duration)
            step += 1
    except tf.errors.OutOfRangeError:
        print('Done evaluating for %d epochs, %d steps, %.1f min.' %
              (FLAGS.num_epochs, step, (time.time() - start_time) / 60))
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
Exemplo n.º 8
0
def evaluate(dataset):
    """Evaluate model on Dataset for a number of steps."""
    with tf.Graph().as_default():
        # Get images and labels from the dataset.
        images, labels = image_processing.inputs(dataset)

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        num_classes = dataset.num_classes() + 1

        # Build a Graph that computes the logits predictions from the
        # inference model.
        graph = Model_Graph(num_class=num_classes, is_training=False)

        model = graph._build_defaut_graph(images=images)

        # Calculate predictions.
        top_1_op = tf.nn.in_top_k(model.logits, labels, 1)
        top_5_op = tf.nn.in_top_k(model.logits, labels, 5)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        #saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               graph_def=graph_def)

        while True:
            _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op,
                       model, labels)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Exemplo n.º 9
0
def evaluate(dataset):
  """Evaluate model on Dataset for a number of steps."""
  with tf.Graph().as_default():
    # Get images and labels from the dataset.
    images, labels = image_processing.inputs(dataset)

    # Number of classes in the Dataset label set plus 1.
    # Label 0 is reserved for an (unused) background class.
    num_classes = dataset.num_classes() + 1

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits, _ = inception.inference(images, num_classes)

    # Calculate predictions.
    top_1_op = tf.nn.in_top_k(logits, labels, 1)
    top_5_op = tf.nn.in_top_k(logits, labels, 5)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        inception.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    graph_def = tf.get_default_graph().as_graph_def()
    summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
                                            graph_def=graph_def)

    while True:
      _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)
Exemplo n.º 10
0
def evaluate(dataset, n):
    """Evaluate model on Dataset for a number of steps."""
    with tf.Graph().as_default():
        # Get images and labels from the dataset.
        images, pitchs, yaws, rolls, names = image_processing.inputs(dataset)

        if n == 1:
            images_leftup = images[:, 0:85, 0:85, :]
            images_leftup = tf.image.resize_bilinear(images_leftup, [32, 32],
                                                     align_corners=False)
            eval_output = model.inference(images_leftup)

        if n == 2:
            images_rightup = images[:, 14:99, 0:85, :]
            images_rightup = tf.image.resize_bilinear(images_rightup, [32, 32],
                                                      align_corners=False)
            eval_output = model.inference(images_rightup)

        if n == 3:
            images_leftdown = images[:, 0:85, 14:99, :]
            images_leftdown = tf.image.resize_bilinear(images_leftdown,
                                                       [32, 32],
                                                       align_corners=False)
            eval_output = model.inference(images_leftdown)

        if n == 4:
            images_rightdown = images[:, 14:99, 14:99, :]
            images_rightdown = tf.image.resize_bilinear(images_rightdown,
                                                        [32, 32],
                                                        align_corners=False)
            eval_output = model.inference(images_rightdown)

        if n == 5:
            images_center = images[:, 0:7, 92:99, :]
            images_center = tf.image.resize_bilinear(images, [32, 32],
                                                     align_corners=False)
            eval_output = model.inference(images_center)

        p = tf.expand_dims(pitchs, 1)
        y = tf.expand_dims(yaws, 1)
        r = tf.expand_dims(rolls, 1)
        labels = tf.concat([p, y, r], 1)

        # Calculate predictions.

        error_op = eval_output - labels
        # acc_op = tf.abs(eval_output-labels)

        saver = tf.train.Saver()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        graph_def = tf.get_default_graph().as_graph_def()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,
                                               graph_def=graph_def)

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                if os.path.isabs(ckpt.model_checkpoint_path):
                    # Restores from checkpoint with absolute path.
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    # Restores from checkpoint with relative path.
                    saver.restore(
                        sess,
                        os.path.join(FLAGS.checkpoint_dir,
                                     ckpt.model_checkpoint_path))

                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/imagenet_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
                print('Successfully loaded model from %s at step=%s.' %
                      (ckpt.model_checkpoint_path, global_step))
            else:
                print('No checkpoint file found')
                return

            # Start the queue runners.
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                num_iter = int(math.ceil(FLAGS.num_examples /
                                         FLAGS.batch_size))
                # Counts the number of correct predictions.

                total_sample_count = num_iter * FLAGS.batch_size
                step = 0

                total_error = [0.0, 0.0, 0.0]
                total_acc = [0.0, 0.0, 0.0]

                g_p_error = []

                print('%s: starting evaluation on (%s).' %
                      (datetime.now(), FLAGS.subset))
                start_time = time.time()
                while step < num_iter and not coord.should_stop():
                    error = sess.run(error_op)
                    e = np.sum(np.abs(error), axis=0)
                    # acc_tmp = sess.run(acc_op)

                    # l = sess.run(labels)
                    # e = sess.run(eval_output)

                    # print(l)
                    # print(e)
                    # print(sess.run(names))
                    # break

                    g_p_error.extend(error)

                    acc_tmp = np.abs(error)
                    acc_tmp[acc_tmp < 5] = 1
                    acc_tmp[acc_tmp >= 5] = 0
                    acc_tmp_sum = np.sum(acc_tmp, axis=0)

                    total_error += e
                    total_acc += acc_tmp_sum
                    step += 1
                    if step % 20 == 0:
                        duration = time.time() - start_time
                        sec_per_batch = duration / 20.0
                        examples_per_sec = FLAGS.batch_size / sec_per_batch
                        print(
                            '%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                            'sec/batch)' % (datetime.now(), step, num_iter,
                                            examples_per_sec, sec_per_batch))
                        start_time = time.time()

                # Compute precision @ 1.
                mae = total_error / total_sample_count
                acc = total_acc / total_sample_count

                print(
                    '%s: pitch_mae = %.4f yaw_mae = %.4f roll_mae = %.4f pitch_acc = %.4f yaw_acc = %.4f roll_acc = %.4f [%d examples]'
                    % (datetime.now(), mae[0], mae[1], mae[2], acc[0], acc[1],
                       acc[2], total_sample_count))

                g_p = np.array(g_p_error)
                # p = np.array(predict_label)

                np.savetxt("error_" + str(n) + ".txt", g_p)
                # np.savetxt("predict_label_"+str(n)+".txt",p)

                # summary = tf.Summary()
                # summary.ParseFromString(sess.run(summary_op))
                # summary.value.add(tag='Precision @ 1', simple_value=precision_at_1)
                # summary.value.add(tag='Recall @ 5', simple_value=recall_at_5)
                # summary_writer.add_summary(summary, global_step)

            except Exception as e:  # pylint: disable=broad-except
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)
Exemplo n.º 11
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Get images and labels for CIFAR-10.
        #dataset = CIFARData(subset='train')
        dataset = ImagenetData(subset='train')
        assert dataset.data_files()

        #test_set = CIFARData(subset='validation')
        test_set = ImagenetData(subset='validation')
        assert test_set.data_files()

        epoch1 = .5 * helper.MAX_EPOCHS
        epoch2 = .75 * helper.MAX_EPOCHS
        step1 = dataset.num_examples_per_epoch() * epoch1 // (
            helper.BATCH_SIZE)
        step2 = dataset.num_examples_per_epoch() * epoch2 // (
            helper.BATCH_SIZE)
        print('Reducing learning rate at step ' + str(step1) + ' and step ' +
              str(step2) + ' and ending at ' + str(helper.MAX_STEPS))

        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # Learning rate
        lr = .1

        #learning_rate = tf.placeholder(tf.float32, shape=[], name='learning_rate')
        dropout = tf.placeholder(tf.float32, shape=[], name='dropout')
        is_training = tf.placeholder(tf.bool, shape=[], name='is_training')

        boundaries = [step1, step2]
        values = [lr, lr / 10, lr / 100]

        learning_rate = tf.train.piecewise_constant(global_step,
                                                    boundaries,
                                                    values,
                                                    name=None)

        decayed_lr = tf.train.polynomial_decay(lr,
                                               global_step,
                                               helper.MAX_STEPS,
                                               end_learning_rate=0.0001,
                                               power=4.0,
                                               cycle=False,
                                               name=None)

        # Create an optimizer that performs gradient descent.
        with tf.name_scope('Optimizer'):
            opt = tf.train.MomentumOptimizer(learning_rate=decayed_lr,
                                             momentum=0.9,
                                             use_nesterov=True)
            #opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True)

        tf.summary.scalar('decayed_learning_rate', decayed_lr)
        tf.summary.scalar('learning_rate', learning_rate)

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        num_preprocess_threads = helper.NUM_THREADS * helper.N_GPUS
        distorted_images, distorted_labels = image_processing.distorted_inputs(
            dataset,
            batch_size=helper.SPLIT_BATCH_SIZE,
            num_preprocess_threads=num_preprocess_threads)

        #images, labels = image_processing.inputs(dataset, batch_size=helper.BATCH_SIZE, num_preprocess_threads=num_preprocess_threads)
        test_images, test_labels = image_processing.inputs(
            test_set,
            batch_size=helper.SPLIT_BATCH_SIZE,
            num_preprocess_threads=num_preprocess_threads)

        input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Split the batch of images and labels for towers.
        #images_splits = tf.split(axis=0, num_or_size_splits=helper.N_GPUS, value=distorted_images)
        #labels_splits = tf.split(axis=0, num_or_size_splits=helper.N_GPUS, value=distorted_labels)

        batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
            [distorted_images, distorted_labels], capacity=2 * helper.N_GPUS)

        # Calculate the gradients for each model tower.
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(helper.N_GPUS):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' %
                                       (helper.TOWER_NAME, i)) as scope:
                        # Calculate the loss for one tower of the CIFAR model. This function
                        # constructs the entire CIFAR model but shares the variables across
                        # all towers.
                        image_batch, label_batch = batch_queue.dequeue()
                        loss = tower_loss(scope,
                                          image_batch,
                                          label_batch,
                                          dropout=dropout,
                                          is_training=is_training)
                        #loss = tower_loss(scope, images_splits[i], labels_splits[i], dropout=dropout, is_training=is_training)

                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        tf.get_variable_scope().reuse_variables()

                        grads = opt.compute_gradients(loss)

                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Add a summaries for the input processing and global_step.
        summaries.extend(input_summaries)

        # Apply the gradients to adjust the shared variables.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            apply_gradient_op = opt.apply_gradients(grads,
                                                    global_step=global_step)

            # Track the moving averages of all trainable variables.
            variable_averages = tf.train.ExponentialMovingAverage(
                helper.MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())

            # Group all updates to into a single train op.
            #train_op = apply_gradient_op
            train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Add histograms for trainable variables.
        #for var in tf.trainable_variables():
        #    summaries.append(tf.summary.histogram(var.op.name, var))

        for grad, var in grads:
            summaries.append(tf.summary.histogram(var.op.name, var))
            #summaries.append(tf.summary.histogram(var.op.name + '_gradient', grad))

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        cross_entropy_op = tf.reduce_mean(tf.get_collection('cross_entropies'),
                                          name='cross_entropy')

        accuracy_op = tf.reduce_mean(tf.get_collection('accuracy'),
                                     name='accuracies')
        summaries.append(tf.summary.scalar('cross_entropy', cross_entropy_op))
        summaries.append(tf.summary.scalar('accuracy', accuracy_op))

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))

        #run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #run_metadata = tf.RunMetadata()

        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        if RESTORE == True:
            ckpt = tf.train.get_checkpoint_state(SAVE_POINT)
            saver.restore(sess, ckpt.model_checkpoint_path)

            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/imagenet_train/model.ckpt-0,
            # extract global_step from it.
            restored_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            print('Successfully loaded model from %s at step=%s.' %
                  (ckpt.model_checkpoint_path, restored_step))
            step = int(restored_step)
            range_step = range(step, helper.MAX_STEPS)
            tf.get_variable_scope().reuse_variables()
            global_step = tf.get_variable('global_step', trainable=False)
        else:
            range_step = range(helper.MAX_STEPS)

        summary_writer = tf.summary.FileWriter('summary', graph=sess.graph)
        num_params = helper.count_params() / 1e6
        print('Total number of params = %.2fM' % num_params)
        print("training")
        top1_error = [-1.0, -1.0]
        top1_step = 0
        top5_error = [-1.0, -1.0]
        top5_step = 0

        for step in range_step:

            start_time = time.time()
            _, loss_value, cross_entropy_value, accuracy_value = sess.run(
                [train_op, loss, cross_entropy_op, accuracy_op],
                feed_dict={
                    dropout: 0.8,
                    is_training: True
                }
            )  #, options=run_options, run_metadata=run_metadata)#, learning_rate: lr})
            duration = time.time() - start_time

            if step == step1 or step == step2:
                print('Decreasing Learning Rate')
                lr /= 10

            if step % 10 == 0:
                num_examples_per_step = helper.BATCH_SIZE
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration

                format_str = (
                    'step %d, loss = %.2f, cross entropy = %.2f, accuracy = %.2f, %.3f sec/batch'
                )
                print(format_str % (step, loss_value, cross_entropy_value,
                                    accuracy_value, sec_per_batch))
                """
                # Create the Timeline object, and write it to a json
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('timeline.json', 'w') as f:
                    f.write(ctf)
                """

            if step % 100 == 0:
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           dropout: 0.8,
                                           is_training: False
                                       })  #, learning_rate: lr})
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 5000 == 0 or (step + 1) == helper.MAX_STEPS:
                if step != 0:
                    checkpoint_path = SAVE_POINT + 'model.ckpt'
                    saver.save(sess, checkpoint_path, global_step=step)
                    print('Model saved')

                    #evaluate(distorted_images, distorted_labels, sess, dropout=dropout, is_training=is_training, train=True)
                    top1, top5 = evaluate(test_images,
                                          test_labels,
                                          sess,
                                          dropout=dropout,
                                          is_training=is_training,
                                          train=False)
                    if top1 > top1_error[0]:
                        top1_error[0] = top1
                        top1_error[1] = top5
                        top1_step = step
                    if top5 > top5_error[1]:
                        top5_error[0] = top1
                        top5_error[1] = top5
                        top5_step = step
                    print(
                        "Best top1 model achieved top1: %.4f, top5: %.4f at step %d"
                        % (top1_error[0], top1_error[1], top1_step))
                    print(
                        "Best top5 model achieved top1: %.4f, top5: %.4f at step %d"
                        % (top5_error[0], top5_error[1], top5_step))