def train():
    """Train CIFAR-10 for a number of steps."""

    g1 = tf.Graph()
    with g1.as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)
        grads = cifar10.train_part1(loss, global_step)

        only_gradients = [g for g, _ in grads]
        only_vars = [v for _, v in grads]
        placeholder_gradients = []

        #with tf.device("/gpu:0"):
        for grad_var in grads:
            placeholder_gradients.append(
                (tf.placeholder('float',
                                shape=grad_var[0].get_shape()), grad_var[1]))

        feed_dict = {}

        for i, grad_var in enumerate(grads):
            feed_dict[placeholder_gradients[i][0]] = np.zeros(
                placeholder_gradients[i][0].shape)

        train_op = cifar10.train_part2(global_step, placeholder_gradients)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        feeds = []
        print("Reached here")
        for i, grad_var in enumerate(grads):
            feeds.append(placeholder_gradients[i][0])
        # Partial Run
        print("Reached here", len(feeds))
        for x in feeds:
            print(x, )
        h = sess.partial_run_setup([only_gradients, train_op], feeds)
        print("Reached here")

        for i in xrange(10):
            res_grads = sess.partial_run(h,
                                         only_gradients,
                                         feed_dict=feed_dict)

            feed_dict = {}
            for i, grad_var in enumerate(res_grads):
                feed_dict[placeholder_gradients[i][0]] = res_grads[i]

            res_train_op = sess.partial_run(h, train_op, feed_dict=feed_dict)
def train():
  """Train CIFAR-10 for a number of steps."""

  g1 = tf.Graph()
  with g1.as_default():
    #global_step = tf.contrib.framework.get_or_create_global_step()
    
    global_step = tf.Variable(-1, name='global_step', trainable=False, dtype=tf.int32)
    increment_global_step_op = tf.assign(global_step, global_step+1)

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs2()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)
    grads  = cifar10.train_part1(loss, global_step)

    only_gradients = [g for g,_ in grads]

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
        

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess:
      # Getting first set of variables
      s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s.connect((TCP_IP, port_main_1))
      recv_size = safe_recv(8, s)
      recv_size = pickle.loads(recv_size)
      recv_data = safe_recv(recv_size, s)
      var_vals_1 = pickle.loads(recv_data)
      s.close()
      # Getting second set of variables
      s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s.connect((TCP_IP, port_main_2))
      recv_size = safe_recv(8, s)
      recv_size = pickle.loads(recv_size)
      recv_data = safe_recv(recv_size, s)
      var_vals_2 = pickle.loads(recv_data)
      s.close()

      feed_dict = {}
      i=0
      for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        if(i < half_index):
            feed_dict[v] = var_vals_1[i]
        else:
            feed_dict[v] = var_vals_2[i-half_index]
        i=i+1
      print("Received variable values from ps")
      s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s1.connect((TCP_IP, port_ps1))
      s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s2.connect((TCP_IP, port_ps2))
      print("Connected to both PSs")
      while not mon_sess.should_stop():
        gradients, step_val = mon_sess.run([only_gradients,increment_global_step_op], feed_dict=feed_dict)
        #print("Sending grads port: ", port)
        # Opening the socket and connecting to server
        # sending the gradients
        grad_part1 = []
        grad_part2 = []
        i=0
        for g in gradients:
            if(i < half_index):
                grad_part1.append(g)
            else:
                grad_part2.append(g)
            i=i+1

        send_data_1 = pickle.dumps(grad_part1,pickle.HIGHEST_PROTOCOL)
        to_send_size_1 = len(send_data_1)
        send_size_1 = pickle.dumps(to_send_size_1, pickle.HIGHEST_PROTOCOL)
        s1.sendall(send_size_1)
        s1.sendall(send_data_1)

        send_data_2 = pickle.dumps(grad_part2,pickle.HIGHEST_PROTOCOL)
        to_send_size_2 = len(send_data_2)
        send_size_2 = pickle.dumps(to_send_size_2, pickle.HIGHEST_PROTOCOL)
        s2.sendall(send_size_2)
        s2.sendall(send_data_2)
        #print("sent grads")
        #receiving the variable values
        recv_size = safe_recv(8, s1)
        recv_size = pickle.loads(recv_size)
        recv_data = safe_recv(recv_size, s1)
        var_vals_1 = pickle.loads(recv_data)

        recv_size = safe_recv(8, s2)
        recv_size = pickle.loads(recv_size)
        recv_data = safe_recv(recv_size, s2)
        var_vals_2 = pickle.loads(recv_data)
        #print("recved grads")
        
        feed_dict = {}
        i=0
        for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            if(i < half_index):
                feed_dict[v] = var_vals_1[i]
            else:
                feed_dict[v] = var_vals_2[i-half_index]
            i=i+1
      
      s1.close()
      s2.close()
Ejemplo n.º 3
0
def train():
    """Train CIFAR-10 for a number of steps."""

    g1 = tf.Graph()
    with g1.as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)
        grads = cifar10.train_part1(loss, global_step)

        only_gradients = [g for g, _ in grads]
        only_vars = [v for _, v in grads]

        placeholder_gradients = []

        #with tf.device("/gpu:0"):
        for grad_var in grads:
            placeholder_gradients.append(
                (tf.placeholder('float',
                                shape=grad_var[0].get_shape()), grad_var[1]))

        feed_dict = {}

        for i, grad_var in enumerate(grads):
            feed_dict[placeholder_gradients[i][0]] = np.zeros(
                placeholder_gradients[i][0].shape)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train_part2(global_step, placeholder_gradients)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement,
                    gpu_options=gpu_options)) as mon_sess:

            global port
            while not mon_sess.should_stop():

                gradients = mon_sess.run(only_gradients, feed_dict=feed_dict)
                # pickling the gradients
                send_data = pickle.dumps(gradients, pickle.HIGHEST_PROTOCOL)
                # finding size of pickled gradients
                to_send_size = len(send_data)
                # Sending the size of the gradients first
                send_size = pickle.dumps(to_send_size, pickle.HIGHEST_PROTOCOL)
                s.sendall(send_size)
                # sending the gradients
                s.sendall(send_data)
                recv_size = safe_recv(8, s)
                recv_size = pickle.loads(recv_size)
                recv_data = safe_recv(recv_size, s)
                gradients2 = pickle.loads(recv_data)
                #print("Recevied gradients of size: ", len(recv_data))
                feed_dict = {}

                for i, grad_var in enumerate(gradients2):
                    feed_dict[placeholder_gradients[i][0]] = gradients2[i]
                    #print(gradients[i].shape)
                    #print(gradients2[i].shape)

                res = mon_sess.run(train_op, feed_dict=feed_dict)