def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    testImg, testlabels = cifar10.inputs(eval_data=True)
    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)
    test_pre = cifar10.inference(testImg,test=True)
     
    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      if step % 10 == 0:
        print ('loss '+str(loss_value))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 10 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)

        #eval
      if step%10==0:
        cifar10.accuracy(test_pre,testlabels)
Esempio n. 2
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():

    global_step = tf.Variable(0, trainable=False)
    images, labels = cifar10.distorted_inputs()
    logits = cifar10.inference(images)
    loss = cifar10.loss(logits, labels)
    # loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels))
    # train_op = tf.train.GradientDescentOptimizer(1e-2).minimize(loss)
    train_op = cifar10.train(loss, global_step)
    top_k_op = tf.nn.in_top_k(logits, labels, 1)

    saver = tf.train.Saver(tf.all_variables())
    init = tf.initialize_all_variables()
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)
    tf.train.start_queue_runners(sess=sess)

    true_count = 0
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value, precisions = sess.run([train_op, loss, top_k_op])

      true_count += np.sum(precisions)

      if step % 10 == 0:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
        duration = time.time() - start_time
        print(' step %d, loss = %.3f, acc = %.3f, dur = %.2f' % 
             (step, loss_value, true_count/(FLAGS.batch_size*10), duration))
        true_count = 0
Esempio n. 3
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), "Model diverged with loss = NaN"

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = "%s: step %d, loss = %.2f (%.1f examples/sec; %.3f " "sec/batch)"
                print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 4
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = FLAGS.batch_size
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))
          global step_no 
          step_no = self._step

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, inter_op_parallelism_threads=4,intra_op_parallelism_threads=0)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        """run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
Esempio n. 6
0
def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        images, labels = cifar10.distorted_inputs()

        logits = cifar10.inference(images)

        loss = cifar10.loss(logits, labels)

        train_op = cifar10.train(loss, global_step)

        saver = tf.train.Saver(tf.all_variables())

        summary_op = tf.merge_all_summaries()

        init = tf.initialize_all_variables()

        sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=sess.graph_def)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), "Model diverged with loss = NaN"

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = "%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)"
                print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, "model.ckpt")
                saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 7
0
def train():

    # Communication defines
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    try:
        addr = socket.gethostbyname(socket.gethostname())
        print("Worker %d with address %s" % (rank, str(addr)))
    except:
        pass

    # Load data set
    images_train_raw, labels_train_raw, images_test_raw, labels_test_raw = cifar10_input.load_cifar_data_raw(
        rank)
    if rank != 0:
        random_permutation = np.random.permutation(images_train_raw.shape[0])
        images_train_raw = images_train_raw[random_permutation]
        labels_train_raw = labels_train_raw[random_permutation]

    # Basic model creation for cuda convnet
    scope_name = "parameters_1"
    with tf.variable_scope(scope_name):
        images = tf.placeholder(tf.float32,
                                shape=(None, cifar10.IMAGE_SIZE,
                                       cifar10.IMAGE_SIZE,
                                       cifar10.NUM_CHANNELS))
        labels = tf.placeholder(tf.int32, shape=(None, ))
        logits = cifar10.inference(images)
        loss_op = cifar10.loss(logits, labels, scope_name)
        train_op, grads_and_vars, opt = cifar10.train(loss_op, scope_name)
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

    model_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope=scope_name)
    model_variables_placeholders = [
        tf.placeholder(dtype=x.dtype, shape=x.get_shape())
        for x in model_variables
    ]
    model_variables_assign = [
        tf.assign(model_variables[i], model_variables_placeholders[i])
        for i in range(len(model_variables))
    ]

    apply_gradients_placeholders = [
        tf.placeholder(dtype=grad.dtype, shape=grad.get_shape())
        for grad, var in grads_and_vars
    ]
    apply_gradients_op = opt.apply_gradients(
        zip(apply_gradients_placeholders,
            [var for grad, var in grads_and_vars]))

    with tf.Session() as sess:

        tf.initialize_all_variables().run()
        tf.train.start_queue_runners(sess=sess)

        get_feed_dict.fractional_dataset_index = 0
        n_examples_processed = 0
        iteration = 0
        eval_iteration_interval = int(
            cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / (FLAGS.batch_size *
                                                        (size - 1)))
        evaluate_times = []
        t_start = time.time()

        sync_variables_times = 0
        accumulate_gradients_times = 0
        compute_times = 0
        previous_accuracy = 0

        for i in range(FLAGS.n_iterations):

            cur_epoch = n_examples_processed / cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN

            if cur_epoch >= FLAGS.n_epochs:
                break

            # Synchronize model
            t_synchronize_start = time.time()
            synchronize_model(sess, model_variables, comm, rank,
                              model_variables_assign,
                              model_variables_placeholders)
            t_synchronize_end = time.time()
            sync_variables_times += t_synchronize_end - t_synchronize_start

            if rank == 0 and iteration % 100 == 0:
                print("Epoch: %f" % (cur_epoch))

            if rank == 0:
                mean_sync = sync_variables_times / (iteration + 1)
                mean_compute = compute_times / (iteration + 1)
                mean_acc_gradients = accumulate_gradients_times / (iteration +
                                                                   1)
                print("Mean sync time: %f" % mean_sync)
                print("Mean compute time: %f" % mean_compute)
                print("Mean acc gradients time: %f" % mean_acc_gradients)

            if iteration % eval_iteration_interval == 0:

                # Evaluate on master
                if rank == 0 and iteration != 0:
                    print("Master evaluating...")
                    acc_total, loss_total = 0, 0
                    evaluate_t_start = time.time()
                    for i in range(0, cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN,
                                   FLAGS.evaluate_batchsize):
                        print("%d of %d" %
                              (i, cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN))
                        fd = get_feed_dict(FLAGS.evaluate_batchsize,
                                           images_train_raw, labels_train_raw,
                                           images, labels)
                        acc_p, loss_p = sess.run([top_k_op, loss_op],
                                                 feed_dict=fd)
                        acc_total += np.sum(acc_p)
                        loss_total += loss_p
                    evaluate_t_end = time.time()
                    evaluate_times.append(evaluate_t_end - evaluate_t_start)
                    acc_total /= cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
                    loss_total /= cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
                    previous_accuracy = acc_total
                    print("Epoch: %f, Time: %f, Accuracy: %f, Loss: %f" %
                          (cur_epoch, time.time() - sum(evaluate_times) -
                           t_start, acc_total, loss_total))

                    if acc_total >= FLAGS.accuracy_to_reach:
                        break
                comm.Barrier()

            # Perform distributed gradient descent
            t_compute_start = time.time()
            materialized_gradients = []
            if rank != 0:
                fd = get_feed_dict(FLAGS.batch_size, images_train_raw,
                                   labels_train_raw, images, labels)
                materialized_gradients = sess.run(
                    [x[0] for x in grads_and_vars], feed_dict=fd)

                # Save gradients on a particular worker
                if rank == 1 and FLAGS.save_gradient_magnitude_histogram:
                    if iteration == 0:
                        print("Plotting gradient magnitude histograms")
                        plt.cla()
                        figs, axes = plt.subplots(nrows=2,
                                                  ncols=5,
                                                  figsize=(15 * 3, 15))
                        axes = axes.flatten()
                        for i, (gradient, variable) in enumerate(
                                zip(materialized_gradients,
                                    [x[1] for x in grads_and_vars])):
                            vname = variable.name.replace("/", "_")
                            name = "iteration_%d_gradient_%s" % (iteration,
                                                                 vname)
                            title = "Layer %s" % vname
                            magnitudes = [
                                abs(x) for x in list(gradient.flatten())
                            ]
                            axes[i].hist(magnitudes, bins='auto')
                            axes[i].set_title(title, fontsize=30)
                        figs.tight_layout()
                        plt.savefig(
                            "SparsifyHistogramOfGradientMagnitudes.png")
                        print("Done!")

            comm.Barrier()
            t_compute_end = time.time()
            compute_times += t_compute_end - t_compute_start

            t_accumulate_gradients_start = time.time()
            aggregate_and_apply_gradients(sess, model_variables, comm, rank,
                                          size, materialized_gradients,
                                          apply_gradients_placeholders,
                                          apply_gradients_op,
                                          previous_accuracy)
            t_accumulate_gradients_end = time.time()
            accumulate_gradients_times += t_accumulate_gradients_end - t_accumulate_gradients_start

            n_examples_processed += (size - 1) * FLAGS.batch_size
            iteration += 1
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        if tfFLAGS.network == 1:
            images, labels = cifar10.distorted_inputs()
            logits, fc1_w, fc1_b, fc2_w, fc2_b = MyModel.inference(images)
        else:
            images, labels = cifar10.distorted_inputs()
            logits, fc1_w, fc1_b, fc2_w, fc2_b = MyModel2.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

            # L2 regularization for the fully connected parameters.
        regularizers = (tf.nn.l2_loss(fc1_w) + tf.nn.l2_loss(fc1_b) + tf.nn.l2_loss(fc2_w) + tf.nn.l2_loss(fc2_b))

        # Add the regularization term to the loss.
        loss += 5e-4 * regularizers

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)    # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % tfFLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = tfFLAGS.log_frequency * tfFLAGS.batch_size / duration
                    sec_per_batch = float(duration / tfFLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                                'sec/batch)')
                    print_(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch))
        
        texts = ['conv1:', 'conv1Biases:', 'conv2:', 'conv2Biases:', 'local3:', 'local3Biases:', 'local4:', 'local4Biases:', 'softmax:', 'softmaxBiases:']
        total_parameters = 0; count = 0
        for variable in tf.trainable_variables():
            variable_parametes = 1
            for dim in variable.get_shape():
                    variable_parametes *= dim.value
            print('Number of hidden parameters of ' + texts[count], variable_parametes)
            total_parameters += variable_parametes
            count += 1
        print('Total Number of hidden parameters:', total_parameters)

        with tf.train.MonitoredTrainingSession(checkpoint_dir=tfFLAGS.train_dir,
                hooks=[tf.train.StopAtStepHook(last_step=tfFLAGS.max_steps), tf.train.NanTensorHook(loss),_LoggerHook()],
                config=tf.ConfigProto( device_count = {'GPU': 0}, log_device_placement=tfFLAGS.log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 9
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:

        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        #labels=labels+aa

        #for ii in range(10):
        #print(images)

        #print(labels)

        noise = tf.constant(np.zeros(128), tf.int32)
        fakelabels = (labels + noise) % 10
        # To insert noise
        '''fakerate=0.1
    fakenum=int(fakerate*128)
    selectindex=rd.sample(range(128),fakenum)
    noisearray=np.zeros(128)
    for i in range(fakenum):
      noisearray[selectindex[i]]=int(10*rd.random())
    tf.assign(noise,noisearray)'''

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, fakelabels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        #init = tf.global_variables_initializer()

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=10,  # save model every 10 seconds
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            #mon_sess.run(init)
            while not mon_sess.should_stop():
                # To insert noise
                fakerate = 0.5
                batch_size = 128
                fakenum = int(fakerate * batch_size)
                selectindex = rd.sample(range(batch_size), fakenum)
                noisearray = np.zeros(batch_size)
                aa = mon_sess.run(labels)
                for i in range(fakenum):
                    noisearray[selectindex[i]] = int(9 * rd.random()) + 1

                mon_sess.run(train_op, {noise: noisearray})
Esempio n. 10
0
def train(cifar10_data, epochs, L, learning_rate, scale3, Delta2, epsilon2,
          eps2_ratio, alpha, perturbFM, fgsm_eps, total_eps, logfile):
    logfile.write("fgsm_eps \t %g, LR \t %g, alpha \t %d , epsilon \t %d \n" %
                  (fgsm_eps, learning_rate, alpha, total_eps))
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        eps_benign = 1 / (1 + eps2_ratio) * (epsilon2)
        eps_adv = eps2_ratio / (1 + eps2_ratio) * (epsilon2)

        # Parameters Declarification
        #with tf.variable_scope('conv1') as scope:
        kernel1 = _variable_with_weight_decay(
            'kernel1',
            shape=[4, 4, 3, 128],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[AECODER_VARIABLES])
        biases1 = _bias_on_cpu('biases1', [128],
                               tf.constant_initializer(0.0),
                               collect=[AECODER_VARIABLES])

        shape = kernel1.get_shape().as_list()
        w_t = tf.reshape(kernel1, [-1, shape[-1]])
        w = tf.transpose(w_t)
        sing_vals = tf.svd(w, compute_uv=False)
        sensitivity = tf.reduce_max(sing_vals)
        gamma = 2 * Delta2 / (L * sensitivity
                              )  #2*3*(14*14 + 2)*16/(L*sensitivity)

        #with tf.variable_scope('conv2') as scope:
        kernel2 = _variable_with_weight_decay(
            'kernel2',
            shape=[5, 5, 128, 128],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases2 = _bias_on_cpu('biases2', [128],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('conv3') as scope:
        kernel3 = _variable_with_weight_decay(
            'kernel3',
            shape=[5, 5, 256, 256],
            stddev=np.sqrt(2.0 / (5 * 5 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases3 = _bias_on_cpu('biases3', [256],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('local4') as scope:
        kernel4 = _variable_with_weight_decay(
            'kernel4',
            shape=[int(image_size / 4)**2 * 256, hk],
            stddev=0.04,
            wd=0.004,
            collect=[CONV_VARIABLES])
        biases4 = _bias_on_cpu('biases4', [hk],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])
        #with tf.variable_scope('local5') as scope:
        kernel5 = _variable_with_weight_decay(
            'kernel5', [hk, 10],
            stddev=np.sqrt(2.0 /
                           (int(image_size / 4)**2 * 256)) / math.ceil(5 / 2),
            wd=0.0,
            collect=[CONV_VARIABLES])
        biases5 = _bias_on_cpu('biases5', [10],
                               tf.constant_initializer(0.1),
                               collect=[CONV_VARIABLES])

        #scale2 = tf.Variable(tf.ones([hk]))
        #beta2 = tf.Variable(tf.zeros([hk]))

        params = [
            kernel1, biases1, kernel2, biases2, kernel3, biases3, kernel4,
            biases4, kernel5, biases5
        ]
        ########

        # Build a Graph that computes the logits predictions from the
        # inference model.
        FM_h = tf.placeholder(tf.float32, [None, 14, 14, 128])
        noise = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
        adv_noise = tf.placeholder(tf.float32,
                                   [None, image_size, image_size, 3])

        x = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
        adv_x = tf.placeholder(tf.float32, [None, image_size, image_size, 3])

        # Auto-Encoder #
        Enc_Layer2 = EncLayer(inpt=adv_x,
                              n_filter_in=3,
                              n_filter_out=128,
                              filter_size=3,
                              W=kernel1,
                              b=biases1,
                              activation=tf.nn.relu)
        pretrain_adv = Enc_Layer2.get_train_ops2(xShape=tf.shape(adv_x)[0],
                                                 Delta=Delta2,
                                                 epsilon=epsilon2,
                                                 batch_size=L,
                                                 learning_rate=learning_rate,
                                                 W=kernel1,
                                                 b=biases1,
                                                 perturbFMx=adv_noise,
                                                 perturbFM_h=FM_h)
        Enc_Layer3 = EncLayer(inpt=x,
                              n_filter_in=3,
                              n_filter_out=128,
                              filter_size=3,
                              W=kernel1,
                              b=biases1,
                              activation=tf.nn.relu)
        pretrain_benign = Enc_Layer3.get_train_ops2(
            xShape=tf.shape(x)[0],
            Delta=Delta2,
            epsilon=epsilon2,
            batch_size=L,
            learning_rate=learning_rate,
            W=kernel1,
            b=biases1,
            perturbFMx=noise,
            perturbFM_h=FM_h)
        cost = tf.reduce_sum((Enc_Layer2.cost + Enc_Layer3.cost) / 2.0)
        ###

        x_image = x + noise
        y_conv = inference(x_image, FM_h, params)
        softmax_y_conv = tf.nn.softmax(y_conv)
        y_ = tf.placeholder(tf.float32, [None, 10])

        adv_x_image = adv_x + adv_noise
        y_adv_conv = inference(adv_x_image, FM_h, params)
        adv_y_ = tf.placeholder(tf.float32, [None, 10])

        # Calculate loss. Apply Taylor Expansion for the output layer
        perturbW = perturbFM * params[8]
        loss = cifar10.TaylorExp(y_conv, y_, y_adv_conv, adv_y_, L, alpha,
                                 perturbW)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        #pretrain_step = tf.train.AdamOptimizer(1e-4).minimize(pretrain_adv, global_step=global_step, var_list=[kernel1, biases1]);
        pretrain_var_list = tf.get_collection(AECODER_VARIABLES)
        train_var_list = tf.get_collection(CONV_VARIABLES)
        #print(pretrain_var_list)
        #print(train_var_list)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            pretrain_step = tf.train.AdamOptimizer(learning_rate).minimize(
                pretrain_adv + pretrain_benign,
                global_step=global_step,
                var_list=pretrain_var_list)
            train_op = cifar10.train(loss,
                                     global_step,
                                     learning_rate,
                                     _var_list=train_var_list)
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

        sess.run(kernel1.initializer)
        dp_epsilon = 0.005
        _gamma = sess.run(gamma)
        _gamma_x = Delta2 / L
        epsilon2_update = epsilon2 / (1.0 + 1.0 / _gamma + 1 / _gamma_x)
        print(epsilon2_update / _gamma + epsilon2_update / _gamma_x)
        print(epsilon2_update)
        delta_r = fgsm_eps * (image_size**2)
        _sensitivityW = sess.run(sensitivity)
        delta_h = _sensitivityW * (14**2)
        #delta_h = 1.0 * delta_r; #sensitivity*(14**2) = sensitivity*(\beta**2) can also be used
        #dp_mult = (Delta2/(L*epsilon2))/(delta_r / dp_epsilon) + (2*Delta2/(L*epsilon2))/(delta_h / dp_epsilon)
        #dp_mult = (Delta2/(L*epsilon2_update))/(delta_r / dp_epsilon) + (2*Delta2/(L*epsilon2_update))/(delta_h / dp_epsilon)
        dp_mult = (Delta2) / (L * epsilon2_update * (delta_h / 2 + delta_r))

        dynamic_eps = tf.placeholder(tf.float32)
        """y_test = inference(x, FM_h, params)
    softmax_y = tf.nn.softmax(y_test);
    c_x_adv = fgsm(x, softmax_y, eps=dynamic_eps/3, clip_min=-1.0, clip_max=1.0)
    x_adv = tf.reshape(c_x_adv, [L, image_size, image_size, 3])"""

        attack_switch = {
            'fgsm': True,
            'ifgsm': True,
            'deepfool': False,
            'mim': True,
            'spsa': False,
            'cwl2': False,
            'madry': True,
            'stm': False
        }

        ch_model_probs = CustomCallableModelWrapper(
            callable_fn=inference_test_input_probs,
            output_layer='probs',
            params=params,
            image_size=image_size,
            adv_noise=adv_noise)

        # define each attack method's tensor
        mu_alpha = tf.placeholder(tf.float32, [1])
        attack_tensor_dict = {}
        # FastGradientMethod
        if attack_switch['fgsm']:
            print('creating attack tensor of FastGradientMethod')
            fgsm_obj = FastGradientMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_fgsm = fgsm_obj.generate(x=x, eps=fgsm_eps, clip_min=-1.0, clip_max=1.0, ord=2) # testing now
            x_adv_test_fgsm = fgsm_obj.generate(x=x,
                                                eps=mu_alpha,
                                                clip_min=-1.0,
                                                clip_max=1.0)  # testing now
            attack_tensor_dict['fgsm'] = x_adv_test_fgsm

        # Iterative FGSM (BasicIterativeMethod/ProjectedGradientMethod with no random init)
        # default: eps_iter=0.05, nb_iter=10
        if attack_switch['ifgsm']:
            print('creating attack tensor of BasicIterativeMethod')
            ifgsm_obj = BasicIterativeMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_ifgsm = ifgsm_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_ifgsm = ifgsm_obj.generate(x=x,
                                                  eps=mu_alpha,
                                                  eps_iter=fgsm_eps / 3,
                                                  nb_iter=3,
                                                  clip_min=-1.0,
                                                  clip_max=1.0)
            attack_tensor_dict['ifgsm'] = x_adv_test_ifgsm

        # MomentumIterativeMethod
        # default: eps_iter=0.06, nb_iter=10
        if attack_switch['mim']:
            print('creating attack tensor of MomentumIterativeMethod')
            mim_obj = MomentumIterativeMethod(model=ch_model_probs, sess=sess)
            #x_adv_test_mim = mim_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, decay_factor=1.0, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_mim = mim_obj.generate(x=x,
                                              eps=mu_alpha,
                                              eps_iter=fgsm_eps / 3,
                                              nb_iter=3,
                                              decay_factor=1.0,
                                              clip_min=-1.0,
                                              clip_max=1.0)
            attack_tensor_dict['mim'] = x_adv_test_mim

        # MadryEtAl (Projected Grdient with random init, same as rand+fgsm)
        # default: eps_iter=0.01, nb_iter=40
        if attack_switch['madry']:
            print('creating attack tensor of MadryEtAl')
            madry_obj = MadryEtAl(model=ch_model_probs, sess=sess)
            #x_adv_test_madry = madry_obj.generate(x=x, eps=fgsm_eps, eps_iter=fgsm_eps/10, nb_iter=10, clip_min=-1.0, clip_max=1.0, ord=2)
            x_adv_test_madry = madry_obj.generate(x=x,
                                                  eps=mu_alpha,
                                                  eps_iter=fgsm_eps / 3,
                                                  nb_iter=3,
                                                  clip_min=-1.0,
                                                  clip_max=1.0)
            attack_tensor_dict['madry'] = x_adv_test_madry

        #====================== attack =========================

        #adv_logits, _ = inference(c_x_adv + W_conv1Noise, perturbFM, params)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess.run(init)

        # Start the queue runners.
        #tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(os.getcwd() + dirCheckpoint,
                                               sess.graph)

        # load the most recent models
        _global_step = 0
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            _global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        else:
            print('No checkpoint file found')

        T = int(int(math.ceil(D / L)) * epochs + 1)  # number of steps
        step_for_epoch = int(math.ceil(D / L))
        #number of steps for one epoch

        perturbH_test = np.random.laplace(0.0, 0, 14 * 14 * 128)
        perturbH_test = np.reshape(perturbH_test, [-1, 14, 14, 128])

        #W_conv1Noise = np.random.laplace(0.0, Delta2/(L*epsilon2), 32 * 32 * 3).astype(np.float32)
        #W_conv1Noise = np.reshape(_W_conv1Noise, [32, 32, 3])

        perturbFM_h = np.random.laplace(0.0,
                                        2 * Delta2 / (epsilon2_update * L),
                                        14 * 14 * 128)
        perturbFM_h = np.reshape(perturbFM_h, [-1, 14, 14, 128])

        #_W_adv = np.random.laplace(0.0, 0, 32 * 32 * 3).astype(np.float32)
        #_W_adv = np.reshape(_W_adv, [32, 32, 3])
        #_perturbFM_h_adv = np.random.laplace(0.0, 0, 10*10*128)
        #_perturbFM_h_adv = np.reshape(_perturbFM_h_adv, [10, 10, 128]);

        test_size = len(cifar10_data.test.images)
        #beta = redistributeNoise(os.getcwd() + '/LRP_0_25_v12.txt')
        #BenignLNoise = generateIdLMNoise(image_size, Delta2, eps_benign, L) #generateNoise(image_size, Delta2, eps_benign, L, beta);
        #AdvLnoise = generateIdLMNoise(image_size, Delta2, eps_adv, L)
        Noise = generateIdLMNoise(image_size, Delta2, epsilon2_update, L)
        #generateNoise(image_size, Delta2, eps_adv, L, beta);
        Noise_test = generateIdLMNoise(
            image_size, 0, epsilon2_update,
            L)  #generateNoise(image_size, 0, 2*epsilon2, test_size, beta);

        emsemble_L = int(L / 3)
        preT_epochs = 100
        pre_T = int(int(math.ceil(D / L)) * preT_epochs + 1)
        """logfile.write("pretrain: \n")
    for step in range(_global_step, _global_step + pre_T):
        d_eps = random.random()*0.5;
        batch = cifar10_data.train.next_batch(L); #Get a random batch.
        adv_images = sess.run(x_adv, feed_dict = {x: batch[0], dynamic_eps: d_eps, FM_h: perturbH_test})
        for iter in range(0, 2):
            adv_images = sess.run(x_adv, feed_dict = {x: adv_images, dynamic_eps: d_eps, FM_h: perturbH_test})
        #sess.run(pretrain_step, feed_dict = {x: batch[0], noise: AdvLnoise, FM_h: perturbFM_h});
        batch = cifar10_data.train.next_batch(L);
        sess.run(pretrain_step, feed_dict = {x: np.append(batch[0], adv_images, axis = 0), noise: Noise, FM_h: perturbFM_h});
        if step % int(25*step_for_epoch) == 0:
            cost_value = sess.run(cost, feed_dict={x: cifar10_data.test.images, noise: Noise_test, FM_h: perturbH_test})/(test_size*128)
            logfile.write("step \t %d \t %g \n"%(step, cost_value))
            print(cost_value)
    print('pre_train finished')"""

        _global_step = 0
        for step in xrange(_global_step, _global_step + T):
            start_time = time.time()
            d_eps = random.random() * 0.5
            batch = cifar10_data.train.next_batch(emsemble_L)
            #Get a random batch.
            y_adv_batch = batch[1]
            """adv_images = sess.run(x_adv, feed_dict = {x: batch[0], dynamic_eps: d_eps, FM_h: perturbH_test})
      for iter in range(0, 2):
          adv_images = sess.run(x_adv, feed_dict = {x: adv_images, dynamic_eps: d_eps, FM_h: perturbH_test})"""
            adv_images_ifgsm = sess.run(attack_tensor_dict['ifgsm'],
                                        feed_dict={
                                            x: batch[0],
                                            adv_noise: Noise,
                                            mu_alpha: [d_eps]
                                        })
            batch = cifar10_data.train.next_batch(emsemble_L)
            y_adv_batch = np.append(y_adv_batch, batch[1], axis=0)
            adv_images_mim = sess.run(attack_tensor_dict['mim'],
                                      feed_dict={
                                          x: batch[0],
                                          adv_noise: Noise,
                                          mu_alpha: [d_eps]
                                      })
            batch = cifar10_data.train.next_batch(emsemble_L)
            y_adv_batch = np.append(y_adv_batch, batch[1], axis=0)
            adv_images_madry = sess.run(attack_tensor_dict['madry'],
                                        feed_dict={
                                            x: batch[0],
                                            adv_noise: Noise,
                                            mu_alpha: [d_eps]
                                        })
            adv_images = np.append(np.append(adv_images_ifgsm,
                                             adv_images_mim,
                                             axis=0),
                                   adv_images_madry,
                                   axis=0)

            batch = cifar10_data.train.next_batch(L)
            #Get a random batch.

            sess.run(pretrain_step,
                     feed_dict={
                         x: batch[0],
                         adv_x: adv_images,
                         adv_noise: Noise_test,
                         noise: Noise,
                         FM_h: perturbFM_h
                     })
            _, loss_value = sess.run(
                [train_op, loss],
                feed_dict={
                    x: batch[0],
                    y_: batch[1],
                    adv_x: adv_images,
                    adv_y_: y_adv_batch,
                    noise: Noise,
                    adv_noise: Noise_test,
                    FM_h: perturbFM_h
                })
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # report the result periodically
            if step % (50 * step_for_epoch) == 0 and step >= (300 *
                                                              step_for_epoch):
                '''predictions_form_argmax = np.zeros([test_size, 10])
          softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: cifar10_data.test.images, noise: Noise_test, FM_h: perturbH_test})
          argmax_predictions = np.argmax(softmax_predictions, axis=1)
          """for n_draws in range(0, 2000):
            _BenignLNoise = generateIdLMNoise(image_size, Delta2, epsilon2, L)
            _perturbFM_h = np.random.laplace(0.0, 2*Delta2/(epsilon2*L), 14*14*128)
            _perturbFM_h = np.reshape(_perturbFM_h, [-1, 14, 14, 128]);"""
          for j in range(test_size):
            pred = argmax_predictions[j]
            predictions_form_argmax[j, pred] += 2000;
          """softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: cifar10_data.test.images, noise: _BenignLNoise, FM_h: _perturbFM_h})
            argmax_predictions = np.argmax(softmax_predictions, axis=1)"""
          final_predictions = predictions_form_argmax;
          is_correct = []
          is_robust = []
          for j in range(test_size):
              is_correct.append(np.argmax(cifar10_data.test.labels[j]) == np.argmax(final_predictions[j]))
              robustness_from_argmax = robustness.robustness_size_argmax(counts=predictions_form_argmax[j],eta=0.05,dp_attack_size=fgsm_eps, dp_epsilon=1.0, dp_delta=0.05, dp_mechanism='laplace') * dp_mult
              is_robust.append(robustness_from_argmax >= fgsm_eps)
          acc = np.sum(is_correct)*1.0/test_size
          robust_acc = np.sum([a and b for a,b in zip(is_robust, is_correct)])*1.0/np.sum(is_robust)
          robust_utility = np.sum(is_robust)*1.0/test_size
          log_str = "step: {:.1f}\t epsilon: {:.1f}\t benign: {:.4f} \t {:.4f} \t {:.4f} \t {:.4f} \t".format(step, total_eps, acc, robust_acc, robust_utility, robust_acc*robust_utility)'''

                #===================adv samples=====================
                log_str = "step: {:.1f}\t epsilon: {:.1f}\t".format(
                    step, total_eps)
                """adv_images_dict = {}
          for atk in attack_switch.keys():
              if attack_switch[atk]:
                  adv_images_dict[atk] = sess.run(attack_tensor_dict[atk], feed_dict ={x:cifar10_data.test.images})
          print("Done with the generating of Adversarial samples")"""
                #===================adv samples=====================
                adv_acc_dict = {}
                robust_adv_acc_dict = {}
                robust_adv_utility_dict = {}
                test_bach_size = 5000
                for atk in attack_switch.keys():
                    print(atk)
                    if atk not in adv_acc_dict:
                        adv_acc_dict[atk] = -1
                        robust_adv_acc_dict[atk] = -1
                        robust_adv_utility_dict[atk] = -1
                    if attack_switch[atk]:
                        test_bach = cifar10_data.test.next_batch(
                            test_bach_size)
                        adv_images_dict = sess.run(attack_tensor_dict[atk],
                                                   feed_dict={
                                                       x: test_bach[0],
                                                       adv_noise: Noise_test,
                                                       mu_alpha: [fgsm_eps]
                                                   })
                        print("Done adversarial examples")
                        ### PixelDP Robustness ###
                        predictions_form_argmax = np.zeros(
                            [test_bach_size, 10])
                        softmax_predictions = sess.run(softmax_y_conv,
                                                       feed_dict={
                                                           x: adv_images_dict,
                                                           noise: Noise,
                                                           FM_h: perturbFM_h
                                                       })
                        argmax_predictions = np.argmax(softmax_predictions,
                                                       axis=1)
                        for n_draws in range(0, 1000):
                            _BenignLNoise = generateIdLMNoise(
                                image_size, Delta2, epsilon2_update, L)
                            _perturbFM_h = np.random.laplace(
                                0.0, 2 * Delta2 / (epsilon2_update * L),
                                14 * 14 * 128)
                            _perturbFM_h = np.reshape(_perturbFM_h,
                                                      [-1, 14, 14, 128])
                            if n_draws == 500:
                                print("n_draws = 500")
                            for j in range(test_bach_size):
                                pred = argmax_predictions[j]
                                predictions_form_argmax[j, pred] += 1
                            softmax_predictions = sess.run(
                                softmax_y_conv,
                                feed_dict={
                                    x: adv_images_dict,
                                    noise: (_BenignLNoise / 10 + Noise),
                                    FM_h: perturbFM_h
                                }) * sess.run(
                                    softmax_y_conv,
                                    feed_dict={
                                        x: adv_images_dict,
                                        noise: Noise,
                                        FM_h: (_perturbFM_h / 10 + perturbFM_h)
                                    })
                            #softmax_predictions = sess.run(softmax_y_conv, feed_dict={x: adv_images_dict, noise: (_BenignLNoise), FM_h: perturbFM_h}) * sess.run(softmax_y_conv, feed_dict={x: adv_images_dict, noise: Noise, FM_h: (_perturbFM_h)})
                            argmax_predictions = np.argmax(softmax_predictions,
                                                           axis=1)
                        final_predictions = predictions_form_argmax
                        is_correct = []
                        is_robust = []
                        for j in range(test_bach_size):
                            is_correct.append(
                                np.argmax(test_bach[1][j]) == np.argmax(
                                    final_predictions[j]))
                            robustness_from_argmax = robustness.robustness_size_argmax(
                                counts=predictions_form_argmax[j],
                                eta=0.05,
                                dp_attack_size=fgsm_eps,
                                dp_epsilon=dp_epsilon,
                                dp_delta=0.05,
                                dp_mechanism='laplace') * dp_mult
                            is_robust.append(
                                robustness_from_argmax >= fgsm_eps)
                        adv_acc_dict[atk] = np.sum(
                            is_correct) * 1.0 / test_bach_size
                        robust_adv_acc_dict[atk] = np.sum([
                            a and b for a, b in zip(is_robust, is_correct)
                        ]) * 1.0 / np.sum(is_robust)
                        robust_adv_utility_dict[atk] = np.sum(
                            is_robust) * 1.0 / test_bach_size
                        ##############################
                for atk in attack_switch.keys():
                    if attack_switch[atk]:
                        # added robust prediction
                        log_str += " {}: {:.4f} {:.4f} {:.4f} {:.4f}".format(
                            atk, adv_acc_dict[atk], robust_adv_acc_dict[atk],
                            robust_adv_utility_dict[atk],
                            robust_adv_acc_dict[atk] *
                            robust_adv_utility_dict[atk])
                print(log_str)
                logfile.write(log_str + '\n')

            # Save the model checkpoint periodically.
            if step % (10 * step_for_epoch) == 0 and (step > _global_step):
                num_examples_per_step = L
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)
                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))
            """if step % (50*step_for_epoch) == 0 and (step >= 900*step_for_epoch):
Esempio n. 11
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()
                self._temp_time = time.time()
                self._temp_step = -1

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                ###############################################################
                if time.time() - self._temp_time > 1:
                    writer.writerow([
                        '%f' % (time.time()), 1,
                        '%f' % (self._step - self._temp_step), 1, env
                    ])
                    self._temp_time = time.time()
                    self._temp_step = self._step
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 12
0
def train(graph=tf.Graph()):
    dt = _prepare_checkpoint_and_flow(graph)
    """Train CIFAR-10 for a number of steps."""
    with graph.as_default():
        with graph.device('/cpu:0'):
            global_step = tf.contrib.framework.get_or_create_global_step()

            # Get images and labels for CIFAR-10.
            images, labels = cifar10.distorted_inputs()

        logits = cifar10.inference(dt,
                                   images,
                                   is_compressed=FLAGS.checkpoint_version > 0)
        dt.flow.cost = dict()  # empty cost

        # Calculate loss.
        loss = cifar10.loss(dt, logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op, variable_averages = cifar10.train(
                dt, loss, global_step, l1_penalty=FLAGS.l1_penalty)

        # Restore the moving average version of the learned variables for eval.
        variables_to_restore = variable_averages.variables_to_restore()

        # Get all bind-to variables (not using moving averages)
        var_gamma = []
        var_beta = []
        var_weights = []
        var_biases = []
        for name, v in variables_to_restore.items():
            if re.match(r'.*gamma/.*', name):
                var_gamma.append(v)
            if re.match(r'.*beta/.*', name):
                var_beta.append(v)
            if re.match(r'.*weights/.*', name):
                var_weights.append(v)
            if re.match(r'.*biases/.*', name):
                var_biases.append(v)

        group_lasso, sparsity = dt.regularization(var_gamma)
        reset_global_step = tf.assign(global_step, tf.zeros_like(global_step))
        prune_op = dt.prune(tf.get_collection('is prune'))

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def after_create_session(self, sess, coord):
                pass

            def before_run(self, run_context):
                self._step += 1
                if self._step % FLAGS.log_frequency == 0:
                    return tf.train.SessionRunArgs(
                        [loss, group_lasso, sparsity])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results[0]
                    group_lasso = run_values.results[1]
                    sparsity = run_values.results[2]
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f, group lasso = %.2f, fake sparsity = %.2f, (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.now(), self._step, loss_value, group_lasso,
                           sparsity, examples_per_sec, sec_per_batch))

            def end(self, sess):
                sess.run(prune_op)
                # save flow graph
                dt.flow.prune()
                dt.flow.print_summary()
                dt.save_flow(
                    _get_train_dir(dt.checkpoint_version) +
                    '/pruned_flow_graph.pkl')

        config = tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)
        config.gpu_options.allow_growth = True
        #config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=_get_train_dir(dt.checkpoint_version),
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=config) as mon_sess:

            # training and sparisifying
            while not mon_sess.should_stop():
                mon_sess.run(train_op)

        # get compressed network
        if not FLAGS.squeeze_model:
            return tf.get_default_graph()
        else:
            dt.is_compressed = True
            dt.cached_flow = dt.flow
            dt.flow = NeuralFlow()

        with tf.variable_scope('compressed') as scope:
            logits_c = cifar10.inference(dt, images, is_compressed=True)
            loss = cifar10.loss(dt, logits_c, labels)

        all_var = variables_to_restore.values()
        all_trainable_var_compressed = [
            v for v in tf.trainable_variables()
            if v.op.name.startswith('compressed')
        ]
        _ = variable_averages.apply(all_trainable_var_compressed)
        all_var_compressed = [
            v for v in tf.global_variables()
            if v.op.name.startswith('compressed')
        ]

        #print(all_var)
        #print(all_var_compressed)
        assign_ops, assign_ops2 = Transformer.compress_to(
            dt, all_var, all_var_compressed)

        #print("Variables to restore from checkpoints: ", variables_to_restore)
        saver = tf.train.Saver(variables_to_restore)
        ckpt = tf.train.get_checkpoint_state(
            _get_train_dir(dt.checkpoint_version))

        variables_to_save = {}
        for v in all_var_compressed:
            vname = v.op.name.replace('compressed/', '')
            variables_to_save[vname] = v

        variables_to_save[global_step.op.name] = global_step
        saver_c = tf.train.Saver(variables_to_save)

        sess = tf.Session()
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            sess.run(dt.shadow[logits.op.name])  # rebuild betas
            sess.run(assign_ops)
            sess.run(assign_ops2)
            sess.run(reset_global_step)
        saver_c.save(
            sess,
            _get_train_dir(dt.checkpoint_version + 1) + '/model.ckpt-0')
        sess.close()

        g = tf.get_default_graph()
        return g
Esempio n. 13
0
def train_simult():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        tl_global_step = tf.Variable(0, trainable=False)
        tlm_global_step = tf.Variable(0, trainable=False)
        tlms_global_step = tf.Variable(0, trainable=False)
        tm_global_step = tf.Variable(0, trainable=False)
        ts_global_step = tf.Variable(0, trainable=False)
        l_global_step = tf.Variable(0, trainable=False)
        m_global_step = tf.Variable(0, trainable=False)
        s_global_step = tf.Variable(0, trainable=False)
        #    tlm_l_global_step = tf.Variable(0, trainable=False)
        #    tlms_l_global_step = tf.Variable(0, trainable=False)
        #    m_l_global_step = tf.Variable(0, trainable=False)
        #    s_l_global_step = tf.Variable(0, trainable=False)

        # Get images and labels for CIFAR-10.
        with tf.device('/cpu:0'):
            with tf.variable_scope('train') as scope:
                images, labels = cifar10.distorted_inputs()
            with tf.variable_scope('eval') as scope:
                images_ev, labels_ev = cifar10.inputs(eval_data='test')

        with tf.variable_scope('model') as scope:
            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = cifar10.inference(images)
            targets = cifar10.mix(cifar10.multinomial(logits), labels)

            # Calculate loss.
            loss = cifar10.loss(logits, labels)

            # Compute logits and calculate predictions for validation error
            scope.reuse_variables()
            logits_ev = cifar10.inference(images_ev)
            top_k_op = tf.nn.in_top_k(logits_ev, labels_ev, 1)

            # Build a graph that trains the model with one batch of examples
            # and updates the model parameters.
            train_op = cifar10.train(loss, global_step)

        with tf.variable_scope('t.l') as scope:
            # Student graph that computes the logits predictions from the
            # inference model.
            tl_logits = cifar10.inference(images)
            tl_targets = cifar10.mix(cifar10.multinomial(tl_logits), labels)

            # Calculate loss according to multinomial sampled target predictions,
            # equally weighted against original loss with labels.
            #lg_loss = tf.add(loss, cifar10.loss(lg_logits, targets))
            tl_loss = cifar10.loss(tl_logits, targets)

            scope.reuse_variables()
            tl_logits_ev = cifar10.inference(images_ev)
            tl_top_k_op = tf.nn.in_top_k(tl_logits_ev, labels_ev, 1)
            tl_train_op = cifar10.train(tl_loss, tl_global_step)

        with tf.variable_scope('t.l.m') as scope:
            tlm_logits = cifar10.inference_vars(images, 48, 48, 192, 96)
            tlm_targets = cifar10.mix(cifar10.multinomial(tlm_logits), labels)
            tlm_loss = cifar10.loss(tlm_logits, tl_targets)
            scope.reuse_variables()
            tlm_logits_ev = cifar10.inference_vars(images_ev, 48, 48, 192, 96)
            tlm_top_k_op = tf.nn.in_top_k(tlm_logits_ev, labels_ev, 1)
            tlm_train_op = cifar10.train(tlm_loss, tlm_global_step)

        with tf.variable_scope('t.l.m.s') as scope:
            tlms_logits = cifar10.inference_vars(images, 32, 32, 96, 48)
            tlms_loss = cifar10.loss(tlms_logits, tlm_targets)
            scope.reuse_variables()
            tlms_logits_ev = cifar10.inference_vars(images_ev, 32, 32, 96, 48)
            tlms_top_k_op = tf.nn.in_top_k(tlms_logits_ev, labels_ev, 1)
            tlms_train_op = cifar10.train(tlms_loss, tlms_global_step)

        with tf.variable_scope('t.m') as scope:
            tm_logits = cifar10.inference_vars(images, 48, 48, 192, 96)
            tm_loss = cifar10.loss(tm_logits, targets)
            scope.reuse_variables()
            tm_logits_ev = cifar10.inference_vars(images_ev, 48, 48, 192, 96)
            tm_top_k_op = tf.nn.in_top_k(tm_logits_ev, labels_ev, 1)
            tm_train_op = cifar10.train(tm_loss, tm_global_step)

        with tf.variable_scope('t.s') as scope:
            ts_logits = cifar10.inference_vars(images, 32, 32, 96, 48)
            ts_loss = cifar10.loss(ts_logits, targets)
            scope.reuse_variables()
            ts_logits_ev = cifar10.inference_vars(images_ev, 32, 32, 96, 48)
            ts_top_k_op = tf.nn.in_top_k(ts_logits_ev, labels_ev, 1)
            ts_train_op = cifar10.train(ts_loss, ts_global_step)

        # Large sized model trained on labels
        with tf.variable_scope('l') as scope:
            l_logits = cifar10.inference_vars(images)
            l_loss = cifar10.loss(l_logits, labels)
            scope.reuse_variables()
            l_logits_ev = cifar10.inference_vars(images_ev)
            l_top_k_op = tf.nn.in_top_k(l_logits_ev, labels_ev, 1)
            l_train_op = cifar10.train(l_loss, l_global_step)

        # Medium sized model trained on labels
        with tf.variable_scope('m') as scope:
            m_logits = cifar10.inference_vars(images, 48, 48, 192, 96)
            m_loss = cifar10.loss(m_logits, labels)
            scope.reuse_variables()
            m_logits_ev = cifar10.inference_vars(images_ev, 48, 48, 192, 96)
            m_top_k_op = tf.nn.in_top_k(m_logits_ev, labels_ev, 1)
            m_train_op = cifar10.train(m_loss, m_global_step)

        # Small sized model trained on labels
        with tf.variable_scope('s') as scope:
            s_logits = cifar10.inference_vars(images, 32, 32, 96, 48)
            s_loss = cifar10.loss(s_logits, labels)
            scope.reuse_variables()
            s_logits_ev = cifar10.inference_vars(images_ev, 32, 32, 96, 48)
            s_top_k_op = tf.nn.in_top_k(s_logits_ev, labels_ev, 1)
            s_train_op = cifar10.train(s_loss, s_global_step)

        # Medium sized model trained on large model, delayed start


#    with tf.variable_scope('t.l.m_late') as scope:
#      tlm_l_logits = cifar10.inference_vars(images, 48, 48, 192, 96)
#      tlm_l_loss = cifar10.loss(tlm_l_logits, tl_targets)
#      scope.reuse_variables()
#      tlm_l_logits_ev = cifar10.inference_vars(images_ev, 48, 48, 192, 96)
#      tlm_l_top_k_op = tf.nn.in_top_k(tlm_l_logits_ev, labels_ev, 1)
#      tlm_l_train_op = cifar10.train(tlm_l_loss, tlm_l_global_step)

# Small sized model trained on medium model, delayed start
#    with tf.variable_scope('t.l.m.s_late') as scope:
#      tlms_l_logits = cifar10.inference_vars(images, 32, 32, 96, 48)
#      tlms_l_loss = cifar10.loss(tlms_l_logits, tlm_targets)
#      scope.reuse_variables()
#      tlms_l_logits_ev = cifar10.inference_vars(images_ev, 32, 32, 96, 48)
#      tlms_l_top_k_op = tf.nn.in_top_k(tlms_l_logits_ev, labels_ev, 1)
#      tlms_l_train_op = cifar10.train(tlms_l_loss, tlms_l_global_step)

# Medium sized model trained on labels, delayed start
#    with tf.variable_scope('m_late') as scope:
#      m_l_logits = cifar10.inference_vars(images, 48, 48, 192, 96)
#      m_l_loss = cifar10.loss(m_l_logits, labels)
#      scope.reuse_variables()
#      m_l_logits_ev = cifar10.inference_vars(images_ev, 48, 48, 192, 96)
#      m_l_top_k_op = tf.nn.in_top_k(m_l_logits_ev, labels_ev, 1)
#      m_l_train_op = cifar10.train(m_l_loss, m_l_global_step)

# Small sized model trained on labels, delayed start
#    with tf.variable_scope('s_late') as scope:
#      s_l_logits = cifar10.inference_vars(images, 32, 32, 96, 48)
#      s_l_loss = cifar10.loss(s_l_logits, labels)
#      scope.reuse_variables()
#      s_l_logits_ev = cifar10.inference_vars(images_ev, 32, 32, 96, 48)
#      s_l_top_k_op = tf.nn.in_top_k(s_l_logits_ev, labels_ev, 1)
#      s_l_train_op = cifar10.train(s_l_loss, s_l_global_step)

# Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        num_examples = 10000
        num_iter = int(np.floor(num_examples / FLAGS.batch_size))
        total_sample_count = (num_iter) * FLAGS.batch_size
        print(num_iter, total_sample_count)

        accuracy = []
        losses = []

        #    if True:
        #      for step in range(3):
        #        M1, M2 = sess.run([mask, mask_neg])
        #        print(np.array([M1, M2]).T)
        #      for step in range(10000):
        #        V0, V1, V2 = sess.run([images_ev[0,0,0,0], images_ev[1,0,0,0], images_ev[2,0,0,0]])
        #        print(step, V0, V1, V2)

        sm_val = 0.0
        md_val = 0.0
        sms_val = 0.0
        mds_val = 0.0
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            if step < 5000:
                (_, loss_value, _l, l_loss_val, _m, m_loss_val, _s, s_loss_val,
                 _tl, tl_loss_val, _tm, tm_loss_val, _ts,
                 ts_loss_val) = sess.run([
                     train_op, loss, l_train_op, l_loss, m_train_op, m_loss,
                     s_train_op, s_loss, tl_train_op, tl_loss, tm_train_op,
                     tm_loss, ts_train_op, ts_loss
                 ])
            elif step < 10000:
                (_, loss_value, _l, l_loss_val, _m, m_loss_val, _s, s_loss_val,
                 _tl, tl_loss_val, _tm, tm_loss_val, _ts, ts_loss_val, _tlm,
                 tlm_loss_val) = sess.run([
                     train_op, loss, l_train_op, l_loss, m_train_op, m_loss,
                     s_train_op, s_loss, tl_train_op, tl_loss, tm_train_op,
                     tm_loss, ts_train_op, ts_loss, tlm_train_op, tlm_loss
                 ])
            else:
                (_, loss_value, _l, l_loss_val, _m, m_loss_val, _s, s_loss_val,
                 _tl, tl_loss_val, _tm, tm_loss_val, _ts, ts_loss_val, _tlm,
                 tlm_loss_val, _tlms, tlms_loss_val) = sess.run([
                     train_op, loss, l_train_op, l_loss, m_train_op, m_loss,
                     s_train_op, s_loss, tl_train_op, tl_loss, tm_train_op,
                     tm_loss, ts_train_op, ts_loss, tlm_train_op, tlm_loss,
                     tlms_train_op, tlms_loss
                 ])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            #losses.append(np.array([loss_value, lg_loss_value, md_val, sm_val]))

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f, lg_loss = %.2f, md_loss = %.2f, '
                    'sm_loss = %.2f, (%.1f examples/sec; %.3f sec/batch)')
                print(
                    format_str %
                    (datetime.now(), step, loss_value, tl_loss_val, m_loss_val,
                     s_loss_val, examples_per_sec, sec_per_batch))

            if step % 500 == 0:
                #summary_str = sess.run(summary_op)
                #summary_writer.add_summary(summary_str, step)
                true_count = np.zeros(9)
                for eval_step in xrange(num_iter):
                    predictions = sess.run([
                        top_k_op, l_top_k_op, m_top_k_op, s_top_k_op,
                        tl_top_k_op, tm_top_k_op, ts_top_k_op, tlm_top_k_op,
                        tlms_top_k_op
                    ])
                    predictions = np.array(predictions)
                    true_count += np.sum(predictions, axis=1)
                    #eval_step += 1

                precision = true_count / total_sample_count
                print(precision)
                accuracy.append(precision)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               'model_student.ckpt')
                saver.save(sess,
                           checkpoint_path,
                           global_step=step,
                           latest_filename='checkpoint_student')

                eval_history = np.array(accuracy).T
                #loss_history = np.array(losses).T
                np.save(
                    os.path.join(FLAGS.train_dir,
                                 'eval_history%s' % date.today()),
                    eval_history)
Esempio n. 14
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        st_global_step = tf.Variable(0, trainable=False)

        images = tf.placeholder(tf.float32,
                                shape=(None, IMAGE_HEIGHT, IMAGE_WIDTH,
                                       IMAGE_DEPTH))
        logits = tf.placeholder(tf.float32, shape=(None, NUM_CLASSES))
        labels = tf.placeholder(tf.int32, shape=None)
        targets = cifar10.multinomial(logits)
        #targets = cifar10.multinomial(logits, labels)

        with tf.variable_scope('student') as s_scope:
            # Build a Graph that computes the logits predictions from the
            # inference model.
            st_logits = cifar10.inference(images)

            # Calculate loss.
            st_loss = cifar10.loss(st_logits, targets)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        st_train_op = cifar10.train(st_loss, st_global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        images_path = os.path.join(FLAGS.data_dir, 'img.npz')
        logits_path = os.path.join(FLAGS.train_dir, 'log.npz')

        if not tf.gfile.Exists(images_path):
            raise ValueError('Failed to find file: ' + images_path)
        if not tf.gfile.Exists(logits_path):
            raise ValueError('Failed to find file: ' + logits_path)

        with np.load(images_path) as data:
            images_set = data['images_set']
            print('images_set shape type ', images_set.shape, images_set.dtype)
        with np.load(logits_path) as data:
            logits_set = data['logits_set']

        data_set = Dataset(images_set, logits_set)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            #      feed_dict = fill_feed_dict(data_set, images, targets)
            feed_dict = fill_feed_dict(data_set, images, logits)
            _, st_loss_value = sess.run([st_train_op, st_loss],
                                        feed_dict=feed_dict)
            duration = time.time() - start_time

            assert not np.isnan(
                st_loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, st_loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, st_loss_value,
                                    examples_per_sec, sec_per_batch))

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               'model_student.ckpt')
                saver.save(sess,
                           checkpoint_path,
                           global_step=step,
                           latest_filename='checkpoint_student')
Esempio n. 15
0
    #                                                   batch_size=3)
    global_step = tf.contrib.framework.get_or_create_global_step()

    y_pred = cifar10.inference(img_batch)
    # label_batch = tf.cast(label_batch, tf.float32)
    # loss = tf.nn.l2_loss(y_pred-label_batch)
    label_batch = tf.cast(label_batch, tf.int64)
    label_batch = tf.reshape(label_batch, [3])
    # label_batch = tf.reduce_sum(label_batch, axis=-1)
    # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
    #     logits=y_pred, labels=label_batch, name='cross_entropy_per_example')
    loss = cifar10.loss(y_pred, label_batch)
    # loss_mean = tf.reduce_mean(loss)

    # train_op = tf.train.AdamOptimizer().minimize(loss)
    train_op = cifar10.train(loss, global_step)
    #初始化所有的op
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        #启动队列
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(0, 3):
            # val, l= sess.run([img_batch, label_batch])
            _, y_pred = sess.run([train_op, y_pred])
            #l = to_categorical(l, 12)
            # print(val.shape, l)
            print y_pred
            # print loss_val
            print("Worked!!!")
Esempio n. 16
0
def train():
  """Train CIFAR-10 for a number of steps."""

  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.Graph().as_default(), tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):
      global_step = tf.contrib.framework.get_or_create_global_step()
      # Get images and labels for CIFAR-10.
      images, labels = cifar10.distorted_inputs()

      # Build a Graph that computes the logits predictions from the
      # inference model.
      logits = cifar10.inference(images)

      # Calculate loss.
      loss = cifar10.loss(logits, labels)

      # Build a Graph that trains the model with one batch of examples and
      # updates the model parameters.
      train_op = cifar10.train(loss, global_step)

      class _LoggerHook(tf.train.SessionRunHook):
        """Logs loss and runtime."""

        def begin(self):
          self._step = -1

        def before_run(self, run_context):
          self._step += 1
          self._start_time = time.time()
          return tf.train.SessionRunArgs(loss)  # Asks for loss value.

        def after_run(self, run_context, run_values):
          duration = time.time() - self._start_time
          loss_value = run_values.results
          if self._step % 10 == 0:
            num_examples_per_step = FLAGS.batch_size
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = float(duration)

            format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print (format_str % (datetime.now(), self._step, loss_value,
                                 examples_per_sec, sec_per_batch))

      with tf.train.MonitoredTrainingSession(master=server.target,
          is_chief=(FLAGS.task_index == 0),
          checkpoint_dir=FLAGS.train_dir,
          hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                 tf.train.NanTensorHook(loss),
                 _LoggerHook()],
          config=tf.ConfigProto(
              log_device_placement=FLAGS.log_device_placement)) as mon_sess:
        while not mon_sess.should_stop():
          mon_sess.run(train_op)
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()
        #eval_data = FLAGS.eval_data == 'test'

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Create a saver.
        #saver = tf.train.Saver(tf.all_variables())
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation based on the TF collection of Summaries.
        #summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        #summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

        #summary_writer =  tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        # Create the exports and sessions repositories if they don't exist
        export_folder_name = 'exports - loss(' + FLAGS.loss + ') le ' + time.strftime(
            "%d-%m-%Y à %H:%M")
        session_folder_name = 'sessions - loss(' + FLAGS.loss + ') le ' + time.strftime(
            "%d-%m-%Y à %H:%M")
        make_sure_path_exists(export_folder_name)
        make_sure_path_exists(session_folder_name)

        # Initialize two different csv files
        test_csv_file = export_folder_name + '/test.csv'
        init_test_csv(test_csv_file)

        tps1 = time.time()

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 250 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss (%s) = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, FLAGS.loss, loss_value,
                       examples_per_sec, sec_per_batch))

            # if step % 100 == 0:
            #   summary_str = sess.run(summary_op)
            #   summary_writer.add_summary(summary_str, step)
            #tf.get_variable_scope().reuse_variables()

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

                with tf.Graph().as_default() as g:
                    # Get images and labels for CIFAR-10.
                    #eval_data = FLAGS.eval_data == 'test'
                    images_eval, labels_eval = cifar10.inputs(eval_data=True)

                    # Build a Graph that computes the logits predictions from the
                    # inference model.
                    logits_eval = cifar10.inference(images_eval)

                    # Calculate predictions.
                    top_k_op = tf.nn.in_top_k(logits_eval, labels_eval, 1)
                    variable_averages = tf.train.ExponentialMovingAverage(
                        cifar10.MOVING_AVERAGE_DECAY)
                    variables_to_restore = variable_averages.variables_to_restore(
                    )
                    saver = tf.train.Saver(variables_to_restore)
                    precision = eval_once(saver, top_k_op)
                    nb_epochs = int((step + 1) * 128 / 50000)
                    csv_writerow(test_csv_file,
                                 [FLAGS.loss] + [step] + [nb_epochs] +
                                 [precision] + [time.time() - tps1])
                    saver.save(
                        sess, session_folder_name +
                        '/Session-Iteration-%d-epoch-%d' % (step, nb_epochs))
Esempio n. 18
0
def train():
  #setConfig() # Already set in main
  config = network_config.getConfig()
  train_dir = config['train_dir']
  max_steps = config['max_steps']
  log_device_placement = config['log_device_placement']
  batch_size = config['batch_size']
    
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.       
    
    # 20 layer network
#     logits = cifar10_model.buildResidualStyleNetwork(images, is_train_phase = True)
                                                     
    # 56 layers
#     logits = cifar10_model.buildResidualStyleNetwork(images, is_train_phase = True, 
#                                                      numStackedBlocks = 9)

    
    logits = cifar10_model.buildNetworkWithVariableScope(images, 
                          is_train_phase = True, 
                          gateType = MywayFFLayer.HIGHWAY_GATE, 
                          numStackedBlocks = 9)
    
    # Calculate loss.
    loss = cifar10.loss(logits, labels)
    print('Loss used logits from cifar10_model')
    
    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.    
    lr = tf.placeholder(tf.float32, [], "learning_rate")
    train_op = cifar10.train(loss, global_step, lr)   

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=log_device_placement))
    sess.run(init)

    # Create a saver.
    # Vinh: this should save the moving averages of batch mean and variance
    # =============== Maybe not, I need to check it now ============================= 
    saver = tf.train.Saver(tf.all_variables())
    #saver = tf.train.Saver() # What is the difference here?


    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(train_dir,
                                            graph_def=sess.graph_def)    
    
    for step in xrange(max_steps):
      start_time = time.time()
      
      # Vinh: change learning rates at steps 32k and 48k, terminating
      # at step 64k (counts from 1) (as in ResNet paper)      
      feed_dict = {lr : 0.1}      
      # Not reducing the learning rate for now
      if (step + 1) == 32000:
        feed_dict = {lr : 0.01}
      elif (step + 1) == 48000:
        feed_dict = {lr : 0.001}
      _, loss_value = sess.run([train_op, loss], feed_dict = feed_dict)
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        # Vinh: If I monitor the learning rate in cifar10.py, I'd need to
        # pass the feed_dict above to the running of summary_op 
        summary_str = sess.run(summary_op) 
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == max_steps:
        checkpoint_path = os.path.join(train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 19
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

      # Get images and labels for CIFAR-10.\
    with tf.device('/cpu:0'):
	  images, labels = cifar10.distorted_inputs()      
	  images_eval, labels_eval = cifar10.inputs(eval_data=True)

	# Build a Graph that computes the logits predictions from the
	# inference model.
	logits = cifar10.inference(images)
	scope.reuse_variables()
	logits_eval = cifar10.inference(images_eval)

	# Calculate loss.
	loss = cifar10.loss(logits, labels)

	# For evaluation
	top_k      = tf.nn.in_top_k (logits,      labels,      1)
	top_k_eval = tf.nn.in_top_k (logits_eval, labels_eval, 1)

	# Add precision summary
	summary_train_prec = tf.placeholder(tf.float32)
	summary_eval_prec  = tf.placeholder(tf.float32)
	tf.scalar_summary('precision/train', summary_train_prec)
	tf.scalar_summary('precision/eval',  summary_eval_prec)

	# Build a Graph that trains the model with one batch of examples and
	# updates the model parameters.
	train_op = cifar10.train(loss, global_step)

	# Create a saver.
	saver = tf.train.Saver(tf.all_variables())

	# Build the summary operation based on the TF collection of Summaries.
	summary_op = tf.merge_all_summaries()

	# Build an initialization operation to run below.
	init = tf.initialize_all_variables()

	# Start running operations on the Graph.
	sess = tf.Session(config=tf.ConfigProto(
	log_device_placement=FLAGS.log_device_placement))
	sess.run(init)

	# Start the queue runners.
	tf.train.start_queue_runners(sess=sess)

	summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
						graph_def=sess.graph_def)

	for step in xrange(FLAGS.max_steps):
	  start_time = time.time()
	  _, loss_value = sess.run([train_op, loss])
	  duration = time.time() - start_time

	  assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

	  if step % 10 == 0:
	    num_examples_per_step = FLAGS.batch_size
	    examples_per_sec = num_examples_per_step / duration
	    sec_per_batch = float(duration)

	    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
				      'sec/batch)')
		print (format_str % (datetime.now(), step, loss_value,
							 examples_per_sec, sec_per_batch))

	  EVAL_STEP = 10
	  EVAL_NUM_EXAMPLES = 1024
	  if step % EVAL_STEP == 0:
		prec_train = evaluate_set (sess, top_k,      EVAL_NUM_EXAMPLES)
		prec_eval  = evaluate_set (sess, top_k_eval, EVAL_NUM_EXAMPLES)
		print('%s: precision train = %.3f' % (datetime.now(), prec_train))
		print('%s: precision eval  = %.3f' % (datetime.now(), prec_eval))

	  if step % 100 == 0:
		summary_str = sess.run(summary_op, feed_dict={summary_train_prec: prec_train,
												  summary_eval_prec:  prec_eval})
		summary_writer.add_summary(summary_str, step)

		# Save the model checkpoint periodically.
	  if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
		checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
		saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 21
0
def train(model_fn, train_folder, qn_id):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = model_fn(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Calculate accuracy
        model_accuracy = cifar10.accuracy(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        global_step = tf.train.get_or_create_global_step()
        train_op = cifar10.train(loss, model_accuracy, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._start_time = time.time()

            def after_create_session(self, session, coord):
                self._step = session.run(global_step)

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs([loss, model_accuracy
                                                ])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results[0]
                    acc_value = run_values.results[1]
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s - %s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (qn_id, datetime.now(), self._step, loss_value,
                           acc_value, examples_per_sec, sec_per_batch))

        class _StopAtHook(tf.train.SessionRunHook):
            def __init__(self, last_step):
                self._last_step = last_step

            def after_create_session(self, session, coord):
                self._step = session.run(global_step)

            def before_run(self, run_context):  # pylint: disable=unused-argument
                self._step += 1
                return tf.train.SessionRunArgs(global_step)

            def after_run(self, run_context, run_values):
                if self._step >= self._last_step:
                    run_context.request_stop()

        # class _StopAtHook(tf.train.StopAtStepHook):
        #     def __init__(self, last_step):
        #         super().__init__(last_step=last_step)
        #
        #     def begin(self):
        #         self._global_step_tensor = global_step
        #
        #     def before_run(self, run_context):  # pylint: disable=unused-argument
        #         return tf.train.SessionRunArgs(global_step)
        #
        #     def after_run(self, run_context, run_values):
        #         gs = run_values.results + 1
        #         print("\tgs = {}/{}".format(gs, self._last_step))
        #         if gs >= self._last_step:
        #             # Check latest global step to ensure that the targeted last step is
        #             # reached. global_step read tensor is the value of global step
        #             # before running the operation. We're not sure whether current session.run
        #             # incremented the global_step or not. Here we're checking it.
        #
        #             step = run_context.session.run(self._global_step_tensor)
        #             print("\t\tstep: {}. gs = {}/{}".format(step, gs, self._last_step))
        #             if step >= self._last_step:
        #                 run_context.request_stop()

        saver = tf.train.Saver()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=train_folder,
                hooks=[
                    _StopAtHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            latest_checkpoint_path = tf.train.latest_checkpoint(train_folder)
            if latest_checkpoint_path is not None:
                # Restore from checkpoint
                print("Restoring checkpoint from %s" % latest_checkpoint_path)
                saver.restore(mon_sess, latest_checkpoint_path)

            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 22
0
def train():
    port = 24454
    log_dir = './cen_logdir_%s_%s' % (FLAGS.job_name, FLAGS.task_index)
    log_dir_test = './cen_test_logdir_%s_%s' % (FLAGS.job_name,
                                                FLAGS.task_index)
    config_ps = tf.ConfigProto(intra_op_parallelism_threads=3,
                               inter_op_parallelism_threads=3)
    cluster = tf.train.ClusterSpec({
        'ps': ['localhost:%d' % port],
        'worker': [
            'localhost:%d' % (port + 1),
            'localhost:%d' % (port + 2),
            'localhost:%d' % (port + 3),
            'localhost:%d' % (port + 4)
        ]
    })
    if FLAGS.job_name == 'ps':
        with tf.device('/job:ps/task:0/cpu:0'):
            server = tf.train.Server(cluster,
                                     job_name='ps',
                                     task_index=FLAGS.task_index,
                                     config=config_ps)
            server.join()

    else:
        is_chief = (FLAGS.task_index == 0)
        gpu_options = tf.GPUOptions(allow_growth=True,
                                    allocator_type="BFC",
                                    visible_device_list="%d" %
                                    FLAGS.task_index)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                allow_soft_placement=True)
        server = tf.train.Server(cluster,
                                 job_name='worker',
                                 task_index=FLAGS.task_index,
                                 config=config)

        worker_device = '/job:worker/task:%d/gpu:%d' % (FLAGS.task_index,
                                                        FLAGS.task_index)

        with tf.device(
                tf.train.replica_device_setter(
                    worker_device=worker_device,
                    ps_device='/job:ps/task:0/cpu:0/',
                    ps_tasks=1)):
            global_step = tf.train.get_or_create_global_step()

            # Get images and labels for CIFAR-10.
            # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
            # GPU and resulting in a slow down.
            with tf.device('/cpu:0'):
                images, labels = cifar10.distorted_inputs()

            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = cifar10.inference(images)

            # Calculate loss.
            loss = cifar10.loss(logits, labels)

            # Build a Graph that trains the model with one batch of examples and
            # updates the model parameters.
            train_op, sync_replicas_hook = cifar10.train(
                loss, global_step, is_chief)

            stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.max_steps)

            class _LoggerHook(tf.train.SessionRunHook):
                """Logs loss and runtime."""
                def begin(self):
                    self._step = -1
                    self._start_time = time.time()

                def before_run(self, run_context):
                    self._step += 1
                    return tf.train.SessionRunArgs(
                        loss)  # Asks for loss value.

                def after_run(self, run_context, run_values):
                    if self._step % FLAGS.log_frequency == 0:
                        current_time = time.time()
                        duration = current_time - self._start_time
                        self._start_time = current_time

                        loss_value = run_values.results
                        examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / FLAGS.log_frequency)

                        format_str = (
                            '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                            'sec/batch)')
                        print(format_str %
                              (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                master=server.target,
                hooks=[
                    sync_replicas_hook, stop_hook,
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=config) as sess:
            while not sess.should_stop():
                sess.run(train_op)
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default() as g:
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits, tensor_list = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op, retrieve_list = cifar10.train(loss, tensor_list, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    class _SparsityHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
       self._step = -1
       mode = sparsity_monitor.Mode.monitor
       data_format = "NHWC"
       self.monitor = sparsity_monitor.SparsityMonitor(mode, data_format, FLAGS.monitor_interval,\
                                                       FLAGS.monitor_period, retrieve_list)

      def before_run(self, run_context):
        self._step += 1
        selected_list = self.monitor.scheduler_before(self._step)
        return tf.train.SessionRunArgs(selected_list)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        self.monitor.scheduler_after(run_values.results, self._step)

    sparsity_summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(FLAGS.sparsity_dir, g)

    start = time.time()
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               #tf.train.SummarySaverHook(save_steps=FLAGS.log_frequency, summary_writer=summary_writer, summary_op=sparsity_summary_op),
               _LoggerHook(),
               _SparsityHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
    end = time.time()
    print(end - start)
Esempio n. 24
0
def train(T_est, T_inv_est):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        T_est = tf.constant(T_est)
        T_inv_est = tf.constant(T_inv_est)

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #images, labels = cifar10.distorted_inputs()
            images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        dropout = tf.constant(0.75)
        logits = cifar10.inference(images, dropout, dropout_flag=True)

        # Calculate loss.
        #loss = loss_forward(logits, labels, T_est)
        loss = loss_forward(logits, labels, T_tru)
        #loss = loss_backward(logits, labels, T_inv_est)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, variable_averages = cifar10.train(
            loss, global_step, return_variable_averages=True)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        variables_to_restore = variable_averages.variables_to_restore()
        #print(variables_to_restore)
        for var_name in variables_to_restore.keys():
            if ('logits_T' in var_name) or ('global_step' in var_name):
                del variables_to_restore[var_name]
        #print(variables_to_restore)

        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            while not mon_sess.should_stop():
                res = mon_sess.run([train_op, global_step, T_tru, T_mask_tru])
                if res[1] % 1000 == 0:
                    print('Disturbing matrix\n', res[2])
                    print('Masked structure\n', res[3])
Esempio n. 25
0
def train():
  """Train CIFAR-10 for a number of steps."""



  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)


    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()


    # # Visualize conv1 features
    # with tf.variable_scope('conv1') as scope_conv:
    #     #tf.get_variable_scope().reuse_variables()
    #     scope_conv.reuse_variables()
    #     weights = tf.get_variable('weights')
    #     grid_x = grid_y = 8   # to get a square grid for 64 conv1 features
    #     grid = put_kernels_on_grid (weights, (grid_y, grid_x))
    #     tf.image_summary('conv1/features', grid, max_images=1)


    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)



    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)


    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time


      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / float(duration)
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 26
0
def train():
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        logits = cifar10.inference(images)
        loss = cifar10.loss(logits, labels)
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            '''
            该类用来打印训练信息
            '''
            def begin(self):
                '''
                在创建会话之前调用,调用begin()时,default graph
                会被创建,可在此处向default graph增加新op, begin()
                调用后,default graph不能再被掉用
                '''
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                '''
                调用在每个sess.run()执行之前,可以返回一个
                tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些
                op/tensor; 加入的op/tensor会和sess.run()中已定义的op/tensor
                合并,然后一起执行。
                @param run_context: A 'SessionRunContext' object
                @return: None or a 'SessionRunArgs' object
                '''
                self._step += 1
                # 在这里返回你想在运行过程中产看的信息,以list的形式传递,如:[loss, accuracy]
                return tf.train.SessionRunArgs(loss)

            def after_run(self, run_context, run_values):
                '''
                调用在每个sess.run()之后,参数run_values是before_run()中要求的
                op/tensor的返回值; 
                可以调用run_contex.request_stop()用于停止迭代。 
                sess.run抛出任何异常after_run不会被调用
                @param run_context: A 'SessionRunContext' object
                @param run_values: A SessionRunValues object
                '''
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    # results返回的是上面before_run()的返回结果,上面是loss所以返回loss
                    # 如若上面返回的是个list,则返回的也是个list
                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    print(
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)'
                        % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

        '''
        将计算图的各个节点/操作定义好,构成一个计算图。然后开启一个
        MonitoredTrainingSession来初始化/注册我们的图和其他信息
        在其参数hooks中,传递了三个hook:
        1. tf.train.StopAtStepHook(last_step):该hook是训练达到特定步数时请求
        停止。使用该hook必须要预先定义一个tf.train.get_or_create_global_step()
        2. tf.train.NanTensorHook(loss):该hook用来检测loss, 若loss的结果为NaN,则会
        抛出异常
        3. _LoggerHook():该hook是自定义的hook,用来检测训练过程中的一些数据,譬如loss, accuracy
        。首先会随着MonitoredTrainingSession的初始化来调用begin()函数,在这里初始化步数,before_run()
        函数会随着sess.run()函数的调用而调用。所以每训练一步调用一次,这里返回想要打印的信息,随后调用
        after_run()函数。
        '''
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 27
0
def train():
    """Train CIFAR-10 for a number of steps."""
    # return context manager make this graph as default graph
    # used if want to create multiple graph in same process
    with tf.Graph().as_default():
        # global_step = number of batches seen by grpah, for store and resume operation
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Hook = tools run in process of train/eval of model
        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                # arguments to be added to Session.run() call
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            # called after each call to tf.app.run()
            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)'
                    )
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        # creates a MonitoredSession for training -> sets proper session initializer/restorer
        # create hooks related to checkpoint & summary saving
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 28
0
def my_train(saver):

    #print ('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # create vars for eval
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def __init__(self):
                self._losses = []

            def get_results(self):
                return {"loss": self._losses}

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    self._losses.append(loss_value)
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        logger = _LoggerHook()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                save_checkpoint_secs=60,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss), logger
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            # Main LOOP
            while logger._step < EPOCH_SIZE:
                # run epoch training
                mon_sess.run(train_op)
                train_res = logger.get_results()

        # train_acc = cifar10_eval.simple_eval_once(saver, top_k_op)["accuracy"]
        train_res["accuracy"] = 0
    return train_res
def train(infer_z, noisy_y, C, img_label):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True, noise_ratio=FLAGS.noise_ratio)
            indices = indices[:,
                              0]  # rank 2 --> rank 1, i.e., (batch_size,1) --> (batch_size,)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag')
        logits = cifar10.inference(images, training=is_training)
        preds = tf.nn.softmax(logits)

        # approximate Gibbs sampling
        T = tf.placeholder(tf.float32,
                           shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES],
                           name='transition')
        if FLAGS.groudtruth:
            unnorm_probs = preds * tf.gather(tf.transpose(T_tru, [1, 0]),
                                             labels)
        else:
            unnorm_probs = preds * tf.gather(tf.transpose(T, [1, 0]), labels)

        probs = unnorm_probs / tf.reduce_sum(
            unnorm_probs, axis=1, keepdims=True)
        sampler = OneHotCategorical(probs=probs)
        labels_ = tf.stop_gradient(tf.argmax(sampler.sample(), axis=1))

        loss = cifar10.loss(logits, labels_)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        variables_to_restore = []
        #variables_to_restore += [var for var in tf.trainable_variables() if 'dense' not in var.name] # if final layer is not included
        variables_to_restore += tf.trainable_variables(
        )  # if final layer is included
        variables_to_restore += [
            g for g in tf.global_variables()
            if 'moving_mean' in g.name or 'moving_variance' in g.name
        ]
        for var in variables_to_restore:
            print(var.name)
        #variables_to_restore = []
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            ## initialize some params
            alpha = 1.0
            C_init = C.copy()
            trans_init = (C + alpha) / np.sum(C + alpha, axis=1, keepdims=True)

            ## running setting
            warming_up_step = 20000
            step = 0
            freq_trans = 200

            ### warming up transition
            with open('T_%.2f.pkl' % FLAGS.noise_ratio) as f:
                data = pickle.load(f)
            trans_warming = data[
                2]  # trans_init or np.eye(cifar10.NUM_CLASSES)

            ## record and run
            exemplars = []
            label_trace_exemplars = []
            infer_z_probs = dict()
            trans_before_after_trace = []
            while not mon_sess.should_stop():
                if step % freq_trans == 0:  # update transition matrix in each n steps
                    trans = (C + alpha) / np.sum(
                        C + alpha, axis=1, keepdims=True)

                if step < warming_up_step:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, indices, labels,
                        labels_, probs
                    ],
                                       feed_dict={
                                           is_training: True,
                                           T: trans_warming
                                       })
                else:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, indices, labels,
                        labels_, probs
                    ],
                                       feed_dict={
                                           is_training: True,
                                           T: trans
                                       })

                #print(res[3].shape)
                trans_before = (C + alpha) / np.sum(
                    C + alpha, axis=1, keepdims=True)
                C_before = C.copy()
                for i in xrange(res[3].shape[0]):
                    ind = res[3][i]
                    #print(noisy_y[ind],res[4][i])
                    assert noisy_y[ind] == res[4][i]
                    C[infer_z[ind]][noisy_y[ind]] -= 1
                    assert C[infer_z[ind]][noisy_y[ind]] >= 0
                    infer_z[ind] = res[5][i]
                    infer_z_probs[ind] = res[6][i]
                    C[infer_z[ind]][noisy_y[ind]] += 1
                    #print(res[4][i],res[5][i])

                trans_after = (C + alpha) / np.sum(
                    C + alpha, axis=1, keepdims=True)
                C_after = C.copy()
                trans_gap = np.sum(np.absolute(trans_after - trans_before))
                rou = np.sum(C_after - C_before, axis=-1) / np.sum(
                    C_before + alpha, axis=-1)
                rou_ = np.sum(np.absolute(C_after - C_before),
                              axis=-1) / np.sum(C_before + alpha, axis=-1)
                trans_bound = np.sum((np.absolute(rou) + rou_) / (1 + rou))
                trans_before_after_trace.append([step, trans_gap, trans_bound])
                #print(trans_gap, trans_bound)

                step = res[2]
                if step % 1000 == 0:
                    print('Counting matrix\n', C)
                    print('Counting matrix\n', C_init)
                    print('Transition matrix\n', trans)
                    print('Transition matrix\n', trans_init)

                if step % 5000 == 0:
                    exemplars.append([
                        infer_z.copy().keys(),
                        infer_z.copy().values(),
                        C.copy()
                    ])

                if step % FLAGS.max_steps_per_epoch == 0:
                    r_n = 0
                    all_n = 0
                    for key in infer_z.keys():
                        if infer_z[key] == img_label[key]:
                            r_n += 1
                        all_n += 1
                    acc = r_n / all_n
                    #print('accuracy: %.2f'%acc)
                    label_trace_exemplars.append(
                        [infer_z.copy(),
                         infer_z_probs.copy(), acc])

            if not FLAGS.groudtruth:
                with open('varC_learnt_%.2f.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump(exemplars, w)
            else:
                with open('varC_learnt_%.2f_tru.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump(exemplars, w)

            if FLAGS.labeltrace:
                with open('varC_label_trace_%.2f.pkl' % FLAGS.noise_ratio,
                          'w') as w:
                    pickle.dump([label_trace_exemplars, img_label], w)

            with open('varC_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio,
                      'w') as w:
                pickle.dump(trans_before_after_trace, w)
def main_fun(argv, ctx):
    import tensorflow as tf
    import cifar10

    sys.argv = argv
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string(
        'train_dir', '/tmp/cifar10_train',
        """Directory where to write event logs """
        """and checkpoint.""")
    tf.app.flags.DEFINE_integer('max_steps', 1000000,
                                """Number of batches to run.""")
    tf.app.flags.DEFINE_boolean('log_device_placement', False,
                                """Whether to log device placement.""")
    tf.app.flags.DEFINE_boolean('rdma', False, """Whether to use rdma.""")

    # cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    cluster_spec, server = TFNode.start_cluster_server(ctx, 1, FLAGS.rdma)

    # Train CIFAR-10 for a number of steps.
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Esempio n. 31
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        cluster = tf.train.ClusterSpec({
            "ps": ["ms1108.utah.cloudlab.us:2222"],
            "worker": [
                "ms1106.utah.cloudlab.us:2222", "ms1126.utah.cloudlab.us:2222",
                "ms1128.utah.cloudlab.us:2222", "ms1127.utah.cloudlab.us:2222"
            ]
        })
        server = tf.train.Server(cluster,
                                 job_name=FLAGS.job_name,
                                 task_index=1)
        if FLAGS.job_name == "ps":
            server.join()
        elif FLAGS.job_name == "worker":
            global_step = tf.contrib.framework.get_or_create_global_step()

            # Get images and labels for CIFAR-10.
            # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
            # GPU and resulting in a slow down.
            with tf.device(tf.train.replica_device_setter(cluster=cluster)):
                images, labels = cifar10.distorted_inputs()

            # Build a Graph that computes the logits predictions from the
            # inference model.
            logits = cifar10.inference(images)

            # Calculate loss.
            loss = cifar10.loss(logits, labels)

            # Build a Graph that trains the model with one batch of examples and
            # updates the model parameters.
            train_op = cifar10.train(loss, global_step)

            class _LoggerHook(tf.train.SessionRunHook):
                """Logs loss and runtime."""
                def begin(self):
                    self._step = -1
                    self._start_time = time.time()

                def before_run(self, run_context):
                    self._step += 1
                    return tf.train.SessionRunArgs(
                        loss)  # Asks for loss value.

                def after_run(self, run_context, run_values):
                    if self._step % FLAGS.log_frequency == 0:
                        current_time = time.time()
                        duration = current_time - self._start_time
                        self._start_time = current_time

                        loss_value = run_values.results
                        examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / FLAGS.log_frequency)

                        format_str = (
                            '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                            'sec/batch)')
                        print(format_str %
                              (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

            with tf.train.MonitoredTrainingSession(
                    master=server.target,
                    checkpoint_dir=FLAGS.train_dir,
                    hooks=[
                        tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                        tf.train.NanTensorHook(loss),
                        _LoggerHook()
                    ],
                    config=tf.ConfigProto(log_device_placement=FLAGS.
                                          log_device_placement)) as mon_sess:
                while not mon_sess.should_stop():
                    mon_sess.run(train_op)
def train():
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    print('PS hosts are: %s' % ps_hosts)
    print('Worker hosts are: %s' % worker_hosts)

    server = tf.train.Server({
        'ps': ps_hosts,
        'worker': worker_hosts
    },
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_id)
    if FLAGS.job_name == 'ps':
        server.join()
    is_chief = (FLAGS.task_id == 0)
    if is_chief:
        if tf.gfile.Exists(FLAGS.train_dir):
            tf.gfile.DeleteRecursively(FLAGS.train_dir)
        tf.gfile.MakeDirs(FLAGS.train_dir)

    device_setter = tf.train.replica_device_setter(ps_tasks=len(ps_hosts))
    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        partitioner = tf.fixed_size_partitioner(len(ps_hosts), axis=0)
        with tf.variable_scope('root', partitioner=partitioner):
            with tf.device(device_setter):
                global_step = tf.Variable(0, trainable=False)

                decay_steps = 50000 * 350.0 / FLAGS.batch_size
                batch_size = tf.placeholder(dtype=tf.int32,
                                            shape=(),
                                            name='batch_size')
                images, labels = cifar10.distorted_inputs(batch_size)
                re = tf.shape(images)[0]
                inputs = tf.reshape(images, [-1, _HEIGHT, _WIDTH, _DEPTH])
                labels = tf.one_hot(labels, 10, 1, 0)
                network_fn = nets_factory.get_network_fn('vgg_16',
                                                         num_classes=10)
                (logits, _) = network_fn(inputs)
                cross_entropy = tf.losses.softmax_cross_entropy(
                    logits=logits, onehot_labels=labels)

                loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
                    [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

                train_op = cifar10.train(loss, global_step)

                sv = tf.train.Supervisor(is_chief=is_chief,
                                         logdir=FLAGS.train_dir,
                                         init_op=tf.group(
                                             tf.global_variables_initializer(),
                                             tf.local_variables_initializer()),
                                         summary_op=None,
                                         global_step=global_step,
                                         saver=None,
                                         recovery_wait_secs=1,
                                         save_model_secs=60)

                tf.logging.info('%s Supervisor' % datetime.now())
                sess_config = tf.ConfigProto(
                    allow_soft_placement=True,
                    log_device_placement=FLAGS.log_device_placement)
                sess_config.gpu_options.allow_growth = True

                # Get a session.
                sess = sv.prepare_or_wait_for_session(server.target,
                                                      config=sess_config)

                # Start the queue runners.
                queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
                sv.start_queue_runners(sess, queue_runners)
                """Train CIFAR-10 for a number of steps."""
                batch_size_num = FLAGS.batch_size
                for step in range(FLAGS.max_steps):
                    start_time = time.time()
                    run_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / batch_size_num
                    decay_steps_num = int(num_batches_per_epoch *
                                          NUM_EPOCHS_PER_DECAY)
                    _, loss_value, gs = sess.run(
                        [train_op, loss, global_step],
                        feed_dict={batch_size: batch_size_num},
                        options=run_options,
                        run_metadata=run_metadata)
                    duration = time.time() - start_time
                    num_examples_per_step = batch_size_num
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)
                    format_str = (
                        "time: " + str(time.time()) +
                        '; %s: step %d (global_step %d), loss = %.2f (%.1f examples/sec; %.3f sec/batch)'
                    )
                    tf.logging.info(format_str %
                                    (datetime.now(), step, gs, loss_value,
                                     examples_per_sec, sec_per_batch))
Esempio n. 33
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

                if os.path.isfile('log_cifar.csv'):
                    os.remove('log_cifar.csv')
                df = pandas.DataFrame([], columns=['time', 'step', 'loss'])
                df.to_csv('log_cifar.csv', index=False)

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

                    df = pandas.read_csv('log_cifar.csv')
                    df2 = pandas.DataFrame(
                        [[time.time(), self._step, loss_value]],
                        columns=['time', 'step', 'loss'])
                    df = df.append(df2)
                    df.to_csv('log_cifar.csv', index=False)

        with tf.train.MonitoredTrainingSession(
                save_checkpoint_secs=300,
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
def train(T_fixed, T_init):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #indices, images, labels = cifar10.distorted_inputs()
            indices, images, labels, T_tru, T_mask_tru = cifar10.noisy_distorted_inputs(
                return_T_flag=True, noise_ratio=FLAGS.noise_ratio)
            indices = indices[:, 0]

        # Build a Graph that computes the logits predictions from the
        # inference model.
        is_training = tf.placeholder(tf.bool, shape=(), name='bn_flag')
        logits = cifar10.inference(images, training=is_training)
        preds = tf.nn.softmax(logits)

        # fixed adaption layer
        fixed_adaption_layer = tf.cast(tf.constant(T_fixed), tf.float32)

        # adaption layer
        logits_T = tf.get_variable(
            'logits_T',
            shape=[cifar10.NUM_CLASSES, cifar10.NUM_CLASSES],
            initializer=tf.constant_initializer(np.log(T_init + 1e-8)))
        adaption_layer = tf.nn.softmax(logits_T)

        # label adaption
        is_use = tf.placeholder(tf.bool, shape=(), name='warming_up_flag')
        adaption = tf.cond(is_use, lambda: fixed_adaption_layer,
                           lambda: adaption_layer)
        preds_aug = tf.clip_by_value(tf.matmul(preds, adaption), 1e-8,
                                     1.0 - 1e-8)
        logits_aug = tf.log(preds_aug)

        # Calculate loss.
        loss = cifar10.loss(logits_aug, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Calculate prediction
        # acc_op contains acc and update_op. So it is the cumulative accuracy when sess runs acc_op
        # if you only want to inspect acc of each batch, just sess run acc_op[0]
        acc_op = tf.metrics.accuracy(labels, tf.argmax(logits, axis=1))
        tf.summary.scalar('training accuracy', acc_op[0])

        #### build scalffold for MonitoredTrainingSession to restore the variables you wish
        variables_to_restore = []
        #variables_to_restore += [var for var in tf.trainable_variables() if ('dense' not in var.name and 'logits_T' not in var.name)]
        variables_to_restore += [
            var for var in tf.trainable_variables()
            if 'logits_T' not in var.name
        ]
        variables_to_restore += [
            g for g in tf.global_variables()
            if 'moving_mean' in g.name or 'moving_variance' in g.name
        ]
        for var in variables_to_restore:
            print(var.name)
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_dir)
        init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
            ckpt.model_checkpoint_path, variables_to_restore)

        def InitAssignFn(scaffold, sess):
            sess.run(init_assign_op, init_feed_dict)

        scaffold = tf.train.Scaffold(saver=tf.train.Saver(),
                                     init_fn=InitAssignFn)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(
                    tf.get_collection('losses')[0])  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                scaffold=scaffold,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                save_checkpoint_secs=60,
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:
            warming_up_step = 32000
            step = 0
            varT_rec = []
            varT_trans_trace = []
            while not mon_sess.should_stop():
                if step < warming_up_step:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, fixed_adaption_layer,
                        T_tru, T_mask_tru
                    ],
                                       feed_dict={
                                           is_training: True,
                                           is_use: True
                                       })
                else:
                    res = mon_sess.run([
                        train_op, acc_op, global_step, adaption_layer, T_tru,
                        T_mask_tru
                    ],
                                       feed_dict={
                                           is_training: True,
                                           is_use: False
                                       })
                step = res[2]

                if step % 5000 == 0:
                    varT_rec.append(res[3])

                if step == warming_up_step:
                    trans_before = res[3].copy()
                if step > warming_up_step:
                    trans_after = res[3].copy()
                    trans_gap = np.sum(np.absolute(trans_before - trans_after))
                    varT_trans_trace.append([step, trans_gap])

        with open('varT_learnt_%.2f.pkl' % FLAGS.noise_ratio, 'w') as w:
            pickle.dump(varT_rec, w)

        with open('varT_transvar_trace_%.2f.pkl' % FLAGS.noise_ratio,
                  'w') as w:
            pickle.dump(varT_trans_trace, w)
Esempio n. 35
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    print('init training')
    global_step = tf.train.get_or_create_global_step()
    print('create global step')
    t1=time.time()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      cifar_images, cifar_labels = cifar10.distorted_inputs()
      mnist_images, mnist_labels = cifar10.mnist_inputs("train")

    # Build a Graph that computes the logits predictions from the
    # inference model.

    with tf.variable_scope('shared_net') as scope:
      cifar_local4 = cifar10.inference_shared(cifar_images)
      scope.reuse_variables()
      mnist_local4 = cifar10.inference_shared(mnist_images)
      

    mnist_logits = cifar10.inference_mnist(mnist_local4)
    cifar_logits = cifar10.inference_cifar(cifar_local4)

    #mnist_labels = tf.Print(mnist_labels, [mnist_labels],'*.*.*.* MNIST labels:')
    #cifar_labels = tf.Print(cifar_labels, [cifar_labels],'*.*.*.* CIFAR labels:')


    #logits, mnist_logits = cifar10.inference(mnist_images)

    #logits, _ = cifar10.inference(images)
    #_, mnist_logits = cifar10.inference(mnist_images)



    # Calculate loss.
    with tf.variable_scope('cifar_losses'):
      cifar_loss = cifar10.loss(cifar_logits, cifar_labels)
    with tf.variable_scope('mnist_losses'):
      mnist_loss = cifar10.loss(mnist_logits,mnist_labels,lossname='mnist_losses')
    ct=time.time()
    print('From start to define losses: ',ct-t1,' sec')



    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.

    # Get variables
    train_vars = tf.trainable_variables()
    shared_vars = [var for var in train_vars if 'shared_' in var.name]
    cifar_vars = [var for var in train_vars if 'cifar_' in var.name]
    mnist_vars = [var for var in train_vars if 'mnist_' in var.name]

    # print('SHAREDSHAREDSHAREDSHARED:')
    # for i in shared_vars:
    #   print(i)
    # print('CIFARCIFARCIFARCIFARCIFARCIFAR:')
    # for i in cifar_vars:
    #   print(i)
    # print('MNISTMNISTMNIST:')
    # for i in mnist_vars:
    #   print(i)


    with tf.name_scope('cifar_train'):
      cifar_train_op = cifar10.train(cifar_loss, global_step,var_list=shared_vars+cifar_vars)

    with tf.name_scope('mnist_train'):
      mnist_train_op = cifar10.train(mnist_loss, global_step,var_list=mnist_vars)
      #mnist_train_op = cifar10.train(mnist_loss, global_step,var_list=shared_vars+mnist_vars)
    ct2=time.time()
    print('From loss to define trainer ops: ',ct2-ct,' sec')

    # # trying to run as simple session...
    # conf = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
    # with tf.Session(config=conf) as sess:
    #   sess.run(tf.global_variables_initializer())
    #   print('Ready for training...')
    #   for i in range(FLAGS.max_steps):
    #     print(i)
    #     sess.run(mnist_train_op)
    #     if i% FLAGS.log_frequency:
    #       print('step %d, training loss %g' % (i, mnist_loss))


    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        #print('beforeee')
        return tf.train.SessionRunArgs([cifar_loss, mnist_loss])  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          cifar_loss_value = run_values.results[0]
          mnist_loss_value = run_values.results[1]
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, CIFAR loss = %.2f ; MNIST loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, cifar_loss_value, mnist_loss_value,
                               examples_per_sec, sec_per_batch))
    print('MonitoredTrainingSession is about to start')
    ct3=time.time()
    print('From trainer op to MTS strat: ',ct3-ct2,' sec')
    saver = tf.train.Saver()

    #with tf.train.SingularMonitoredSession(
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(cifar_loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement),
        save_checkpoint_secs=5000) as mon_sess:
      StepCount = 1
      dt=time.time()
      print('With MTS as mon_sess.. :',dt-ct, 'sec')
      print('while is coming')
      while not mon_sess.should_stop():
        if StepCount == 1:
          print('First cycle...')
          et=time.time()
          print('From MTS to first cycle :',et-dt,' sec')
        bt=time.time()
        mon_sess.run(cifar_train_op)
        bt2=time.time()
        if StepCount<10:
          print('Single session run :', bt2-bt, 'sec')
        #if (not mon_sess.should_stop()) and ((StepCount%20)==0) and StepCount<8000:
        if (not mon_sess.should_stop()) and ((StepCount%20)==0) and StepCount<=40000:
        # if (not mon_sess.should_stop()) and StepCount<40000:
          mon_sess.run(mnist_train_op)

        #if StepCount==50 or StepCount==100 or StepCount==500 or StepCount==1000 or StepCount==2000 or StepCount==4000 or StepCount==8000 or StepCount==10000 or StepCount==40000:
        #if (StepCount%50)==0:
        #  saver.save(mon_sess._sess._sess._sess._sess,'/tmp/cifar10_train/model.ckpt',global_step=StepCount)
          #time.sleep(4)
        if StepCount==40000:
          saver.save(mon_sess._sess._sess._sess._sess,'/tmp/cifar10_train/model.ckpt',global_step=StepCount)

        StepCount+=1
Esempio n. 36
0
def train():
  """Train a model for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for a segmentation model.
    images, labels, ground_truth = cifar10.distorted_inputs()
    tf.histogram_summary('label_hist/with_ignore', labels)
    tf.histogram_summary('label_hist/ground_truth', ground_truth)
    
    # Build a Graph that computes the logits predictions from the
    # inference model.
    print("before inference")
    print(images.get_shape())
    logits, nr_params = cifar10.inference(images)
    print("nr_params: "+str(nr_params) )
    print("after inference")
    # Calculate loss.
    loss = cifar10.loss(logits, labels)
    accuracy, precision, cat_accs = cifar10.accuracy(logits, ground_truth)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())
#    tf.image_summary('images2', images)
    print (logits)
#    tf.image_summary('predictions', logits)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      # Restores from checkpoint
      saver.restore(sess, ckpt.model_checkpoint_path)
      # Assuming model_checkpoint_path looks something like:
      #   /my-favorite-path/cifar10_train/model.ckpt-0,
      # extract global_step from it.
      global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
    else:
      print('No checkpoint file found')
      print('Initializing new model')
      sess.run(init)
      global_step = 0


    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    for step in xrange(global_step, FLAGS.max_steps):
      start_time = time.time()
      _, loss_value, accuracy_value, precision_value, cat_accs_val  = sess.run([train_op,
                                                                                loss,
                                                                                accuracy,
                                                                                precision,
                                                                                cat_accs])
                                                                  
      duration = time.time() - start_time

      print (precision_value)
      print (cat_accs_val)
      
      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      #precision_value = [0 if np.isnan(p) else p for p in precision_value]
      #print (precision_value)
      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)\n Accuracy = %.4f, mean average precision = %.4f')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch,
                             accuracy_value, np.mean(precision_value)))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

        summary = tf.Summary()
        summary.value.add(tag='Accuracy (raw)', simple_value=float(accuracy_value))
        for i,s in enumerate(CLASSES):
          summary.value.add(tag="precision/"+s+" (raw)",simple_value=float(precision_value[i]))
          summary.value.add(tag="accs/"+s+" (raw)",simple_value=float(cat_accs_val[i]))
#        summary.value.add(tag='Human precision (raw)', simple_value=float(precision_value))
        summary_writer.add_summary(summary, step)
        print("hundred steps")
      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        print("thousand steps")
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():

    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for CIFAR-10.
    # images, labels = cifar10.standard_distorted_inputs()
    inputs = cifar10.ram_inputs(unit_variance=True, is_train=True)
    images = inputs['images']
    labels = inputs['labels']

    # Batch generator
    batcher = cifar10.Cifar10BatchGenerator(
        inputs['data_images'], inputs['data_labels'], True,
        FLAGS.max_epochs)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images, 3, use_batchnorm=True,
        use_nrelu=False, id_decay=False, add_shortcuts=True, is_train=True)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)

    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement,
        gpu_options=gpu_options))
  
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    step = -1
    while not batcher.is_done():
      step += 1

      batch_im, batch_labs = batcher.next_batch()
      feed_dict = {
          inputs['images_pl']: batch_im,
          inputs['labels_pl']: batch_labs,
        }

      start_time = time.time()
      _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 10 == 0:
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 10 == 0 or batcher.is_done():
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 38
0
def train():

  # # debug

  # true_classes = np.ndarray(shape=(FLAGS.batch_size, 1), dtype=int)
  # true_classes.fill(2)


  # # Create a pair of constant ops, add the numpy 
  # # array matrices.
  # true_classes_tf_matrix = tf.constant(true_classes, dtype=tf.int64)

  # # playing with introducing the sampler
  # classes_sampler = tf.nn.learned_unigram_candidate_sampler(
  #                                     true_classes_tf_matrix, 
  #                                     1,                # true_classes
  #                                     5,                # num_sampled
  #                                     False,            # unique
  #                                     10,               # range_max
  #                                     seed=None, 
  #                                     name="my_classes_sampler")

  # # print(classes_sampler)
  # # print("debug")
  # # print(classes_sampler.set_sampler)
  # # exit()

  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # print("images")
    # print(images)
    # images = tf.Print(images, [images])
    # print()
    # print(images[1])

    print("------------------- train calling interference ---------------------")
    print(cifar10.__file__)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    for step in xrange(FLAGS.max_steps):

      # manually load the contents of images and labels
      # before calling this sess.run()
      # 1. have Cifar10 dataset in memory
      # 2. create a mini-batch
      # 3. set the placeholders/vars to the the mini-batch data
      # 4. run one forward-backward step

      # print("training step: " + str(step))

      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      # debug, temp change, go back to the one below
      summary_str = sess.run(summary_op)
      # print("summary: " + summary_str)
      summary_writer.add_summary(summary_str, step)
      summary_writer.flush()

      # if step % 100 == 0:
      #   summary_str = sess.run(summary_op)
      #   # print("summary: " + summary_str)
      #   summary_writer.add_summary(summary_str, step)
      #   summary_writer.flush()

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 39
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits,transformations,alignment_loss,transformations_regularizer  = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)
#    alignment_loss /= (24**2)
    #transformations_regularizer /= (24**2)
    loss *= 1000
    loss += alignment_loss+transformations_regularizer
#    loss += alignment_loss

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
Esempio n. 40
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            #      images, labels = cifar10.distorted_inputs()
            images, labels = cifar10.inputs(False)

        (x_train, y_train_orl), (x_test, y_test_orl) = dset.cifar10.load_data()
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train, x_test = normalize(x_train, x_test)

        y_train_orl = y_train_orl.astype('int32')
        y_test_orl = y_test_orl.astype('int32')
        y_train_flt = y_train_orl.ravel()
        y_test_flt = y_test_orl.ravel()

        print("image and lables:", images, labels)
        print("xtrian and ytrain:", type(x_train), x_train.shape,
              x_train.dtype, type(y_train_orl), y_train_orl.shape,
              y_train_orl.dtype)
        #    exit(0)

        x = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, 32, 32, 3))
        y = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, ))
        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits, _ = cifar10.inference(x)

        # Calculate loss.
        loss = cifar10.loss(logits, y)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0 and self._step > 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            step = 0
            f = open('tl.json', 'w')
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            time_begin = time.time()
            while not mon_sess.should_stop():
                offset = (step * FLAGS.batch_size) % (EPOCH_SIZE -
                                                      FLAGS.batch_size)
                x_data = x_train[offset:(offset + FLAGS.batch_size), ...]
                y_data_flt = y_train_flt[offset:(offset + FLAGS.batch_size)]
                mon_sess.run(train_op, feed_dict={
                    x: x_data,
                    y: y_data_flt
                })  #, options=run_options, run_metadata=run_metadata)
                #        tl = timeline.Timeline(run_metadata.step_stats)
                #       ctf = tl.generate_chrome_trace_format()
                #    f.write(ctf)
                step += 1
                if (step + 1 == FLAGS.max_steps):
                    predt(mon_sess, x_test, y_test_flt, logits, x, y)
            time_end = time.time()
            training_time = time_end - time_begin
            print("Training elapsed time: %f s" % training_time)
            f.close()
def main_fun(argv, ctx):
  import tensorflow as tf
  import cifar10

  sys.argv = argv
  FLAGS = tf.app.flags.FLAGS
  tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
                             """Directory where to write event logs """
                             """and checkpoint.""")
  tf.app.flags.DEFINE_integer('max_steps', 1000000,
                              """Number of batches to run.""")
  tf.app.flags.DEFINE_boolean('log_device_placement', False,
                              """Whether to log device placement.""")
  tf.app.flags.DEFINE_boolean('rdma', False, """Whether to use rdma.""")

  # cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
  tf.gfile.MakeDirs(FLAGS.train_dir)

  cluster_spec, server = TFNode.start_cluster_server(ctx, 1, FLAGS.rdma)

  # Train CIFAR-10 for a number of steps.
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = FLAGS.batch_size
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
Esempio n. 42
0
def train(lambs):
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        lambs = tf.constant(lambs, dtype=tf.float32)
        loss = infor.loss(logits, labels, lambs)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, lr_op = cifar10.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            previous_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            print ('Start training from previous')
        else:
            print('Strart training from step 0')
            previous_step = 0
            init = tf.initialize_all_variables()
            sess.run(init)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph_def=sess.graph_def)

        step = previous_step
        while step <= FLAGS.max_steps:
            start_time = time.time()
            _, loss_value, lr_value = sess.run([train_op, loss, lr_op])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = ('%s: step %d, loss = %.2f, lr_value = %.4f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value, lr_value,
                                    examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            step += 1
Esempio n. 43
0
def run_training():
    """Train MNIST for a number of steps."""
    # Get the sets of images and labels for training, validation, and
    # test on MNIST.
    data_sets = input_data.read_data_sets(FLAGS.input_data_dir,
                                          FLAGS.fake_data)

    ## Get images and labels for CIFAR-10.
    images1, labels1 = cifar10.distorted_inputs()

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():
        # Generate placeholders for the images and labels.
        images_placeholder, labels_placeholder = placeholder_inputs(
            FLAGS.batch_size)

        ## for cifar10
        images_placeholder1, labels_placeholder1 = placeholder_inputs1(
            FLAGS.batch_size)

        global_step1 = tf.contrib.framework.get_or_create_global_step()

        # Build a Graph that computes predictions from the inference model.
        logits = mnist.inference(images_placeholder, FLAGS.hidden1,
                                 FLAGS.hidden2)

        ## Build a Graph that computes the logits predictions from the
        ## inference model.
        logits1 = cifar10.inference(images_placeholder1)

        # Add to the Graph the Ops for loss calculation.
        loss = mnist.loss(logits, labels_placeholder)

        ## Calculate loss.
        loss1 = cifar10.loss(logits1, labels1)

        # Add to the Graph the Ops that calculate and apply gradients.
        train_op = mnist.training(loss, FLAGS.learning_rate)

        ## Build a Graph that trains the model with one batch of examples and
        ## updates the model parameters.
        train_op1 = cifar10.train(loss1, global_step1)

        # Add the Op to compare the logits to the labels during evaluation.
        eval_correct = mnist.evaluation(logits, labels_placeholder)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver()

        # Create a session for running Ops on the Graph.
        sess = tf.Session()

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        # And then after everything is built:

        # Run the Op to initialize the variables.
        sess.run(init)

        # Start the training loop.
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()

            # Fill a feed dictionary with the actual set of images and labels
            # for this particular training step.
            feed_dict = fill_feed_dict(data_sets.train, images_placeholder,
                                       labels_placeholder)

            feed_dict1 = {
                images_placeholder1: images1,
                labels_placeholder1: labels1,
            }

            # Run one step of the model.  The return values are the activations
            # from the `train_op` (which is discarded) and the `loss` Op.  To
            # inspect the values of your Ops or variables, you may include them
            # in the list passed to sess.run() and the value tensors will be
            # returned in the tuple from the call.
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time

            # Write the summaries and print an overview fairly often.
            if step % 100 == 0:
                # Print status to stdout.
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, loss_value, duration))
                # Update the events file.
                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=step)
                # Evaluate against the training set.
                print('Training Data Eval:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.train)
                # Evaluate against the validation set.
                print('Validation Data Eval:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.validation)
                # Evaluate against the test set.
                print('Test Data Eval:')
                do_eval(sess, eval_correct, images_placeholder,
                        labels_placeholder, data_sets.test)
def main(_):

    class _LoggerHook(tf.train.SessionRunHook):
        """Logs loss and runtime."""

        def begin(self):
            self._step = -1

        def before_run(self, run_context):
            self._step += 1
            self._start_time = time.time()
            return tf.train.SessionRunArgs(loss)  # Asks for loss value.

        def after_run(self, run_context, run_values):
            duration = time.time() - self._start_time
            loss_value = run_values.results
            if self._step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                            'sec/batch)')
                print (format_str % (datetime.now(), self._step, loss_value,
                                    examples_per_sec, sec_per_batch))
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")

    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(cluster,
                            job_name=FLAGS.job_name,
                            task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":

        # Assigns ops to the local worker by default.
        with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)):

            global_step = tf.contrib.framework.get_or_create_global_step()

            # Get images and labels for CIFAR-10.
            images, labels = cifar10.distorted_inputs()

            # Build inference Graph.
            logits = cifar10.inference(images)

            # Build the portion of the Graph calculating the losses. Note that we will
            # assemble the total_loss using a custom function below.
            loss = cifar10.loss(logits, labels)

            # Build a Graph that trains the model with one batch of examples and
            # updates the model parameters.
            train_op = cifar10.train(loss,global_step)

        # The StopAtStepHook handles stopping after running given steps.
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), _LoggerHook()]

        # The MonitoredTrainingSession takes care of session initialization,
        # restoring from a checkpoint, saving to a checkpoint, and closing when done
        # or an error occurs.
        with tf.train.MonitoredTrainingSession(master=server.target,
                                                is_chief=(FLAGS.task_index == 0),
                                                checkpoint_dir=FLAGS.train_dir,
                                                save_checkpoint_secs=60,
                                                hooks=hooks) as mon_sess:
            while not mon_sess.should_stop():
                # Run a training step asynchronously.
                # See `tf.train.SyncReplicasOptimizer` for additional details on how to
                # perform *synchronous* training.
                # mon_sess.run handles AbortedError in case of preempted PS.
                mon_sess.run(train_op)