def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape): if indices is None: tensor = variable_utils.get_read_var_tensor(var_op) grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=tensor.get_shape(), shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_grad(grad, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_grad(num_accum_required, name=var_op.name + '_take_grad') update_consumers(grad_consumers, grad, agg_grad) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.op) else: grad_indexed_slices = ops.IndexedSlices( values=grad, indices=indices, dense_shape=dense_shape) grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=grad.shape, shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op indices_consumers = list(indices.consumers()) grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_indexed_slices_grad( grad_indexed_slices, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_indexed_slices_grad( num_accum_required, name=var_op.name + '_take_grad') agg_indices = agg_grad.indices if indices.dtype != agg_grad.indices.dtype: agg_indices = math_ops.cast(agg_grad.indices, indices.dtype) agg_grad = ops.IndexedSlices(values=agg_grad.values, indices=agg_indices, dense_shape=agg_grad.dense_shape) assert isinstance(agg_grad, ops.IndexedSlices) update_consumers(indices_consumers, indices, agg_grad.indices) update_consumers(grad_consumers, grad, agg_grad.values) update_control_consumers(get_control_consumers(indices.op), indices.op, agg_grad.indices.op) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.values.op) return accum_apply_op, agg_grad
def _collect_sparse_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: raise NotImplementedError( 'Currently the collective NCCL AllGather is not supported in TensorFlow release.' 'Please choose another strategy.') conf = {} if self._spec: conf = {'communication_hint': self._spec} if self._compressor_type: logging.warning( 'AllGather currently does not support AutoDist compressor so it skips.' ) if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance indices_c_ops = grad.indices.consumers() indices_cc_ops = get_control_consumers(grad.indices.op) values_c_ops = grad.values.consumers() values_cc_ops = get_control_consumers(grad.values.op) with ops.name_scope(replica_prefix(i)): with ops.colocate_with(grad.indices.op): new_indices = collective_ops.all_gather( grad.indices, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-indices'), **conf) with ops.colocate_with(grad.values.op): new_values = collective_ops.all_gather( grad.values, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-values'), **conf) update_consumers(indices_c_ops, grad.indices, new_indices) update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) update_consumers(values_c_ops, grad.values, new_values) update_control_consumers(values_cc_ops, grad.values.op, new_values)
def _get_optimizer_source_op(update_op): """ Identify the additional no_op between update_op and train_op if it exists (for certain optimizers). Args: update_op Returns: source_op: the no_op if existed, otherwise the update_op itself. """ group_deps = [ op for op in get_control_consumers(update_op) if 'Adam' in op.name and 'group_deps' in op.name and op.type == 'NoOp' ] source_op = group_deps[0] if group_deps else update_op return source_op
def _prune_control_dependencies(self, graph_item, var_op_name, master_replica=0): """ Prune the control dependencies between the train_op on non-master replica and update op. Since the replicator will replicate the entire graph, the update op on non-master replica will also be replicated. If the train_op on non-master replica is fetched (which is the case in our current feed-fetch remap implementation), it will trigger those update ops and result in an unnecessary update over the trainable variables. This function prunes the control dependencies between train_op and any variable that bases on a PS syncer to avoid this situation. """ for i in range(self.num_replicas): if i == master_replica: continue this_var_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) _, _, update_op = graph_item.var_op_name_to_grad_info[ this_var_op_name] source_op = self._get_optimizer_source_op(update_op) remove_from_control_consumers(get_control_consumers(source_op), source_op)
def add_sync_op(self, graph_item, var_update_op, variable_replicator=None): """ Adds additional ops needed for synchronous distributed training into current graph. Main purpose of additional ops are: 1. Initialization 2. Synchronization 3. Gradient aggregation Args: graph_item (graph_item.GraphItem): the graph var_update_op: The op variable_replicator: The dictionary of master variable op name -> list of replicated variables, could be None Returns: None """ this_worker_cpu = device_spec.DeviceSpecV2.from_string( self.worker_device) this_worker_cpu = this_worker_cpu.replace(device_type='CPU', device_index=0) var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op is_trainable = var_op in graph_item.trainable_var_op_to_var source_op = self._get_optimizer_source_op(var_update_op) cc = get_control_consumers(source_op) with ops.device(var_op.device): if self._staleness == 0: queue_ops = self._get_queue_ops(var_update_op, source_op, self.is_chief, is_trainable) elif self._staleness > 0: queue_ops = self._get_queue_ops_stale(var_update_op, source_op, self.is_chief, is_trainable) else: raise ValueError( "staleness should be greater than or equal to 0.") # Only dense trainable variables are replicated locally if variable_replicator: mirror_variable_update_ops = variable_replicator.get_all_update_ops( queue_ops, worker_device=this_worker_cpu) with ops.device(this_worker_cpu): finish_op = control_flow_ops.group( *mirror_variable_update_ops) else: finish_op = control_flow_ops.group(*queue_ops) # Place computation ops of aggregated gradients on PS # Note that even though this is doing a graph traversal, it is called in such a way that it # only traverses from a gradient aggregator op to a gradient application op (or vice versa) -- # these corresponding ops should always be adjacent in the graph. self._place_post_grad_agg_ops( device_spec.DeviceSpecV2.from_string(self.target_device), self._var_op_to_agg_grad, {var_op: var_update_op} if is_trainable else {}) # Replace the control input of train_op to be finish_op # Note(Hao): this cc is stale, i.e. cc \subset get_control_consumers(source_op) update_control_consumers(cc, source_op, finish_op)