def _get_aggregated_sparse_grad(self, graph_item, var_op, grad, reduce_to_device, BFTaggregator): indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name)) values_op_name = strip_replica_prefix(get_op_name(grad.values.name)) dense_shape_op_name = strip_replica_prefix( get_op_name(grad.dense_shape.name)) indexed_slices_grads = [] for i in range(self.num_replicas): indices_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(indices_op_name, replica_prefix(i))) values_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(values_op_name, replica_prefix(i))) dense_shape_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i))) indexed_slices_grads.append( ops.IndexedSlices( values_op.outputs[utils.get_index_from_tensor_name( grad.values.name)], indices_op.outputs[utils.get_index_from_tensor_name( grad.indices.name)], dense_shape_op.outputs[utils.get_index_from_tensor_name( grad.dense_shape.name)])) return self._aggregate_sparse_gradients(var_op, reduce_to_device, indexed_slices_grads, values_op_name)
def _share_initializer(self, graph_item, var_op_name, master_replica=0): """Share the initializers of all replica variables to use initializer on replica=master_replica.""" # find the initial value of the var on master_replica master_var_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(var_op_name, replica_prefix(master_replica))) master_var = graph_item.trainable_var_op_to_var[master_var_op] master_init_tensor = graph_item.graph.get_tensor_by_name( master_var.initial_value.name) master_init_op = master_init_tensor.op # set the device of the init ops to reside on the chief device master_init_device = device_spec.DeviceSpecV2.from_string(master_init_op.device) \ .replace(task=0) master_init_op._set_device_from_string(master_init_device.to_string()) for i in range(0, self.num_replicas): if i == master_replica: continue var_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(var_op_name, replica_prefix(i))) var = graph_item.trainable_var_op_to_var[var_op] init_op = graph_item.graph.get_tensor_by_name( var.initial_value.name).op init_assign_op = get_consumers(init_op)[0] init_assign_op._update_input(1, master_init_tensor)
def _collect_dense_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') compressors = defaultdict( lambda: Compressor.create(self._compressor_type, var_op_name)) conf = CollectiveOpsConfig() conf.group_size = len(self.all_canonical_replica_devices) conf.group_key = get_collective_keys().get_group_key( self.all_canonical_replica_devices) conf.instance_key = get_collective_keys().get_instance_key(var_op_name) conf.merge_op = 'Add' conf.final_op = 'Div' if self._spec: setattr(conf, 'communication_hint', self._spec) for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance grad_consumers = get_consumers( grad.op) # this line must happen before the reduction # "\/" is added for name scope reuse with ops.name_scope( replica_prefix(i) + "/collective-group-{}/".format(self._group)): with ops.colocate_with(grad.op): reduced_grad = compressors[i].reduce(grad, conf) update_consumers(grad_consumers, grad, reduced_grad)
def _get_aggregated_dense_grad(self, graph_item, grad_name, reduce_to_device, BFTaggregator): grad_op_name = strip_replica_prefix(get_op_name(grad_name)) output_idx = get_index_from_tensor_name(grad_name) grad_ops = [ graph_item.graph.get_operation_by_name( ops.prepend_name_scope(grad_op_name, replica_prefix(i))) for i in range(self.num_replicas) ] # Aggregate gradients on `reduce_to_device` (usually CPU) with ops.device(reduce_to_device): #print("@@@@@@@@@@@@@@",[grad_op.outputs[output_idx] for grad_op in grad_ops]) ''' grad_sum_op_name = ops.prepend_name_scope(grad_op_name, u"%sAdd" % AUTODIST_PREFIX) grad_sum = math_ops.add_n([grad_op.outputs[output_idx] for grad_op in grad_ops], name=grad_sum_op_name) grad_avg_op_name = ops.prepend_name_scope(grad_op_name, u"%sDiv" % AUTODIST_PREFIX) grad_avg = math_ops.realdiv(grad_sum, self.num_replicas, name=grad_avg_op_name) ''' # BFT Aggregator gradients = [grad_op.outputs[output_idx] for grad_op in grad_ops] grad_avg = BFTaggregator.aggregate(gradients) #print("$$$$$$$$$$$$$$",grad_avg) return grad_avg
def between_graph_apply(self, graph_item, var_name): """ Apply between-graph synchronization to the target ops in the graph. Args: graph_item: The current graph. var_name: the variable to be synchronized. Returns: graph_item.GraphItem: updated graph item. """ if not self._sync: return graph_item item = graph_item # here the variable on replica:0 has been shared, so the original var_name won't work var_op_name = ops.prepend_name_scope(get_op_name(var_name), replica_prefix(0)) gradient, target, update_op = item.var_op_name_to_grad_info[ var_op_name] with item.graph.as_default(): proxy = self._create_proxy( item, gradient, target) if self._local_replication else None if proxy: proxy.update_colocation_group(item.get_colocation_op) with item.graph.name_scope(self._BETWEEN_GRAPH_APPLY_SCOPE): self._var_op_to_agg_grad, self._var_op_to_accum_apply_op = \ self._get_accumulation_ops(item, gradient, target, 1 if self._staleness > 0 else self.num_workers) self.add_sync_op(item, update_op, proxy) item.graph._names_in_use.pop(self._BETWEEN_GRAPH_APPLY_SCOPE) return item
def in_graph_apply(self, graph_item, var_name): """ Perform in-graph synchronization based on AllReduce and TensorFlow Collective Ops. Note that collective ops now only supports dense tensors. Args: graph_item (graph_item.GraphItem): the graph_item to be distributed var_name (str): the corresponded variable name Returns: graph_item.GraphItem: The new graph """ # Skip allreduce synchronizer when rank <= 1 if self.num_replicas * self.num_workers <= 1: return graph_item item = graph_item var_op_name = get_op_name(var_name) # Throw an error if the variable is sparse master_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(0)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[master_op_name] graph_item.var_queried.append(master_op_name) with item.graph.as_default(): self._share_initializer(item, var_op_name, master_replica=0) if isinstance(grad, ops.IndexedSlices): self._collect_sparse_gradients(item, var_op_name) else: self._collect_dense_gradients(item, var_op_name) return item
def _collect_sparse_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: raise NotImplementedError( 'Currently the collective NCCL AllGather is not supported in TensorFlow release.' 'Please choose another strategy.') conf = {} if self._spec: conf = {'communication_hint': self._spec} if self._compressor_type: logging.warning( 'AllGather currently does not support AutoDist compressor so it skips.' ) if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance indices_c_ops = grad.indices.consumers() indices_cc_ops = get_control_consumers(grad.indices.op) values_c_ops = grad.values.consumers() values_cc_ops = get_control_consumers(grad.values.op) with ops.name_scope(replica_prefix(i)): with ops.colocate_with(grad.indices.op): new_indices = collective_ops.all_gather( grad.indices, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-indices'), **conf) with ops.colocate_with(grad.values.op): new_values = collective_ops.all_gather( grad.values, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-values'), **conf) update_consumers(indices_c_ops, grad.indices, new_indices) update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) update_consumers(values_c_ops, grad.values, new_values) update_control_consumers(values_cc_ops, grad.values.op, new_values)
def in_graph_apply(self, graph_item, var_name): """ Apply in-graph ps synchronization. Args: graph_item: the old graph item var_name: the variable name w/o replica prefix Returns: graph_item.GraphItem """ item = graph_item var_op_name = get_op_name(var_name) master_replica_index = 0 with item.graph.as_default(): self._prune_control_dependencies( item, var_op_name, master_replica=master_replica_index) self._share_variable(item, var_op_name, master_replica=master_replica_index) master_var_name = ops.prepend_name_scope( var_name, replica_prefix(master_replica_index)) master_var_op_name = get_op_name(master_var_name) item.updated = True grad, target, update_op = item.var_op_name_to_grad_info_v2[ master_var_op_name] item.var_queried.append(master_var_op_name) agg_grad = self._aggregate_gradients(item, old_update_op=update_op, old_grad=grad, old_target=target) # update grad_target_pair and variable info for i in range(self.num_replicas): var_name_to_remove = ops.prepend_name_scope( var_name, replica_prefix(i)) item.pop_gradient_info(var_name=var_name_to_remove) if i != master_replica_index: item.info.pop_variable(var_name=var_name_to_remove) item.extend_gradient_info( grads=[agg_grad], targets=[item.graph.get_tensor_by_name(master_var_name)]) # TODO(Hao): Prune the graph to use unnecessary nodes return item
def _build_proxy_on(self, destination_device): """ Build a proxy of the original variable on `destination_device`. Args: destination_device (DeviceSpecV2): the destination device where the proxy is on. """ is_gpu = destination_device.device_type.upper( ) == 'GPU' if destination_device.device_type else False prefix = replica_prefix(destination_device.device_index ) if is_gpu else replica_prefix('CPU') with ops.device(destination_device): proxy_var = variable_scope.get_variable( ops.prepend_name_scope(self._this_op.name, prefix), dtype=self._dtype, initializer=self._initial_value, trainable=False) self._graph_item.info.update_variables( [proxy_var], replace=False) # Should we update graph_item.info? self._proxy_vars.append(proxy_var) self._proxy_var_init_ops.append( proxy_var.assign(get_read_var_tensor(self._this_op))) self._mirror_all_read_var_ops() self._update_all_consumers()
def _update_gradient_consumers(new_graph_item, consumer_ops, control_consumer_ops, old_tensor_name, new_tensor): """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0.""" # Get the original tensor (the one from replica 0) to replace old_op_name = strip_replica_prefix(get_op_name(old_tensor_name)) replica_0_op_name = ops.prepend_name_scope(old_op_name, replica_prefix(0)) replica_0_op = new_graph_item.graph.get_operation_by_name( replica_0_op_name) output_idx = get_index_from_tensor_name(old_tensor_name) replica_0_tensor = replica_0_op.outputs[output_idx] update_consumers(consumer_ops, replica_0_tensor, new_tensor) update_control_consumers(control_consumer_ops, replica_0_tensor.op, new_tensor.op)
def _remap_feed(self, feed, feed_val=None): """ Remap the feeds to the right element in the transformed graph. For example, there are N copies of a placeholder for N replicas and we have to feed all of them with tensors. Args: feed: feed graph element or name feed_val: feed value Returns: List of (new_feed, new_feed_value) pairs """ feed_name = feed if isinstance(feed, str) else feed.name try: transformed_feeds = [ self._graph_item.graph.as_graph_element(feed_name) ] except KeyError: transformed_feeds = [ self._graph_item.graph.as_graph_element( ops.prepend_name_scope(feed_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] num_replicated_feeds = self._graph_transformer.num_local_replicas feed = feed if not isinstance(feed, str) else transformed_feeds[0] def expand_feed_val(feed_val, feed=feed): """Given a original feed or replicated feed, expand the feed value.""" # If we have replicated placeholders with undefined (polymorphic) shape, we split the feed_val across it; # otherwise we feed all replicated placeholders the same feed_val polymorphic_dim = self._polymorphic_dim(feed) if polymorphic_dim: feed_vals = np.array_split(np.asarray(feed_val), num_replicated_feeds, axis=polymorphic_dim) else: feed_vals = [feed_val for _ in range(num_replicated_feeds)] return feed_vals if feed_val is not None: feed_vals = expand_feed_val(feed_val) transformed_feeds = list(zip(transformed_feeds, feed_vals)) return transformed_feeds, expand_feed_val
def _aggregate_sparse_gradients(self, var_op, reduce_to_device, indexed_slices_grads, values_op_name): with ops.device(reduce_to_device): grad_accum_op_name = ops.prepend_name_scope( values_op_name, u"%sAccum" % AUTODIST_PREFIX) grad_accum = data_flow_ops.SparseConditionalAccumulator( dtype=indexed_slices_grads[0].values.dtype, shape=var_op.outputs[0].shape, shared_name=grad_accum_op_name, name=grad_accum_op_name) accum_apply_ops = [ grad_accum.apply_indexed_slices_grad( indexed_slices_grads[i], MAX_INT64, name=ops.prepend_name_scope( values_op_name, u"%s-Accum-Apply" % replica_prefix(i))) for i in range(self.num_replicas) ] take_grad_op_name = ops.prepend_name_scope( values_op_name, u"%sTake-Grad" % AUTODIST_PREFIX) with ops.control_dependencies(accum_apply_ops): take_grad = grad_accum.take_indexed_slices_grad( self.num_replicas, name=take_grad_op_name) new_indices = take_grad.indices new_values = take_grad.values new_dense_shape = take_grad.dense_shape if indexed_slices_grads[0].indices.dtype != new_indices.dtype: new_indices = math_ops.cast( new_indices, indexed_slices_grads[0].indices.dtype, name=ops.prepend_name_scope( values_op_name, u"%sTake-Grad-Cast-Indices" % AUTODIST_PREFIX)) if indexed_slices_grads[ 0].dense_shape.dtype != new_dense_shape.dtype: new_dense_shape = math_ops.cast( new_dense_shape, indexed_slices_grads[0].dense_shape.dtype, name=ops.prepend_name_scope( values_op_name, u"%sTake-Grad-Cast-Shape" % AUTODIST_PREFIX)) return ops.IndexedSlices(new_values, new_indices, new_dense_shape)
def _prune_control_dependencies(self, graph_item, var_op_name, master_replica=0): """ Prune the control dependencies between the train_op on non-master replica and update op. Since the replicator will replicate the entire graph, the update op on non-master replica will also be replicated. If the train_op on non-master replica is fetched (which is the case in our current feed-fetch remap implementation), it will trigger those update ops and result in an unnecessary update over the trainable variables. This function prunes the control dependencies between train_op and any variable that bases on a PS syncer to avoid this situation. """ for i in range(self.num_replicas): if i == master_replica: continue this_var_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) _, _, update_op = graph_item.var_op_name_to_grad_info[ this_var_op_name] source_op = self._get_optimizer_source_op(update_op) remove_from_control_consumers(get_control_consumers(source_op), source_op)
def replicate(self, graph_item): """ Replicate the entire graph as many times as num_replica. Args: graph_item: the original graph item Returns: The new graph item """ item = GraphItem(graph=ops.Graph()) fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph) with item.graph.as_default(): gdef = graph_item.graph.as_graph_def() for i in range(self._num_local_replicas): # Replicate ops with ops.device(self._replica_device_placer(replica_id=i)): import_graph_def(gdef, name=replica_prefix(i)) # Replicate while_loop context (control_flow) if needed. # The order matters -- We must replicate bwd context first, then forward context. # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate # parent context first and then child context. See: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938 if bwd_ctx: for ctx in bwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) if fwd_ctx: for ctx in fwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) # update saver master_replica = 0 if graph_item.info.savers: item.info.update_savers( [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto() for proto in graph_item.info.savers], replace=False ) # update gradient info for i in range(self._num_local_replicas): for g_name, t_name in graph_item.grad_target_name_pairs.items(): if isinstance(g_name, tuple): new_g_name = ( ops.prepend_name_scope(g_name[0], replica_prefix(i)), ops.prepend_name_scope(g_name[1], replica_prefix(i)), ops.prepend_name_scope(g_name[2], replica_prefix(i))) else: new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i)) new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i)) item.extend_gradient_info_by_names( grads=[new_g_name], targets=[new_t_name] ) item.info.update_variables( [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto() for proto in graph_item.info.variables], replace=False ) item.info.update_table_initializers( [ops.prepend_name_scope(tb_init, replica_prefix(i)) for tb_init in graph_item.info.table_initializers], replace=False ) return item
def test_strip_replica_prefix(): for name in ['my_op', '^my_op', 'my_tensor:0']: new_name = ops.prepend_name_scope(name, replica_prefix(12)) assert strip_replica_prefix(new_name) == name
def _remap_fetch(self, fetch): """ Remap the user-provided fetches to the right list of fetches after graph transformations. Cases: * If original fetch exists (which is not affected by graph transformation), fetch the original. * Otherwise, for fetches that are train_ops, fetch them on all replicas; * for other fetches, only fetch it on master replica. * For example, for partitioned vars, it corresponds to the concat one as_tensor on the first replica. """ _remap_element = self._remap_element fetch_type = type(fetch) fetch_name = fetch if isinstance(fetch, str) else fetch.name contract_fn = lambda fetched_vals: fetched_vals[0] # noqa: E731 try: transformed_fetch = [_remap_element(fetch_type, fetch_name)] except KeyError: master_replica_name = ops.prepend_name_scope( fetch_name, replica_prefix(0)) master_replica_fetch = _remap_element(fetch_type, master_replica_name) polymorphic_dim = self._polymorphic_dim(master_replica_fetch) def is_train_op(op): # In TF2: train_op as AssignAddVariableOp # In TF1 (being deprecated): no_op with a groups of stateful ops as control dependencies # TODO(unless deprecating): make the checking as strict as possible return isinstance( op, ops.Operation) and (op.op_def.is_stateful or op.op_def.name == 'NoOp') if is_train_op(master_replica_fetch): transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] #################################################################### # # For Debugging Local Replicas #################################################################### # transformed_fetch = [ # self._graph_item.graph.as_graph_element('AutoDist-Replica-0/emb/part_0_take_grad') # ] # transformed_fetch = [ # _remap_element(ops.Tensor, ops.prepend_name_scope( # 'Mean:0', # replica_prefix(i))) # for i in range(self._graph_transformer.num_local_replicas) # ] # transformed_fetch = [_remap_element(ops.Tensor, # ops.prepend_name_scope( # 'sampled_softmax_loss/embedding_lookup:0', # replica_prefix(1) # ) # )] #################################################################### logging.debug('Fetch mapped from {} to {}'.format( fetch, transformed_fetch)) elif polymorphic_dim: transformed_fetch = [ _remap_element( fetch_type, ops.prepend_name_scope(fetch_name, replica_prefix(i))) for i in range(self._graph_transformer.num_local_replicas) ] contract_fn = lambda fetch_vals: np.concatenate( fetch_vals, axis=polymorphic_dim) # noqa: E731 else: transformed_fetch = [master_replica_fetch] return transformed_fetch, contract_fn
def _share_variable(self, graph_item, var_op_name, master_replica=0): """ Share the variable on the replica = `master_replica` (default to 0). Update inputs of consumers of the variable on replica > 0 to variable on replica=`master_replica`. Args: graph_item: the old graph item var_op_name: the name of the variable op of the variable to be shared master_replica: the index of master replica (default to 0) """ for i in range(0, self.num_replicas): if i == master_replica: continue this_var_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) this_var_op = graph_item.graph.get_operation_by_name( this_var_op_name) # Get all read variable ops to this replica variable read_var_ops = get_read_var_ops(this_var_op) # Get all consumers of its VarhandleOp, # excluding ReadVariableOps and those not in its variable scope handle_consumers = set(get_consumers(this_var_op)) handle_consumers.difference_update(set(read_var_ops)) handle_consumers.difference_update({ con for con in handle_consumers if con.name.startswith(this_var_op_name + '/') }) # We exclude the `update_op` when updating the consumers on the shared variables. # Because i) sharing variable indicates sharing its stateful ops correspondingly # (so it is ok to remove stateful ops in none-master replica but we just disconnect it) # ii) A variable cannot correspond to more than one update ops for now. handle_consumers.difference_update(set(graph_item.all_update_ops)) # update the consumers of all read variable ops to use the read variable ops of replica=master_replica for read_var_op in read_var_ops: new_read_var_op_name = ops.prepend_name_scope( ops.strip_name_scope(read_var_op.name, replica_prefix(i)), replica_prefix(master_replica)) new_read_var_op = graph_item.graph.get_operation_by_name( new_read_var_op_name) consumers = get_consumers(read_var_op) update_consumers(consumers, read_var_op.outputs[0], new_read_var_op.outputs[0]) update_colocation_group(consumers, read_var_op, new_read_var_op) # update the consumers of VarhandleOp to use the handle on replica=master_replica new_handle_op_name = ops.prepend_name_scope( ops.strip_name_scope(this_var_op_name, replica_prefix(i)), replica_prefix(master_replica)) new_handle_op = graph_item.graph.get_operation_by_name( new_handle_op_name) handle_consumers = list(handle_consumers) update_consumers(handle_consumers, this_var_op.outputs[0], new_handle_op.outputs[0]) update_colocation_group(handle_consumers, this_var_op, new_handle_op)