def distributed_function(x, y, sample_weights, learning_phase=None): """A single step of the distributed execution across replicas.""" del learning_phase # TODO(b/129653859): Simplify after PerReplica can be the input of # `def_function.function`. `regroup` calls and re-wrapping in # PerReplica won't be needed then. if isinstance(strategy, one_device_strategy.OneDeviceStrategy): device_map = values.SingleDeviceMap(devices[0]) wrap_class = lambda d, x: x else: device_map = values.ReplicaDeviceMap(devices) wrap_class = values.PerReplica # Transform each lists of lists of values into per replica objects # in the case of mirrored strategy. For example, for 2 replicas: # [[x0, y0], [x1, y1]] > [PerReplica(d0:x0, d1:x1), # PerReplica(d0:y0, d1:y1)] x = values.regroup(device_map, x, wrap_class) y = values.regroup(device_map, y, wrap_class) if y else None sample_weights = values.regroup(device_map, sample_weights, wrap_class) if sample_weights else None # Call `Model.{train,test,predict}_on_batch` on every replica passing # PerReplicas as arguments. On every replica inside this call, each # PerReplica object will return the value for that replica. The outputs # are PerReplicas too. outputs = strategy.experimental_run_v2( per_replica_function, args=(x, y, sample_weights)) # Out of PerReplica outputs reduce or pick values to return. all_outputs = unwrap_outputs( strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) return all_outputs
def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: replicas = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_deprecated( new_name)) return values.regroup(self._input_workers.device_map, replicas) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): """This function will throw an OutOfRange error.""" # As this will be only called when there is no data left, so calling # get_next() will trigger an OutOfRange error. data = self._iterators[worker_index].get_next(device) out_of_range_replicas.append(data) return data global_has_value, replicas = _get_next_as_optional( self, self._strategy) results = [] for i, worker in enumerate(self._input_workers.worker_devices): with ops.device(worker): devices = self._input_workers.compute_devices_for_worker(i) for j, device in enumerate(devices): with ops.device(device): # pylint: disable=undefined-loop-variable # pylint: disable=cell-var-from-loop # It is fine for the lambda to capture variables from the loop as # the lambda is executed in the loop as well. result = control_flow_ops.cond( global_has_value, lambda: replicas[i][j], lambda: out_of_range_fn(i, device)) # pylint: enable=cell-var-from-loop # pylint: enable=undefined-loop-variable results.append(result) replicas = results # Some dimensions in `replicas` will become unknown after we conditionally # return the real tensors or the dummy tensors. We fix the input shapes by # using the shapes from `out_of_range_replicas` because it is calling # get_next() inside. flattened_replicas = nest.flatten(replicas) for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): flattened_replicas[i].set_shape(replica_data.get_shape()) replicas = nest.pack_sequence_as(replicas, flattened_replicas) return values.regroup(self._input_workers.device_map, replicas)
def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: replicas = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_deprecated(new_name)) return values.regroup(self._input_workers.device_map, replicas) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): """This function will throw an OutOfRange error.""" # As this will be only called when there is no data left, so calling # get_next() will trigger an OutOfRange error. data = self._iterators[worker_index].get_next(device) out_of_range_replicas.append(data) return data global_has_value, replicas = _get_next_as_optional(self, self._strategy) results = [] for i, worker in enumerate(self._input_workers.worker_devices): with ops.device(worker): devices = self._input_workers.compute_devices_for_worker(i) for j, device in enumerate(devices): with ops.device(device): # pylint: disable=undefined-loop-variable # pylint: disable=cell-var-from-loop # It is fine for the lambda to capture variables from the loop as # the lambda is executed in the loop as well. result = control_flow_ops.cond(global_has_value, lambda: replicas[i][j], lambda: out_of_range_fn(i, device)) # pylint: enable=cell-var-from-loop # pylint: enable=undefined-loop-variable results.append(result) replicas = results # Some dimensions in `replicas` will become unknown after we conditionally # return the real tensors or the dummy tensors. We fix the input shapes by # using the shapes from `out_of_range_replicas` because it is calling # get_next() inside. flattened_replicas = nest.flatten(replicas) for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): flattened_replicas[i].set_shape(replica_data.get_shape()) replicas = nest.pack_sequence_as(replicas, flattened_replicas) return values.regroup(self._input_workers.device_map, replicas)
def testWrapAListOfTwoTuples(self): device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, [("1", "2"), ("3", "4")]) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ("1", "3"), values.PerReplica) self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
def testNested(self): device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"]) self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): devices = [] created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) devices.append(_device_str(device_id)) created_estimator_specs.append(spec) device_map = values.ReplicaDeviceMap(devices) merged_estimator_spec = values.regroup( device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEqual(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], values.select_replica(device_id, merged_estimator_spec))
def reduce(self, reduce_op, per_replica_value, destinations): """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `reduce_op` and put the result on `destinations`. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how per_replica_value will be reduced. per_replica_value: a PerReplica object or a tensor with device set. destinations: the reduction destinations. Returns: a Mirrored object. Raises: ValueError: if per_replica_value can't be converted to a PerReplica object or if destinations aren't strings, Variables or DistributedValues """ if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica(per_replica_value) validate_destinations(destinations) # Shortcut if `per_replica_value` only contains one value. if self._num_between_graph_workers == 1 and len( per_replica_value.values) == 1 and _devices_match( per_replica_value, destinations): return value_lib.regroup( per_replica_value.device_map, per_replica_value.values, wrap_class=value_lib.Mirrored) return self.reduce_implementation(reduce_op, per_replica_value, destinations)
def reduce_implementation(self, reduce_op, per_replica_value, destinations): all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0] device_map, logical_device = get_device_map_from(destinations) devices = device_map.logical_to_actual_devices(logical_device) if (isinstance(all_reduced, value_lib.Mirrored) and all_reduced.device_map is device_map and all_reduced.logical_device == logical_device): return all_reduced # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform # utility to access component for a particular device. if not isinstance(all_reduced, value_lib.Mirrored): all_reduced = value_lib.Mirrored( value_lib.SingleDeviceMap(all_reduced.device), [all_reduced]) index = [] with ops.control_dependencies(all_reduced.values): for d in devices: with ops.device(d): if d in all_reduced.devices: index.append(array_ops.identity(all_reduced.get(d))) else: # TODO(josh11b): Once we add support for model parallelism, get the # copy from the corresponding replica instead of the primary. index.append(array_ops.identity(all_reduced.primary)) return value_lib.regroup(device_map, index, wrap_class=value_lib.Mirrored)
def testNested(self): device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"]) self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(1), result)
def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), values.select_replica(0, result)) self.assertEqual(_nested_value("2"), values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) self.assertEqual(_nested_value("2"), values.select_device_mirrored(_device_str(1), result))
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): values = [ _make_indexed_slices(values, indices, dense_shape, d) for d in devices ] return value_lib.regroup(value_lib.ReplicaDeviceMap(devices), values, wrap_class=value_lib.Mirrored)
def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): devices = [] created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) devices.append(_device_str(device_id)) created_estimator_specs.append(spec) device_map = values.ReplicaDeviceMap(devices) merged_estimator_spec = values.regroup( device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEqual(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], values.select_replica(device_id, merged_estimator_spec))
def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest( "A GPU is not available for this test in eager mode.") v, device_map, mirrored = _make_mirrored() result = values.regroup(device_map, v) self.assertIs(mirrored, result)
def _experimental_distribute_values_from_function(self, value_fn): per_replica_values = [] for replica_id in range(self._num_replicas_in_sync): per_replica_values.append(value_fn( distribute_lib.ValueContext(replica_id, self._num_replicas_in_sync))) return values.regroup(per_replica_values, always_wrap=True)
def testOneDevice(self): result = values.regroup({_device_str(0): _nested_value("1")}) # On one device regroup() and select_device() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} mirrored = values.MirroredVariable(index, v, variable_scope.VariableAggregation.SUM) result = values.regroup(index) self.assertIs(mirrored, result)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shapes.append(input_tensor.get_shape()) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) fn_result = fn(ctx, fn_inputs) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self._unwrap(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, wrap them in a Mirrored # container, else in a PerReplica container. if reduce_op is None: last_step_tensor_outputs_dict[name] = values.regroup( {d: t for d, t in zip(self._devices, output)}, values.PerReplica) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def tpu_function(args, kwargs): """TF Function used to replicate the user computation.""" if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: if tensor_util.is_tensor(input_tensor): maximum_shape = input_tensor.get_shape() else: maximum_shape = tensor_shape.TensorShape(np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. if result[0] is None: replicate_outputs = [None] * len(replicate_outputs) else: replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def testOneDevice(self): device_map = values.ReplicaDeviceMap((_device_str(0),)) result = values.regroup(device_map, (_nested_value("1"),)) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) device_map = values.ReplicaDeviceMap((d,)) mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result)
def testOneDevice(self): device_map = values.ReplicaDeviceMap((_device_str(0),)) result = values.regroup(device_map, (_nested_value("1"),)) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) device_map = values.ReplicaDeviceMap((d,)) mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shapes.append(input_tensor.get_shape()) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def per_device_dataset(batch, devices): index = {} def get_ith(i_): return lambda x: x[i_] for i, d in enumerate(devices): index[d] = nest.map_structure(get_ith(i), batch) return values.regroup(index)
def _fake_mirrored(value, devices): """Create a faked Mirrored object for testing. All components of the returned Mirrored have the same objects, which is not true in reality. """ devices = _get_devices(devices) values = [] for d in devices: with ops.device(d): values.append(array_ops.identity(value)) return value_lib.regroup(values, wrap_class=value_lib.Mirrored)
def per_device_dataset(iterator, devices): batch = iterator.get_next() print(batch) index = {} def get_ith(i): return lambda x: x[i] for i, d in enumerate(devices): index[d] = nest.map_structure(get_ith(i), batch) return values.regroup(index)
def per_device_dataset(iterator, devices): """ Split a batch features into per-device features """ batch = iterator.get_next() index = {} def get_ith(i_): return lambda x: x[i_] for i, d in enumerate(devices): index[d] = nest.map_structure(get_ith(i), batch) return values.regroup(index)
def testNamedTuple(self): # We include toy implementations of Scaffold and EstimatorSpec to # avoid a dependency on Estimator here. class Scaffold(object): pass class EstimatorSpec(collections.namedtuple( "EstimatorSpec", ["mode", "loss", "train_op", "scaffold"])): def __new__(cls, mode, loss, train_op, scaffold=None): return super(EstimatorSpec, cls).__new__( cls, mode=mode, loss=loss, train_op=train_op, scaffold=scaffold or Scaffold()) with context.graph_mode(), ops.Graph().as_default(): devices = [] created_estimator_specs = [] for device_id in range(3): spec = EstimatorSpec( mode=mode_keys.EstimatorModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) devices.append(_device_str(device_id)) created_estimator_specs.append(spec) device_map = values.ReplicaDeviceMap(devices) merged_estimator_spec = values.regroup( device_map, created_estimator_specs) self.assertIsInstance(merged_estimator_spec, EstimatorSpec) self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) self.assertEqual(created_estimator_specs[device_id].loss, merged_estimator_spec.loss.get(d)) self.assertEqual(created_estimator_specs[device_id].train_op, merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) self.assertIsInstance(created_estimator_specs[device_id].scaffold, Scaffold) # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], values.select_replica(device_id, merged_estimator_spec))
def loop_body(has_data, data, state): """Executes `reduce_fn` in a loop till the dataset is empty.""" # data is list of lists here. where each list corresponds to one worker. # TODO(b/130570614): Add support for the multiworker and TPU pods use # case. if self._input_workers.num_workers == 1: data = data[0] else: raise ValueError("Dataset iteration within a tf.function is" " not supported for multiple workers.") per_replica_data = values.regroup(self._input_workers.device_map, data) state = reduce_fn(state, per_replica_data) has_data, data = _get_next_as_optional(iterator, self._strategy) return has_data, data, state
def loop_body(has_data, data, state): """Executes `reduce_fn` in a loop till the dataset is empty.""" del has_data # Unused. # data is list of lists here. where each list corresponds to one worker. # TODO(b/130570614): Add support for the multiworker and TPU pods use # case. if self._input_workers.num_workers == 1: data = data[0] else: raise ValueError("Dataset iteration within a tf.function is" " not supported for multiple workers.") state = reduce_fn(state, values.regroup(data)) has_data, data = _get_next_as_optional(iterator, self._strategy) return has_data, data, state
def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" replicas = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): # Make `replicas` a flat list of values across all replicas. replicas.extend(self._iterators[i].get_next_as_list(new_name)) return values.regroup(self._input_workers.device_map, replicas)
def simple_broadcast(value, destinations, always_mirrored=False): """Broadcast `value` to `destinations` using simple copies.""" device_map, logical_device = get_device_map_from(destinations) devices = device_map.logical_to_actual_devices(logical_device) if len(devices) == 1 and not always_mirrored: return cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, devices[0]) else: value_updates = [] for d in devices: value_updates.append( cross_device_utils.copy_tensor_or_indexed_slices_to_device( value, d)) return value_lib.regroup( device_map, value_updates, wrap_class=value_lib.Mirrored)
def get_next(self, name=None): """Scatter the input across hosts and devices.""" replicas = [] for worker, iterator in zip(self._input_workers.worker_devices, self._iterators): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): data_per_worker = iterator.get_next_as_list(name=new_name) # Append to replicas to get a flat list of values indexed by replica. replicas.extend(data_per_worker) return values.regroup(self._input_workers.device_map, replicas)
def get_next(self, name=None): """Scatter the input across hosts and devices.""" replicas = [] for worker, iterator in zip(self._input_workers.worker_devices, self._iterators): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): data_per_worker = iterator.get_next_as_list(name=new_name) # Append to replicas to get a flat list of values indexed by replica. replicas.extend(data_per_worker) return values.regroup(self._input_workers.device_map, replicas)
def _make_per_replica(values, devices, regroup=False): devices = _get_devices(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the # PerReplica wrapper if it has only one value. if len(values) == 1 and regroup: with ops.device(devices[0]): placed_v = array_ops.identity(values[0]) return placed_v index = [] for d, v in zip(devices, values): with ops.device(d): placed_v = array_ops.identity(v) index.append(placed_v) return value_lib.regroup(index)
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions." ) if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly(): raise NotImplementedError( "Eager mode not supported in TPUStrategy.") if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access raise NotImplementedError( "`experimental_run` is not compatible with " "`_disable_training_loop_on_host=True`") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, inputs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(inputs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def testSameId(self): foo = object() device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def testSameId(self): foo = object() device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_replica(), should undo the merge done by regroup(). result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1])
def batch_reduce(self, reduce_op, value_destination_pairs): """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the `per_replica_value` will be reduced. value_destination_pairs: a list or a tuple of PerReplica objects (or tensors with device set if there is one device) and destinations. Returns: a list of Mirrored objects. Raises: ValueError: if `value_destination_pairs` is not an iterable of tuples of PerReplica objects and destinations. """ # TODO(yuefengz): if destinations are different, split into several # `_batch_reduce` invocations. if not _validate_value_destination_pairs(value_destination_pairs): # If the first element of each pair is a tensor, we try to turn it into a # PerReplica object. value_destination_pairs = _normalize_value_destination_pairs( value_destination_pairs) for _, d in value_destination_pairs: validate_destinations(d) # Shortcut all PerReplica objects only contain one value. if self._num_between_graph_workers == 1 and _all_devices_match( value_destination_pairs) and len( value_destination_pairs[0][0].values) == 1: return [ value_lib.regroup( v.device_map, v.values, wrap_class=value_lib.Mirrored) for v, _ in value_destination_pairs ] return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly(): raise NotImplementedError("Eager mode not supported in TPUStrategy.") if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access raise NotImplementedError( "`experimental_run` is not compatible with " "`_disable_training_loop_on_host=True`") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, inputs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(inputs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: replicas = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_deprecated(new_name)) return values.regroup(self._input_workers.device_map, replicas) replicas = [] worker_has_values = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): worker_has_value, next_element = ( self._iterators[i].get_next_as_list(new_name)) worker_has_values.append(worker_has_value) # Make `replicas` a flat list of values across all replicas. replicas.append(next_element) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): """This function will throw an OutOfRange error.""" # As this will be only called when there is no data left, so calling # get_next() will trigger an OutOfRange error. data = self._iterators[worker_index].get_next(device) out_of_range_replicas.append(data) return data # `global_has_value` indicates whether there is data in this global batch. # We do a all-reduce across all the workers in the multi-worker case. # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy. if len(worker_has_values) > 1: with ops.device(self._input_workers.compute_devices_for_worker(0)[0]): # Place the tf.reduce_any op in device 0 to minimize communication # cost. # TODO(b/128545270): Investigate why placing it on worker 0 will cause # the entire data to copy back from device to host. global_has_value = math_ops.reduce_any(worker_has_values) else: global_has_value = worker_has_values[0] results = [] for i, worker in enumerate(self._input_workers.worker_devices): with ops.device(worker): devices = self._input_workers.compute_devices_for_worker(i) for j, device in enumerate(devices): with ops.device(device): # pylint: disable=undefined-loop-variable # pylint: disable=cell-var-from-loop # It is fine for the lambda to capture variables from the loop as # the lambda is executed in the loop as well. result = control_flow_ops.cond(global_has_value, lambda: replicas[i][j], lambda: out_of_range_fn(i, device)) # pylint: enable=cell-var-from-loop # pylint: enable=undefined-loop-variable results.append(result) replicas = results # Some dimensions in `replicas` will become unknown after we conditionally # return the real tensors or the dummy tensors. We fix the input shapes by # using the shapes from `out_of_range_replicas` because it is calling # get_next() inside. flattened_replicas = nest.flatten(replicas) for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): flattened_replicas[i].set_shape(replica_data.get_shape()) replicas = nest.pack_sequence_as(replicas, flattened_replicas) return values.regroup(self._input_workers.device_map, replicas)
def get_next(self, name=None): assert self._input_workers.num_workers == 1 data_list = self.get_next_as_list(name) return values.regroup(self._input_workers.device_map, data_list)
def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") v, device_map, mirrored = _make_mirrored() result = values.regroup(device_map, v) self.assertIs(mirrored, result)
def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run() # call. threads = [] for index, d in enumerate(distribution.extended.worker_devices): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = MirroredExtended._MirroredReplicaThread( # pylint: disable=protected-access distribution, coord, d, variable_creator_fn, fn, *values.select_device(d, args), **values.select_device(d, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError("Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup({t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope with ops.name_scope(mtt_captured_name_scope): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup({t.device: t.main_result for t in threads})