def _fake_mirrored(value, devices): """Create a faked Mirrored object for testing. All components of the returned Mirrored have the same objects, which is not true in reality. """ devices = _get_devices(devices) values = [] for d in devices: with ops.device(d): values.append(array_ops.identity(value)) return distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)
def testNamedTuple(self): # We include toy implementations of Scaffold and EstimatorSpec to # avoid a dependency on Estimator here. class Scaffold(object): pass class EstimatorSpec( collections.namedtuple( "EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])): def __new__(cls, mode, loss, train_op, scaffold=None): return super(EstimatorSpec, cls).__new__(cls, mode=mode, loss=loss, train_op=train_op, scaffold=scaffold or Scaffold()) with context.graph_mode(), ops.Graph().as_default(): created_estimator_specs = [] for device_id in range(3): spec = EstimatorSpec(mode=mode_keys.EstimatorModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity( constant_op.constant(device_id))) created_estimator_specs.append(spec) merged_estimator_spec = distribute_utils.regroup( created_estimator_specs) self.assertIsInstance(merged_estimator_spec, EstimatorSpec) self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.values[device_id]) self.assertEqual( created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.values[device_id]) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual( created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.values[device_id]) self.assertIsInstance( created_estimator_specs[device_id].scaffold, Scaffold) # Also test that we can undo the merge using select_replica() self.assertEqual( created_estimator_specs[device_id], distribute_utils.select_replica(device_id, merged_estimator_spec))
def simple_broadcast(value, destinations, always_mirrored=False): """Broadcast `value` to `destinations` using simple copies.""" devices = get_devices_from(destinations) if len(devices) == 1 and not always_mirrored: return cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, devices[0]) else: value_updates = [] for d in devices: value_updates.append( cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) return distribute_utils.regroup(value_updates, wrap_class=value_lib.Mirrored)
def batch_reduce(self, reduce_op, value_destination_pairs, experimental_hints=None): """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. This can be faster than multiple individual `reduce`s because we can fuse several tensors into one or multiple packs before reduction. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the `per_replica_value` will be reduced. value_destination_pairs: A list or a tuple of PerReplica objects (or tensors with device set if there is one device) and destinations. experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: a list of Mirrored objects. Raises: ValueError: if `value_destination_pairs` is not an iterable of tuples of PerReplica objects and destinations. """ # TODO(yuefengz): if destinations are different, split into several # `_batch_reduce` invocations. if not _validate_value_destination_pairs(value_destination_pairs): # If the first element of each pair is a tensor, we try to turn it into a # PerReplica object. value_destination_pairs = _normalize_value_destination_pairs( value_destination_pairs) for _, d in value_destination_pairs: validate_destinations(d) # Shortcut all PerReplica objects only contain one value. if self._num_between_graph_workers == 1 and _all_devices_match( value_destination_pairs) and len( value_destination_pairs[0][0].values) == 1: return [ distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) for v, _ in value_destination_pairs ] if experimental_hints is None: experimental_hints = collective_util.Hints() return self.batch_reduce_implementation(reduce_op, value_destination_pairs, experimental_hints)
def _call_for_each_replica(self, fn, args, kwargs): # For now, `fn` must be an @tf.function. # TODO(josh11b): Relax this restriction? Main problem is if # (a) executing eagerly, (b) `fn` not @tf.function, and # (c) executed frequently. assert isinstance(fn, def_function.Function) if _outside_run_graph() is not None: # Nested case, should just use outer function's context for things like # the current replica index. # TODO(josh11b): Test this case! with MirroredFunctionReplicaContext(self._container_strategy()): results = fn(*nest.map_structure(_unwrap_tensors, args), **nest.map_structure(_unwrap_tensors, kwargs)) return nest.map_structure(_wrap_tensors, results) _replica_index.graph_outside_run = ops.get_default_graph() return_values = [] try: with MirroredFunctionReplicaContext(self._container_strategy()): for index, device in enumerate(self._devices): _replica_index.current = index with ops.device(device): if context.executing_eagerly(): # NOTE: These functions need to execute concurrently if they # use a collective op. This is a particular concern with eager # execution. with context.execution_mode(context.ASYNC): return_values.append( fn( *distribute_utils.select_replica( index, args), **distribute_utils.select_replica( index, kwargs))) else: return_values.append( fn( *distribute_utils.select_replica( index, args), **distribute_utils.select_replica( index, kwargs))) finally: _replica_index.graph_outside_run = None _replica_index.current = None return distribute_utils.regroup(return_values)
def _make_per_replica(values, devices, regroup=False): devices = _get_devices(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the # PerReplica wrapper if it has only one value. if len(values) == 1 and regroup: with ops.device(devices[0]): placed_v = array_ops.identity(values[0]) return placed_v index = [] for d, v in zip(devices, values): with ops.device(d): placed_v = array_ops.identity(v) index.append(placed_v) return distribute_utils.regroup(index)
def reduce(self, reduce_op, per_replica_value, destinations, experimental_hints=None): """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `reduce_op` and put the result on `destinations`. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how per_replica_value will be reduced. per_replica_value: A `tf.distribute.DistributedValues` object or a tensor with device set. destinations: the reduction destinations. experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: a Mirrored object. Raises: ValueError: if per_replica_value can't be converted to a PerReplica object or if destinations aren't strings, Variables or DistributedValues """ if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica( per_replica_value) validate_destinations(destinations) # Shortcut if `per_replica_value` only contains one value. if self._num_between_graph_workers == 1 and len( per_replica_value.values) == 1 and _devices_match( per_replica_value, destinations): with ops.device(per_replica_value.values[0].device): v = array_ops.identity(per_replica_value.values[0]) return distribute_utils.regroup((v, ), wrap_class=value_lib.Mirrored) if experimental_hints is None: experimental_hints = collective_util.Hints() return self.reduce_implementation(reduce_op, per_replica_value, destinations, experimental_hints)
def testSameId(self): foo = object() result = distribute_utils.regroup((("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertLen(result, 2) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). result_0 = distribute_utils.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertLen(result_0, 2) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = distribute_utils.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertLen(result_1, 2) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def testRegroupCollectionsMapping(self): class CollectionsMappingBasedClass(collections.Mapping): """Class inherited from collections.Mapping.""" def __init__(self, *args, **kwargs): self._d = dict(*args, **kwargs) def __getitem__(self, key): return self._d.__getitem__(key) def __iter__(self): return iter(self._d) def __len__(self): return len(self._d) result = distribute_utils.regroup( (CollectionsMappingBasedClass(a="a1", b="b1"), CollectionsMappingBasedClass(a="a2", b="b2"))) self.assertIsInstance(result, CollectionsMappingBasedClass) self._is_per_replica(result["a"], ["a1", "a2"]) self._is_per_replica(result["b"], ["b1", "b2"])
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, experimental_hints): """All-reduce IndexedSlices across all workers in a batch.""" logging.log_first_n( logging.INFO, "Collective batch_all_reduce for IndexedSlices: " "%d all-reduces, group_size = %d" % (len(per_replica_values), self._group_size), 10) # Pass self._communication to the runtime as a communication hint. communication_hint = self._communication.value # For now, we use NCCL only when batch_size > 1. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL. if self._communication == CollectiveCommunication.NCCL and len( per_replica_values) == 1: communication_hint = CollectiveCommunication.AUTO.value gathered_values = [] with self._lock, ops.name_scope("allreduce"): for per_replica in per_replica_values: gathered_values.append( cross_device_utils.build_collective_gather_indexed_slices( per_replica.values, self._devices, self._group_size, self._collective_keys, communication_hint, timeout=experimental_hints.timeout_seconds)) mirrored = [] for value in gathered_values: if reduce_op == reduce_util.ReduceOp.MEAN: # Assume each worker has the same number of replicas. for i, v in enumerate(value): with ops.device(v.device): value[i].values = value[i].values / self._group_size mirrored.append( distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) return mirrored
def _ungroup_and_make_mirrored(grouped_reduced, destinations, reduce_op, num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before Mirrored objects are created if reduce_op is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_device_utils.aggregate_gradients_using*. destinations: a value to colocate the result with. reduce_op: Indicates how values will be aggregated. Accepted values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. num_between_graph_workers: number of workers in the between-graph replication. Returns: a list of Mirrored objects. """ num_replicas = len( get_devices_from(destinations)) * num_between_graph_workers index = [[] for _ in range(len(grouped_reduced[0]))] for per_replica_reduced in grouped_reduced: for i, (v, _) in enumerate(per_replica_reduced): if reduce_op == reduce_util.ReduceOp.MEAN: with ops.device(v.device): index[i].append(v / num_replicas) else: index[i].append(v) return [ distribute_utils.regroup(v, wrap_class=value_lib.Mirrored) for v in index ]
def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per replica, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} devices = distribution.extended.worker_devices thread_local_callables = _get_thread_local_configuration_callable() # TODO(isaprykin): Create these threads once instead of during every call. threads = [] for index in range(len(devices)): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = _MirroredReplicaThread( distribution, coord, index, devices, variable_creator_fn, fn, distribute_utils.caching_scope_local, distribute_utils.select_replica(index, args), distribute_utils.select_replica(index, kwargs), thread_local_callables) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = distribute_utils.regroup( tuple(t.merge_args for t in threads)) merge_kwargs = distribute_utils.regroup( tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope mtt_captured_var_scope = threads[0].captured_var_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update( t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ ops.control_dependencies(mtt_captured_control_deps), \ variable_scope.variable_scope(mtt_captured_var_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): t.merge_result = distribute_utils.select_replica( r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return distribute_utils.regroup(tuple(t.main_result for t in threads))
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_result = fn(ctx, iterator.get_next()) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self._local_results(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop(cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, wrap them in a Mirrored # container, else in a PerReplica container. if reduce_op is None: last_step_tensor_outputs_dict[name] = distribute_utils.regroup( output) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def testWrapAListOfTwoTuples(self): result = distribute_utils.regroup([("1", "2"), ("3", "4")]) self.assertIsInstance(result, tuple) self.assertLen(result, 2) self._is_per_replica(result[0], ("1", "3"), values.PerReplica) self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): values = [_make_indexed_slices(values, indices, dense_shape, d) for d in devices] return distribute_utils.regroup( values, wrap_class=value_lib.Mirrored)
def _call_for_each_replica(self, fn, args, kwargs): with distribute_lib.ReplicaContext( self._container_strategy(), replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): # TODO(rchao): Support multi-replica per worker or sync-group. return distribute_utils.regroup((fn(*args, **kwargs),))
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Remove None at the end of args as they are not replicatable # If there are None in the middle we can't do anything about it # so let those cases fail. # For example when Keras model predict is used they pass the targets as # None. We want to handle it here so all client libraries don't have to # do this as other strategies can handle None values better. while args and args[-1] is None: args = args[:-1] # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), distribute_utils.select_replica(i, args), distribute_utils.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if options.experimental_enable_dynamic_batch_size and replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): rank = input_tensor.shape.rank else: rank = np.ndim(input_tensor) maximum_shape = tensor_shape.TensorShape([None] * rank) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None if options.experimental_bucketizing_dynamic_shape: padding_spec = tpu.PaddingSpec.POWER_OF_TWO else: padding_spec = None with strategy.scope(): replicate_outputs = tpu.replicate( replicated_fn, replicate_inputs, device_assignment=self._device_assignment, maximum_shapes=maximum_shapes, padding_spec=padding_spec, xla_options=tpu.XLAOptions( use_spmd_for_xla_partitioning=False)) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if not isinstance(output, ops.Operation) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None or isinstance(result[0], ops.Operation): replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] return distribute_utils.regroup(replicate_outputs)
def testOneDevice(self): result = distribute_utils.regroup((_nested_value("1"), )) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), distribute_utils.select_replica(0, result))
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, experimental_hints): """All-reduce across all workers in a batch.""" batch_size = len(per_replica_values) # Pass self._communication to the runtime as a communication hint. communication = self._communication.value # For now, we use NCCL only when batch_size > 1. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL. if self._communication == CollectiveCommunication.NCCL and batch_size == 1: communication = CollectiveCommunication.AUTO.value # Reverse the lists so that there's better chance that values follows # the order in which they are calculated (e.g. when they're gradients), so # as to overlap calculation with communication. However, this may not be # optimal for cases like gradients of complicated non-sequential models. # # Note that we reverse the list before packing so that the first pack won't # be too small, since it's more likely for first few packs to have long # queuing time due to concurrent intense computation. # # TODO(b/147393503): explore solutions for optimal ordering. packs = cross_device_utils.pack_by_size( list(reversed(per_replica_values)), experimental_hints.bytes_per_pack) if batch_size > 1: logging.info( "Collective batch_all_reduce: %d all-reduces, num_devices = %d, " "group_size = %d, communication_hint = %s, num_packs = %d", batch_size, len(self._devices), self._group_size, communication, len(packs)) else: logging.log_first_n( logging.INFO, "Collective batch_all_reduce: %d all-reduces, " "num_devices = %d, group_size = %d, communication_hint = %s, " "num_packs = %d" % (batch_size, len(self._devices), self._group_size, communication, len(packs)), 10) reduced_values = [] with self._lock: for pack in packs: # By placing all CollectiveReduce ops in a pack under single name scope, # we ensure they will be picked up by the `ScopedAllocator` grappler # optimizer and packed into a single all-reduce. with ops.name_scope("allreduce"): for per_replica in pack: # Add control dependencies per device from the last gradients to the # current set, in order to serialize NCCL launches. if (communication == CollectiveCommunication.NCCL.value and reduced_values): control_inputs = list(reduced_values[-1]) else: control_inputs = None reduced_values.append( cross_device_utils.build_collective_reduce( per_replica.values, self._devices, self._group_size, self._collective_keys, "Add", "Id", communication, control_inputs, executors=self._executors, timeout=experimental_hints.timeout_seconds)) for e in self._executors: e.wait() mirrored = [] # Reverse the order of reduced value to recover the order in the input. for value in reversed(reduced_values): if reduce_op == reduce_util.ReduceOp.MEAN: for i, v in enumerate(value): with ops.device(v.device): value[i] = v / self._group_size mirrored.append( distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) return mirrored