def _collect_dense_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') compressors = defaultdict( lambda: Compressor.create(self._compressor_type, var_op_name)) conf = CollectiveOpsConfig() conf.group_size = len(self.all_canonical_replica_devices) conf.group_key = get_collective_keys().get_group_key( self.all_canonical_replica_devices) conf.instance_key = get_collective_keys().get_instance_key(var_op_name) conf.merge_op = 'Add' conf.final_op = 'Div' if self._spec: setattr(conf, 'communication_hint', self._spec) for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance grad_consumers = get_consumers( grad.op) # this line must happen before the reduction # "\/" is added for name scope reuse with ops.name_scope( replica_prefix(i) + "/collective-group-{}/".format(self._group)): with ops.colocate_with(grad.op): reduced_grad = compressors[i].reduce(grad, conf) update_consumers(grad_consumers, grad, reduced_grad)
def _update_consumer(self, other, consumer_op): """ Update consumer. Args: other (Operation): Other resource var op. consumer_op (Operation): The new consumer. """ old_read_var_op = self._consumer_to_read_var_op[consumer_op] new_read_var_op = self._read_var_ops_mappings[other][old_read_var_op] update_consumers([consumer_op], old_tensor=old_read_var_op.outputs[0], new_tensor=new_read_var_op.outputs[0])
def _update_gradient_consumers(new_graph_item, consumer_ops, control_consumer_ops, old_tensor_name, new_tensor): """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0.""" # Get the original tensor (the one from replica 0) to replace old_op_name = strip_replica_prefix(get_op_name(old_tensor_name)) replica_0_op_name = ops.prepend_name_scope(old_op_name, replica_prefix(0)) replica_0_op = new_graph_item.graph.get_operation_by_name( replica_0_op_name) output_idx = get_index_from_tensor_name(old_tensor_name) replica_0_tensor = replica_0_op.outputs[output_idx] update_consumers(consumer_ops, replica_0_tensor, new_tensor) update_control_consumers(control_consumer_ops, replica_0_tensor.op, new_tensor.op)
def _collect_sparse_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: raise NotImplementedError( 'Currently the collective NCCL AllGather is not supported in TensorFlow release.' 'Please choose another strategy.') conf = {} if self._spec: conf = {'communication_hint': self._spec} if self._compressor_type: logging.warning( 'AllGather currently does not support AutoDist compressor so it skips.' ) if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance indices_c_ops = grad.indices.consumers() indices_cc_ops = get_control_consumers(grad.indices.op) values_c_ops = grad.values.consumers() values_cc_ops = get_control_consumers(grad.values.op) with ops.name_scope(replica_prefix(i)): with ops.colocate_with(grad.indices.op): new_indices = collective_ops.all_gather( grad.indices, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-indices'), **conf) with ops.colocate_with(grad.values.op): new_values = collective_ops.all_gather( grad.values, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-values'), **conf) update_consumers(indices_c_ops, grad.indices, new_indices) update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) update_consumers(values_c_ops, grad.values, new_values) update_control_consumers(values_cc_ops, grad.values.op, new_values)
def _handle_read(new_graph_item, var_op, partitioned_var): partitioned_var_tensor = partitioned_var.as_tensor() for op in get_consumers(var_op): op = new_graph_item.graph.get_operation_by_name( ops.prepend_name_scope(op.name, AUTODIST_TO_DELETE_SCOPE) ) if op.type == "ResourceGather": # Only Resource Variable needs to be taken care of # because ResourceGather consumes resource tensor rather than the tensor of read_var_op # Question: Is there any case where the op.type == "ResourceGather" # but we can't use embedding_lookup_v2 to reconstruct the op consuming a partitioned resource # The second input to a ResourceGather op is always the indices per the opdef emb_lookup = embedding_ops.embedding_lookup_v2(partitioned_var, ids=op.inputs[1]) update_consumers(get_consumers(op), op.outputs[0], emb_lookup) if is_read_var_op(op, version=1): # Without our modification, Reference Vars in TF have a read op associated with them. # TF can sometimes look for this and expect it to exist (e.g. in graph.as_graph_element) # so we add one back to avoid errors. # read_out is already the output tensor of the generated identity op read_out = array_ops.identity(partitioned_var_tensor, name=ops.prepend_name_scope("read", var_op.name)) update_consumers(get_consumers(op), op.outputs[0], read_out) elif is_read_var_op(op, version=2): read_out = array_ops.identity(partitioned_var_tensor, name=ops.prepend_name_scope("Read/ReadVariableOp", var_op.name)) update_consumers(get_consumers(op), op.outputs[0], read_out)
def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape): if indices is None: tensor = variable_utils.get_read_var_tensor(var_op) grad_accum = data_flow_ops.ConditionalAccumulator( grad.dtype, shape=tensor.get_shape(), shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_grad(grad, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_grad(num_accum_required, name=var_op.name + '_take_grad') update_consumers(grad_consumers, grad, agg_grad) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.op) else: grad_indexed_slices = ops.IndexedSlices( values=grad, indices=indices, dense_shape=dense_shape) grad_accum = data_flow_ops.SparseConditionalAccumulator( grad.dtype, shape=grad.shape, shared_name=var_op.name + "/grad_accum") # Get a copy of consumers list before creating accum_apply_op indices_consumers = list(indices.consumers()) grad_consumers = list(grad.consumers()) accum_apply_op = grad_accum.apply_indexed_slices_grad( grad_indexed_slices, local_step=MAX_INT64, name=grad.op.name + '_accum_apply_grad') agg_grad = grad_accum.take_indexed_slices_grad( num_accum_required, name=var_op.name + '_take_grad') agg_indices = agg_grad.indices if indices.dtype != agg_grad.indices.dtype: agg_indices = math_ops.cast(agg_grad.indices, indices.dtype) agg_grad = ops.IndexedSlices(values=agg_grad.values, indices=agg_indices, dense_shape=agg_grad.dense_shape) assert isinstance(agg_grad, ops.IndexedSlices) update_consumers(indices_consumers, indices, agg_grad.indices) update_consumers(grad_consumers, grad, agg_grad.values) update_control_consumers(get_control_consumers(indices.op), indices.op, agg_grad.indices.op) update_control_consumers(get_control_consumers(grad.op), grad.op, agg_grad.values.op) return accum_apply_op, agg_grad
def _share_variable(self, graph_item, var_op_name, master_replica=0): """ Share the variable on the replica = `master_replica` (default to 0). Update inputs of consumers of the variable on replica > 0 to variable on replica=`master_replica`. Args: graph_item: the old graph item var_op_name: the name of the variable op of the variable to be shared master_replica: the index of master replica (default to 0) """ for i in range(0, self.num_replicas): if i == master_replica: continue this_var_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) this_var_op = graph_item.graph.get_operation_by_name( this_var_op_name) # Get all read variable ops to this replica variable read_var_ops = get_read_var_ops(this_var_op) # Get all consumers of its VarhandleOp, # excluding ReadVariableOps and those not in its variable scope handle_consumers = set(get_consumers(this_var_op)) handle_consumers.difference_update(set(read_var_ops)) handle_consumers.difference_update({ con for con in handle_consumers if con.name.startswith(this_var_op_name + '/') }) # We exclude the `update_op` when updating the consumers on the shared variables. # Because i) sharing variable indicates sharing its stateful ops correspondingly # (so it is ok to remove stateful ops in none-master replica but we just disconnect it) # ii) A variable cannot correspond to more than one update ops for now. handle_consumers.difference_update(set(graph_item.all_update_ops)) # update the consumers of all read variable ops to use the read variable ops of replica=master_replica for read_var_op in read_var_ops: new_read_var_op_name = ops.prepend_name_scope( ops.strip_name_scope(read_var_op.name, replica_prefix(i)), replica_prefix(master_replica)) new_read_var_op = graph_item.graph.get_operation_by_name( new_read_var_op_name) consumers = get_consumers(read_var_op) update_consumers(consumers, read_var_op.outputs[0], new_read_var_op.outputs[0]) update_colocation_group(consumers, read_var_op, new_read_var_op) # update the consumers of VarhandleOp to use the handle on replica=master_replica new_handle_op_name = ops.prepend_name_scope( ops.strip_name_scope(this_var_op_name, replica_prefix(i)), replica_prefix(master_replica)) new_handle_op = graph_item.graph.get_operation_by_name( new_handle_op_name) handle_consumers = list(handle_consumers) update_consumers(handle_consumers, this_var_op.outputs[0], new_handle_op.outputs[0]) update_colocation_group(handle_consumers, this_var_op, new_handle_op)