def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] inputs = [] var_list = [] for x in grads_and_vars: inputs.extend(list(x)) with ops.device(global_step.device): self._local_steps = variables.Variable( array_ops.zeros( [self._total_num_replicas], dtype=global_step.dtype), trainable=False, name="local_steps") # Check staleness. Note that this has to be ref(), otherwise identity will # be accessed and it will be old values. local_step = array_ops.slice(self._local_steps.ref(), array_ops.reshape(self._replica_id, (1,)), [1], name="get_local_step") local_step = array_ops.reshape(local_step, ()) is_stale = math_ops.less(local_step, global_step) with ops.name_scope(None, self._name, inputs): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): if isinstance(grad, ops.Tensor): gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2, grad.dtype, shapes=var.get_shape(), shared_name=var.name)) self._one_element_queue_list.append((gradient_queue, var.device)) train_ops.append(gradient_queue.enqueue([grad])) # Aggregate all gradients gradients = gradient_queue.dequeue_many( self._replicas_to_aggregate) aggregated_grad.append(math_ops.reduce_sum(gradients, [0])) elif grad is None: aggregated_grad.append(None) # pass-through. else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") aggregated_grad.append(self._aggregate_sparse_grad(grad, var, train_ops)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), shared_name="dummy_queue")) # Clear all the gradients queues in case there are stale gradients. clear_queue_ops = [] with ops.control_dependencies([update_op]): for queue, dev in self._one_element_queue_list: with ops.device(dev): stale_grads = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_grads) for queue, dev in self._sparse_grad_queues_and_devs: with ops.device(dev): _, stale_indices = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_indices) with ops.device(global_step.device): self._clean_up_op = control_flow_ops.abort( error_msg="From sync_replicas") # According to the staleness, select between the enqueue op (real_grad) # or no-op (no_op_grad). Effectively dropping all the stale gradients. no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")] real_grad = lambda: [control_flow_ops.group(*train_ops)] final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies([final_train_ops]): token = sync_token_queue.dequeue() train_op = state_ops.scatter_update(self._local_steps, self._replica_id, token) with ops.control_dependencies(clear_queue_ops): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. # Note that ref() is used to avoid reading from the identity with old # the step. tokens = array_ops.fill([self._tokens_per_step], global_step.ref()) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) self._gradients_applied = True return train_op
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. This contains most of the synchronization implementation and also wraps the apply_gradients() from the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: train_op: The op to dequeue a token so the replicas can exit this batch and start the next one. This is executed by each replica. Raises: ValueError: If the grads_and_vars is empty. ValueError: If global step is not provided, the staleness cannot be checked. """ if not grads_and_vars: raise ValueError("Must supply at least one variable") if global_step is None: raise ValueError("Global step is required to check staleness") self._global_step = global_step train_ops = [] aggregated_grad = [] inputs = [] var_list = [] for x in grads_and_vars: inputs.extend(list(x)) with ops.device(global_step.device): self._local_steps = variables.Variable( array_ops.zeros( [self._total_num_replicas], dtype=global_step.dtype), trainable=False, name="local_steps") # Check staleness. Note that this has to be ref(), otherwise identity will # be accessed and it will be old values. local_step = array_ops.slice(self._local_steps.ref(), array_ops.reshape(self._replica_id, (1,)), [1], name="get_local_step") local_step = array_ops.reshape(local_step, ()) is_stale = math_ops.less(local_step, global_step) with ops.op_scope(inputs, None, self._name): for grad, var in grads_and_vars: var_list.append(var) with ops.device(var.device): if isinstance(grad, ops.Tensor): gradient_queue = (data_flow_ops.FIFOQueue(self._tokens_per_step * 2, grad.dtype, shapes=var.get_shape(), shared_name=var.name)) self._one_element_queue_list.append((gradient_queue, var.device)) train_ops.append(gradient_queue.enqueue([grad])) # Aggregate all gradients gradients = gradient_queue.dequeue_many( self._replicas_to_aggregate) aggregated_grad.append(math_ops.reduce_sum(gradients, [0])) elif grad is None: aggregated_grad.append(None) # pass-through. else: if not isinstance(grad, ops.IndexedSlices): raise ValueError("Unknown grad type!") aggregated_grad.append(self._aggregate_sparse_grad(grad, var, train_ops)) aggregated_grads_and_vars = zip(aggregated_grad, var_list) # sync_op will be assigned to the same device as the global step. with ops.device(global_step.device), ops.name_scope(""): update_op = self._opt.apply_gradients(aggregated_grads_and_vars, global_step) # Create token queue. with ops.device(global_step.device), ops.name_scope(""): sync_token_queue = ( data_flow_ops.FIFOQueue(-1, global_step.dtype.base_dtype, shapes=(), shared_name="sync_token_q")) self._sync_token_queue = sync_token_queue # dummy_queue is passed to the queue runner. Don't use the real queues # because the queue runner doesn't automatically reopen it once it # closed queues in PS devices. dummy_queue = ( data_flow_ops.FIFOQueue(1, types_pb2.DT_INT32, shapes=(), shared_name="dummy_queue")) # Clear all the gradients queues in case there are stale gradients. clear_queue_ops = [] with ops.control_dependencies([update_op]): for queue, dev in self._one_element_queue_list: with ops.device(dev): stale_grads = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_grads) for queue, dev in self._sparse_grad_queues_and_devs: with ops.device(dev): _, stale_indices = queue.dequeue_many(queue.size()) clear_queue_ops.append(stale_indices) with ops.device(global_step.device): self._clean_up_op = control_flow_ops.abort( error_msg="From sync_replicas") # According to the staleness, select between the enqueue op (real_grad) # or no-op (no_op_grad). Effectively dropping all the stale gradients. no_op_grad = lambda: [control_flow_ops.no_op(name="no_grad_enqueue")] real_grad = lambda: [control_flow_ops.group(*train_ops)] final_train_ops = control_flow_ops.cond(is_stale, no_op_grad, real_grad) with ops.device(global_step.device), ops.name_scope(""): # Replicas have to wait until they can get a token from the token queue. with ops.control_dependencies([final_train_ops]): token = sync_token_queue.dequeue() train_op = state_ops.scatter_update(self._local_steps, self._replica_id, token) with ops.control_dependencies(clear_queue_ops): # Sync_op needs to insert tokens to the token queue at the end of the # step so the replicas can fetch them to start the next step. # Note that ref() is used to avoid reading from the identity with old # the step. tokens = array_ops.fill([self._tokens_per_step], global_step.ref()) sync_op = sync_token_queue.enqueue_many((tokens,)) if self._variable_averages is not None: with ops.control_dependencies([sync_op]), ops.name_scope(""): sync_op = self._variable_averages.apply( self._variables_to_average) self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue, [sync_op]) self._gradients_applied = True return train_op