def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
    if not grads_and_vars:
      raise ValueError("Must supply at least one variable")

    if global_step is None:
      raise ValueError("Global step is required to check staleness")

    self._global_step = global_step
    train_ops = []
    aggregated_grad = []
    inputs = []
    var_list = []
    for x in grads_and_vars:
      inputs.extend(list(x))

    with ops.device(global_step.device):
      self._local_steps = variables.Variable(
          array_ops.zeros(
              [self._total_num_replicas],
              dtype=global_step.dtype),
          trainable=False,
          name="local_steps")

    # Check staleness. Note that this has to be ref(), otherwise identity will
    # be accessed and it will be old values.
    local_step = array_ops.slice(self._local_steps.ref(),
                                 array_ops.reshape(self._replica_id, (1,)),
                                 [1],
                                 name="get_local_step")
    local_step = array_ops.reshape(local_step, ())
    is_stale = math_ops.less(local_step, global_step)

    with ops.name_scope(None, self._name, inputs):
      for grad, var in grads_and_vars:
        var_list.append(var)
        with ops.device(var.device):
          if isinstance(grad, ops.Tensor):
            gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2,
                                                      grad.dtype,
                                                      shapes=var.get_shape(),
                                                      shared_name=var.name))
            self._one_element_queue_list.append((gradient_queue, var.device))
            train_ops.append(gradient_queue.enqueue([grad]))

            # Aggregate all gradients
            gradients = gradient_queue.dequeue_many(
                self._replicas_to_aggregate)
            aggregated_grad.append(math_ops.reduce_sum(gradients, [0]))
          elif grad is None:
            aggregated_grad.append(None)  # pass-through.
          else:
            if not isinstance(grad, ops.IndexedSlices):
              raise ValueError("Unknown grad type!")
            aggregated_grad.append(self._aggregate_sparse_grad(grad, var,
                                                               train_ops))

      aggregated_grads_and_vars = zip(aggregated_grad, var_list)

      # sync_op will be assigned to the same device as the global step.
      with ops.device(global_step.device), ops.name_scope(""):
        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
                                              global_step)

      # Create token queue.
      with ops.device(global_step.device), ops.name_scope(""):
        sync_token_queue = (
            data_flow_ops.FIFOQueue(-1,
                                    global_step.dtype.base_dtype,
                                    shapes=(),
                                    shared_name="sync_token_q"))
        self._sync_token_queue = sync_token_queue

        # dummy_queue is passed to the queue runner. Don't use the real queues
        # because the queue runner doesn't automatically reopen it once it
        # closed queues in PS devices.
        dummy_queue = (
            data_flow_ops.FIFOQueue(1,
                                    types_pb2.DT_INT32,
                                    shapes=(),
                                    shared_name="dummy_queue"))
      # Clear all the gradients queues in case there are stale gradients.
      clear_queue_ops = []
      with ops.control_dependencies([update_op]):
        for queue, dev in self._one_element_queue_list:
          with ops.device(dev):
            stale_grads = queue.dequeue_many(queue.size())
            clear_queue_ops.append(stale_grads)

        for queue, dev in self._sparse_grad_queues_and_devs:
          with ops.device(dev):
            _, stale_indices = queue.dequeue_many(queue.size())
            clear_queue_ops.append(stale_indices)

      with ops.device(global_step.device):
        self._clean_up_op = control_flow_ops.abort(
            error_msg="From sync_replicas")

      # According to the staleness, select between the enqueue op (real_grad)
      # or no-op (no_op_grad). Effectively dropping all the stale gradients.
      no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")]
      real_grad = lambda: [control_flow_ops.group(*train_ops)]
      final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad)

      with ops.device(global_step.device), ops.name_scope(""):
        # Replicas have to wait until they can get a token from the token queue.
        with ops.control_dependencies([final_train_ops]):
          token = sync_token_queue.dequeue()
          train_op = state_ops.scatter_update(self._local_steps,
                                              self._replica_id, token)

        with ops.control_dependencies(clear_queue_ops):
          # Sync_op needs to insert tokens to the token queue at the end of the
          # step so the replicas can fetch them to start the next step.
          # Note that ref() is used to avoid reading from the identity with old
          # the step.
          tokens = array_ops.fill([self._tokens_per_step], global_step.ref())
          sync_op = sync_token_queue.enqueue_many((tokens,))

        if self._variable_averages is not None:
          with ops.control_dependencies([sync_op]), ops.name_scope(""):
            sync_op = self._variable_averages.apply(
                self._variables_to_average)

        self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
                                                            [sync_op])
        self._gradients_applied = True
        return train_op
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
    if not grads_and_vars:
      raise ValueError("Must supply at least one variable")

    if global_step is None:
      raise ValueError("Global step is required to check staleness")

    self._global_step = global_step
    train_ops = []
    aggregated_grad = []
    inputs = []
    var_list = []
    for x in grads_and_vars:
      inputs.extend(list(x))

    with ops.device(global_step.device):
      self._local_steps = variables.Variable(
          array_ops.zeros(
              [self._total_num_replicas],
              dtype=global_step.dtype),
          trainable=False,
          name="local_steps")

    # Check staleness. Note that this has to be ref(), otherwise identity will
    # be accessed and it will be old values.
    local_step = array_ops.slice(self._local_steps.ref(),
                                 array_ops.reshape(self._replica_id, (1,)),
                                 [1],
                                 name="get_local_step")
    local_step = array_ops.reshape(local_step, ())
    is_stale = math_ops.less(local_step, global_step)

    with ops.op_scope(inputs, None, self._name):
      for grad, var in grads_and_vars:
        var_list.append(var)
        with ops.device(var.device):
          if isinstance(grad, ops.Tensor):
            gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2,
                                                      grad.dtype,
                                                      shapes=var.get_shape(),
                                                      shared_name=var.name))
            self._one_element_queue_list.append((gradient_queue, var.device))
            train_ops.append(gradient_queue.enqueue([grad]))

            # Aggregate all gradients
            gradients = gradient_queue.dequeue_many(
                self._replicas_to_aggregate)
            aggregated_grad.append(math_ops.reduce_sum(gradients, [0]))
          elif grad is None:
            aggregated_grad.append(None)  # pass-through.
          else:
            if not isinstance(grad, ops.IndexedSlices):
              raise ValueError("Unknown grad type!")
            aggregated_grad.append(self._aggregate_sparse_grad(grad, var,
                                                               train_ops))

      aggregated_grads_and_vars = zip(aggregated_grad, var_list)

      # sync_op will be assigned to the same device as the global step.
      with ops.device(global_step.device), ops.name_scope(""):
        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
                                              global_step)

      # Create token queue.
      with ops.device(global_step.device), ops.name_scope(""):
        sync_token_queue = (
            data_flow_ops.FIFOQueue(-1,
                                    global_step.dtype.base_dtype,
                                    shapes=(),
                                    shared_name="sync_token_q"))
        self._sync_token_queue = sync_token_queue

        # dummy_queue is passed to the queue runner. Don't use the real queues
        # because the queue runner doesn't automatically reopen it once it
        # closed queues in PS devices.
        dummy_queue = (
            data_flow_ops.FIFOQueue(1,
                                    types_pb2.DT_INT32,
                                    shapes=(),
                                    shared_name="dummy_queue"))
      # Clear all the gradients queues in case there are stale gradients.
      clear_queue_ops = []
      with ops.control_dependencies([update_op]):
        for queue, dev in self._one_element_queue_list:
          with ops.device(dev):
            stale_grads = queue.dequeue_many(queue.size())
            clear_queue_ops.append(stale_grads)

        for queue, dev in self._sparse_grad_queues_and_devs:
          with ops.device(dev):
            _, stale_indices = queue.dequeue_many(queue.size())
            clear_queue_ops.append(stale_indices)

      with ops.device(global_step.device):
        self._clean_up_op = control_flow_ops.abort(
            error_msg="From sync_replicas")

      # According to the staleness, select between the enqueue op (real_grad)
      # or no-op (no_op_grad). Effectively dropping all the stale gradients.
      no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")]
      real_grad = lambda: [control_flow_ops.group(*train_ops)]
      final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad)

      with ops.device(global_step.device), ops.name_scope(""):
        # Replicas have to wait until they can get a token from the token queue.
        with ops.control_dependencies([final_train_ops]):
          token = sync_token_queue.dequeue()
          train_op = state_ops.scatter_update(self._local_steps,
                                              self._replica_id, token)

        with ops.control_dependencies(clear_queue_ops):
          # Sync_op needs to insert tokens to the token queue at the end of the
          # step so the replicas can fetch them to start the next step.
          # Note that ref() is used to avoid reading from the identity with old
          # the step.
          tokens = array_ops.fill([self._tokens_per_step], global_step.ref())
          sync_op = sync_token_queue.enqueue_many((tokens,))

        if self._variable_averages is not None:
          with ops.control_dependencies([sync_op]), ops.name_scope(""):
            sync_op = self._variable_averages.apply(
                self._variables_to_average)

        self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
                                                            [sync_op])
        self._gradients_applied = True
        return train_op