def tower_loss(self, scope, images, labels):
        """Calculate the total loss on a single tower running the CIFAR model.
        
        Args:
          scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
          images: Images. 4D tensor of shape [batch_size, height, width, 3].
          labels: Labels. 1D tensor of shape [batch_size].
        
        Returns:
           Tensor of shape [] containing the total loss for a batch of data
        """

        # Build inference Graph.
        logits = cifar10.inference(images)

        # Build the portion of the Graph calculating the losses. Note that we will
        # assemble the total_loss using a custom function below.
        _ = cifar10.loss(logits, labels)

        # Assemble all of the losses for the current tower only.
        losses = tf.get_collection('losses', scope)

        # Calculate the total loss for the current tower.
        total_loss = tf.add_n(losses, name='total_loss')

        # Attach a scalar summary to all individual losses and the total loss; do the
        # same for the averaged version of the losses.
        for l in losses + [total_loss]:
            # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
            # session. This helps the clarity of presentation on tensorboard.
            loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '',
                               l.op.name)
            tf.summary.scalar(loss_name, l)

        return total_loss
def evaluate():
    """Eval CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        # Get images and labels for CIFAR-10.
        eval_data = FLAGS.eval_data == 'test'
        images, labels = cifar10.inputs(eval_data=eval_data)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate predictions.
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            cifar10.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        while True:
            eval_once(saver, summary_writer, top_k_op, summary_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Пример #3
0
 def train( name, X_train, Y_train, X_val, Y_val):
     sess = tf.Session()
     sess.run(tf.global_variables_initializer())
 
     num_training = len(X_train)
     batch_size = cifar10.FLAGS.batch_size
     cifar10.FLAGS
     step = 0
     losses = []
     accuracies = []
     print('-' * 5 + '  Start training  ' + '-' * 5)
     #Y_train = tf.one_hot(Y_train, 10)
     #Y_val = tf.one_hot(Y_val,10)
 
     for epoch in range(num_epoch):
         print('train for epoch %d' % epoch)
         for i in range(num_training // batch_size):
             X_ = X_train[i * batch_size:(i + 1) * batch_size][:]
             Y_ = Y_train[i * batch_size:(i + 1) * batch_size]
              
             global_step = tf.contrib.framework.get_or_create_global_step()
 
             # Get images and labels for CIFAR-10.
             # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
             # GPU and resulting in a slow down.
 
             feed_dict = {images: X_, labels : Y_}
              
                 fetches = [self.train_op, self.loss_op, self.accuracy_op]
 
                 _, loss, accuracy = sess.run(fetches, feed_dict=feed_dict)
                 losses.append(loss)
             # Build a Graph that computes the logits predictions from the
             # inference model.
             logits = cifar10.inference(images)
         
             # Calculate loss.
             loss = cifar10.loss(logits, labels)
             
             accuracy = eval(logits, labels)
             
             # Build a Graph that trains the model with one batch of examples and
             # updates the model parameters.
             train_op = cifar10.train(loss, global_step)
             losses.append(loss)
             accuracies.append(accuracy)
             #if step % self.log_step == 0:
             #    print('iteration (%d): loss = %.3f, accuracy = %.3f' %
             #            (step, loss, accuracy))
             step += 1
 
         # Print validation results
         print('validation for epoch %d' % epoch)
         #val_accuracy = self.evaluate(sess, X_val, Y_val)
         #print('-  epoch %d: validation accuracy = %.3f' % (epoch, val_accuracy))
         saver = tf.train.Saver()
         model_path = saver.save(sess, name+".ckpt")
Пример #4
0
 def buildModel(self):
     # Placeholders
     self.images = tf.placeholder(tf.float32, [None, 32, 32, 3])
     self.labels = tf.placeholder(tf.int64, [None])
     # Build a Graph that computes the logits predictions from the
     # inference model.
     self.logits = cifar10.inference(self.images)
         
     # Calculate loss.
     self.loss = cifar10.loss(self.logits, self.labels)
             
     self.accuracy = eval(self.logits, self.labels)
             
     # Build a Graph that trains the model with one batch of examples and
     # updates the model parameters.
     self.train_op = cifar10.train(self.loss, self.global_step)
def evaluate(eval_data, eval_dir, model_dir, client):
    """Eval CIFAR-10 for a number of steps."""
    FLAGS = parser.parse_args()

    with tf.Graph().as_default() as g:
        # Get images and labels for CIFAR-10.

        images, labels = cifar10.inputs(eval_data=eval_data, client=client)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate predictions.
        predict = tf.argmax(logits, 1, output_type=tf.int32)
        correct = tf.equal(predict, labels)
        accuracy_op = tf.reduce_mean(tf.cast(correct, tf.float32))
        top_k_op = tf.nn.in_top_k(logits, labels, 1)
        loss = cifar10.loss(logits, labels)
        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            cifar10.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        summary_writer = tf.summary.FileWriter(eval_dir, g)

        while True:

            eval_once(saver, summary_writer, top_k_op, accuracy_op, summary_op,
                      loss, model_dir, eval_dir)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
Пример #6
0
def train(client):
    """Train CIFAR-10 for a number of steps."""
    FLAGS = parser.parse_args()
    model_dir = modelDir(FLAGS.train_dir, client)
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs(client)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

            def end(self, mon_sess):
                #print("END!")
                weights = []  #np.empty([len(tf.trainable_variables())])
                i = 0
                for t in tf.trainable_variables():
                    print(t)
                    weights.append(t.eval(session=mon_sess))
                print(weights)
                return weights

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=model_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
    def train(self):
      """Train CIFAR-10 for a number of steps."""
      client = self.index
      FLAGS = parser.parse_args()
      model_dir=self.modelDir(FLAGS.train_dir,client)
      
      with tf.Graph().as_default():   
        global_step = tf.contrib.framework.get_or_create_global_step()
    
        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
          images, labels = cifar10.distorted_inputs(client)
    
        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)
    
        # Calculate loss.
        loss = cifar10.loss(logits, labels)
    
        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)
        saver = tf.train.Saver()
        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            
            def begin(self):
                self._step = -1
                self._start_time = time.time()
                self.loss=[]
            
            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.
            
            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time
            
                    loss_value = run_values.results
                    self.loss.append(loss_value)
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)
            
                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print (format_str % (datetime.now(), self._step, loss_value,
                                     examples_per_sec, sec_per_batch))
            def end(self, mon_sess):
                #print tf.global_variables()
                filename = FLAGS.loss_dir + "loss_client"+str(client)+".pkl"
                
                old_loss=[]
                if os.path.exists(filename):
                    with open(filename,'rb') as rfp: 
                        old_loss = pickle.load(rfp)
                old_loss.append(self.loss)
                

                with open(filename, "wb") as fp:   #Pickling
                    pickle.dump(old_loss, fp)
                    
                cifar10_eval.evaluate(True, "./models/cifar/eval/client"+str(client), FLAGS.train_dir+str(client), client)
                for t in tf.trainable_variables():
                    #print(t)
                    weights.append(t.eval(session=mon_sess))


             
        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=model_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            #save_path = saver.restore(mon_sess._sess._sess._sess._sess, "/Users/Sara/Dropbox/Class/NN/Project/models/cifar/tfmodel.ckpt")
 
            t = tf.trainable_variables()[0]
            #print(t.eval(mon_sess)[0][0][0]) 
            if len(central_weights)>0 : 
                d = central_weights[0][0][0][0] == t.eval(mon_sess)[0][0][0]
                #print("weights", d)

            while not mon_sess.should_stop():

                mon_sess.run(train_op)
def eval_central_model(client, eval_data, eval_dir, checkpoint_dir):
    curr_client = client
    """Eval CIFAR-10 for a number of steps."""
    FLAGS = parser.parse_args()

    with tf.Graph().as_default() as g:
        # Get images and labels for CIFAR-10.

        images, labels = process_images()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate predictions.
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        # Restore the moving average version of the learned variables for eval.
        variable_averages = tf.train.ExponentialMovingAverage(
            cifar10.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                saver.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                num_iter = int(math.ceil(FLAGS.num_examples /
                                         FLAGS.batch_size))
                true_count = 0  # Counts the number of correct predictions.
                total_sample_count = num_iter * FLAGS.batch_size
                step = 0

                while step < num_iter and not coord.should_stop():
                    predictions = sess.run([top_k_op])
                    true_count += np.sum(predictions)
                    step += 1

                # Compute precision @ 1.
                precision = true_count / total_sample_count
                print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
                write(eval_dir, precision, client)

            except Exception as e:  # pylint: disable=broad-except
                coord.request_stop(e)
            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)