def predict(actions, state): state = state.copy() # break down the inputs along the batch dimension to form equal sized # tensors in each replica. num_replicas = strategy.num_replicas_in_sync actions = tf.split(actions, num_replicas) state = { key: tf.split(value, num_replicas) for key, value in state.items() } devices = values.ReplicaDeviceMap(strategy.extended.worker_devices) dist_actions = values.PerReplica(devices, tuple(actions)) dist_state = [] for i in range(num_replicas): dist_state.append({key: value[i] for key, value in state.items()}) dist_state = values.PerReplica(devices, tuple(dist_state)) dist_predictions = strategy.experimental_run_v2(model.predict, args=(dist_actions, dist_state)) dist_predictions = { key: strategy.experimental_local_results(value) for key, value in dist_predictions.items() } predictions = { key: tf.concat(value, axis=0) for key, value in dist_predictions.items() } return predictions
def observe(images, actions, rewards, state): images = tf.to_float(images) / 255.0 - 0.5 # break down the inputs along the batch dimension to form equal sized # tensors in each replica. num_replicas = strategy.num_replicas_in_sync images = tf.split(images, num_replicas) actions = tf.split(actions, num_replicas) state = { key: tf.split(value, num_replicas) for key, value in state.items() } devices = values.ReplicaDeviceMap(strategy.extended.worker_devices) dist_images = values.PerReplica(devices, tuple(images)) dist_actions = values.PerReplica(devices, tuple(actions)) dist_state = [] for i in range(num_replicas): dist_state.append({key: value[i] for key, value in state.items()}) dist_state = values.PerReplica(devices, tuple(dist_state)) _, dist_posteriors = strategy.experimental_run_v2(model.observe, args=(dist_actions, dist_images, dist_state)) dist_posteriors = { key: strategy.experimental_local_results(value) for key, value in dist_posteriors.items() } posteriors = { key: tf.concat(value, axis=0) for key, value in dist_posteriors.items() } posteriors = {key: value[:, -1] for key, value in posteriors.items()} posteriors['rewards'] = rewards[:, -1] return posteriors
def testCondWithValuesNotConvertibleToTensor(self): device_map = values.SingleDeviceMap("CPU") per_replica_1 = values.PerReplica(device_map, (set(["a"]),)) per_replica_2 = values.PerReplica(device_map, (set(["b", "c"]),)) condition = array_ops.placeholder(dtypes.bool, []) with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): control_flow_ops.cond( condition, lambda: per_replica_1, lambda: per_replica_2)
def testCondWithValuesNotConvertibleToTensor(self): per_replica_1 = values_lib.PerReplica(({"a"}, )) per_replica_2 = values_lib.PerReplica(({"b", "c"}, )) condition = array_ops.placeholder(dtypes.bool, []) with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): control_flow_ops.cond(condition, lambda: per_replica_1, lambda: per_replica_2)
def testPassPerReplica(self, distribution): @function.defun def fn1(mock_model, factor): return mock_model(factor) device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) factors = values.PerReplica(device_map, (5.0, 3.0)) expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25)) self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
def testCondWithValuesConvertibleToTensor(self): per_replica_1 = values_lib.PerReplica(("a", )) per_replica_2 = values_lib.PerReplica(("b", )) condition = array_ops.placeholder_with_default(True, []) result = control_flow_ops.cond(condition, lambda: per_replica_1, lambda: per_replica_2) self.assertLen(result.values, 1) self.assertAllEqual(result.values[0], "a")
def _gather(strategy, value): """Gathers a single value.""" # pylint: disable=protected-access if not isinstance(value, values.DistributedValues): value = values.PerReplica([ops.convert_to_tensor(value)]) if not isinstance(strategy.extended, collective_all_reduce_strategy.CollectiveAllReduceExtended): return array_ops.stack(value._values) assert len(strategy.extended.worker_devices) == len(value._values) inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] return strategy.gather(values.PerReplica(inputs), axis=0)
def testCondWithTensorValues(self): per_replica_1 = values_lib.PerReplica((constant_op.constant("a"), )) per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]), )) condition = array_ops.placeholder_with_default(True, []) result = control_flow_ops.cond(condition, lambda: per_replica_1, lambda: per_replica_2) self.assertLen(result.values, 1) self.assertAllEqual(result.values[0], "a")
def testCondWithValuesConvertibleToTensor(self): device_map = values.SingleDeviceMap("CPU") per_replica_1 = values.PerReplica(device_map, ("a",)) per_replica_2 = values.PerReplica(device_map, ("b",)) condition = array_ops.placeholder_with_default(True, []) result = control_flow_ops.cond( condition, lambda: per_replica_1, lambda: per_replica_2) self.assertEqual(per_replica_1.device_map, result.device_map) self.assertEqual(per_replica_1.logical_device, result.logical_device) self.assertLen(result.values, 1) self.assertAllEqual(result.values[0], "a")
def testTimeoutReduceSparse(self, communication, required_gpus): hints = collective_util.Hints(timeout_seconds=1) collective, devices, _ = self._get_test_objects( "worker", 0, num_gpus=required_gpus, communication=communication, use_strategy_object=False) remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec( self._cluster_spec), protocol="grpc") devices = [device_util.canonicalize(d) for d in devices] v = value_lib.PerReplica([ _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) ]) @def_function.function def reduce_sparse(): collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) # The collective should time out because we only launch it on worker-0, # while there're three workers in total. with self.assertRaises(errors.DeadlineExceededError): reduce_sparse() # Reset since collective failures poison the context. context._reset_context() # pylint: disable=protected-access
def testTypeSpec(self): vals = (constant_op.constant(1.), ) per_replica = values_lib.PerReplica(vals) spec = per_replica._type_spec self.assertEqual(spec._value_specs, (tensor_spec.TensorSpec([], dtypes.float32), ))
def testFunctionCanReturnPerReplica(self): f = def_function.function(lambda x: x) x = values_lib.PerReplica((constant_op.constant(1.), )) y = f(x) self.assertIsNot(x, y) nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) self.assertEqual(x._type_spec, y._type_spec)
def testShapeInvariantToComponentsExplicitShape(self): v1 = constant_op.constant([1., 1., 1.]) v2 = constant_op.constant([2., 2., 2.]) per_replica = values.PerReplica(values.SingleDeviceMap("CPU"), (v1, v2)) shape = [None] self.assertEqual(per_replica._shape_invariant_to_components(shape=shape), (shape, shape))
def testShapeInvariantToComponents(self): v1 = constant_op.constant(1.) v2 = constant_op.constant(2.) per_replica = values.PerReplica(values.SingleDeviceMap("CPU"), (v1, v2)) self.assertEqual(per_replica._shape_invariant_to_components(), (v1.shape, v2.shape))
def nccl_gradient_sync(input, layer_id, construct_log): with tf.name_scope("gradient_sync_"+str(layer_id)): from tensorflow.python.distribute import values as value_lib nccl = tf.contrib.distribute.AllReduceCrossDeviceOps(all_reduce_alg='hierarchical_copy') tower_gradients = construct_log["tower_gradients"] destinations = construct_log["tower_devices"] grad_var_towers = list(zip(*tower_gradients)) synchronized_grad_vars = [] batch_reduce_vals = [] valid_grad_var_towers = [] for tgv in grad_var_towers: if tgv[0][0] is not None: per_replica = value_lib.PerReplica({ device: gv[0] for device, gv in zip(destinations, tgv)}) batch_reduce_vals.append((per_replica, destinations)) valid_grad_var_towers.append(tgv) else: for gv in tgv: synchronized_grad_vars.append(gv) batch_mirrored = nccl.batch_reduce(tf.distribute.ReduceOp.MEAN, batch_reduce_vals) for tgv, mirrored in zip(valid_grad_var_towers, batch_mirrored): for device, gv in zip(destinations, tgv): with tf.device(device): synchronized_grad_vars.append((mirrored.get(device), gv[1])) construct_log["gradients"] = synchronized_grad_vars return input
def testDoesNotTriggerFunctionTracing(self): traces = [] @def_function.function def f(x): traces.append(None) # Only happens on trace. return x per_replica = values.PerReplica( values.SingleDeviceMap("CPU"), (constant_op.constant(1.),)) # Trace once. f(per_replica) self.assertNotEmpty(traces) del traces[:] per_replica_spec = per_replica._type_spec for _ in range(5): vals = per_replica_spec._to_components(per_replica) vals = [v * 2 for v in vals] per_replica = per_replica_spec._from_components(vals) output = f(per_replica) self.assertIsInstance(output, values.PerReplica) self.assertAllEqual(output._values, per_replica._values) self.assertAllEqual(output._device_map, per_replica._device_map) self.assertAllEqual(output._logical_device, per_replica._logical_device) self.assertEmpty(traces) # Make sure we're not re-tracing `f`.
def testStrategyExtendedUpdate(self, distribution, synchronization, aggregation): if len(distribution.extended.parameter_devices) != 2: self.skipTest("n/a: needs exactly two parameter devices") if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and aggregation != variables_lib.VariableAggregation.NONE): self.skipTest( "n/a: doesn't apply to ON_WRITE variable with aggregation") with distribution.scope(): v = variables_lib.Variable(0., synchronization=synchronization, aggregation=aggregation) value = values_lib.PerReplica([1., 2.]) assign_fn = lambda var, value: var.assign(value) self.evaluate( distribution.extended.update(v, assign_fn, args=(value, ))) self.assertAllEqual(self.evaluate(v.values), [1., 2.]) assign_add_fn = lambda var, value: var.assign_add(value) self.evaluate( distribution.extended.update(v, assign_add_fn, args=(value, ))) self.assertAllEqual(self.evaluate(v.values), [2., 4.]) assign_sub_fn = lambda var, value: var.assign_sub(value) self.evaluate( distribution.extended.update(v, assign_sub_fn, args=(value, ))) self.assertAllEqual(self.evaluate(v.values), [1., 2.]) read_assign_fn = lambda var, value: var.assign_add(var.value() + var. read_value()) self.evaluate( distribution.extended.update(v, read_assign_fn, args=(value, ))) self.assertAllEqual(self.evaluate(v.values), [3., 6.])
def testMirroredStratParaAsync(self): """Tests RNG/MirrorStrategy interaction #3. The user can create n independent RNGs outside strategy.scope(), where n is the number of replicas, and give one to each replica. The replicas can thus get different random-number streams. """ shape = [3, 4] dtype = dtypes.int32 gens = random.get_global_generator().split(count=2) devices = ["/cpu:0", test_util.gpu_device_name()] strat = MirroredStrategy(devices=devices) # Use `PerReplica` to specify which `gen` is sent to which replica gens = dist_values.PerReplica( device_map=dist_values.ReplicaDeviceMap(devices), values=[[g] for g in gens]) with strat.scope(): def f(gen): t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t results = strat.extended.call_for_each_replica( fn=f, args=gens) values = results.values self.assertAllEqual(2, len(values)) self.assertAllDifferent(values)
def _get_indexed_slices(self, devices, start_i, variable_length, as_per_replica=True): dense_shape = [10, 2] values = ([[1., 2.]], [[3., 4.]], [[2., 1.]], [[0., 0.]], [[3., 1.]], [[2., 1.]]) indices = ([1], [2], [3], [4], [5], [6]) # values and indices that have variable lengths. vl_values = ([[1., 2.], [3., 4.]], [[3., 4.]], [[2., 1.]], [[0., 0.]], [[3., 1.], [2., 1.]], [[2., 1.]]) vl_indices = ([1, 2], [2], [3], [4], [5, 6], [6]) indexed_slices = [] for i, d in enumerate(devices): idx = i + start_i indexed_slices.append( _make_indexed_slices( vl_values[idx] if variable_length else values[idx], vl_indices[idx] if variable_length else indices[idx], dense_shape, d)) if as_per_replica: per_replica = value_lib.PerReplica(indexed_slices) return per_replica else: return indexed_slices
def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, num_cores=None): super(TPUExtended, self).__init__(container_strategy) self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata( self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override self._num_cores_override = num_cores # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. device_map = { d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name } self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True
def _gather(strategy, value): """Gathers a single value.""" # pylint: disable=protected-access if not isinstance(value, values.DistributedValues): assert isinstance(value, core.Tensor) value = values.PerReplica([value]) if not isinstance( strategy.extended, collective_all_reduce_strategy.CollectiveAllReduceExtended): return array_ops.stack(value._values) assert len(strategy.extended.worker_devices) == len(value._values) inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] collective_keys = strategy.extended._collective_keys devices = strategy.extended.worker_devices group_size = strategy.num_replicas_in_sync @def_function.function def gather_fn(): gathered = cross_device_utils.build_collective_gather( inputs, devices, group_size, collective_keys) return distribute_utils.update_regroup(strategy.extended, gathered, group=True) return gather_fn()
def make_per_replica_value(value_fn, devices): """Creates a `PerReplica` object whose values reside in `devices`. Args: value_fn: a callable that takes one argument (`device_idx`) and should return the value that is going to be created on devices[device_idx]. devices: a list of device strings to create `PerReplica` values on. Returns: A `PerReplica` object. """ values = [] for device_idx, device in enumerate(devices): v = value_fn(device_idx) if isinstance(v, indexed_slices.IndexedSlicesValue): with ops.device(device): values.append( indexed_slices.IndexedSlices( values=array_ops.identity(v.values), indices=array_ops.identity(v.indices), dense_shape=array_ops.identity(v.dense_shape))) else: with ops.device(device): values.append(array_ops.identity(v)) return value_lib.PerReplica(values)
def _initialize_local(self, num_gpus, devices): """Initializes the object for local training.""" self._cluster_spec = None # Convert `num_gpus` into `devices`, shouldn't specify both. if devices is None: if num_gpus is None: num_gpus = context.num_gpus() if num_gpus == 0: devices = ["/device:CPU:0"] else: devices = ["/device:GPU:%d" % d for d in range(num_gpus)] elif num_gpus is not None: raise ValueError( "Must only specify one of `devices` and `num_gpus`.") self._num_gpus = num_gpus # TODO(yuefengz): consider setting the default device. assert devices, "Must specify at least one device." assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) self._device_index = values.PerReplica( {d: i for i, d in enumerate(devices)})
def _initialize_multi_worker(self, devices): """Initializes the object for multi-worker training.""" self._local_mode = False assert devices, "Must specify at least one device." assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? self._devices = tuple(device_util.resolve(d) for d in devices) self._canonical_device_set = set(self._devices) self._device_index = values.PerReplica( {d: i for i, d in enumerate(devices)}) device_dict = _group_device_list(devices) self._workers = [] self._worker_devices = [] for job in ["chief", "worker"]: for task in range(len(device_dict.get(job, []))): worker = "/job:%s/task:%d" % (job, task) self._workers.append(worker) self._worker_devices.append((worker, device_dict[job][task])) # Setting `_default_device` will add a device scope in the # distribution.scope. We set the default device to the first worker. When # users specify device under distribution.scope by # with tf.device("/cpu:0"): # ... # their ops will end up on the cpu device of its first worker, e.g. # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. self._default_device = self._workers[0] self._inferred_cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce( self._workers, _infer_num_gpus_per_worker(self._devices))
def testInconsistentShape(self): per_replica_values = [ value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), array_ops.ones([10, 10], dtype=dtypes.float32), ]), value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), input_layer.Input( shape=(10), batch_size=None, dtype=dtypes.float32), ]), ] packs = cross_device_utils.pack_by_size( per_replica_values, bytes_per_pack=1) self.assertLen(packs, 1) self.assertEqual(packs[0], per_replica_values)
def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_replica = value_lib.PerReplica((t0, t1)) self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
def testToComponents(self): device_map = values.SingleDeviceMap("CPU") vals = (constant_op.constant(1.),) per_replica = values.PerReplica(device_map, vals) logical_device = 0 self.assertEqual(per_replica._to_components(), vals) self.assertEqual(per_replica._component_metadata(), (device_map, logical_device))
def testTypeSpecRoundTrip(self): vals = (constant_op.constant(1.), ) per_replica = values_lib.PerReplica(vals) spec = per_replica._type_spec tensor_list = spec._to_components(per_replica) reconstructed = spec._from_components(tensor_list) self.assertAllEqual(per_replica.values, reconstructed.values)
def testFunctionCanReturnPerReplica(self): f = def_function.function(lambda x: x) x = values.PerReplica( values.SingleDeviceMap("CPU"), (constant_op.constant(1.),)) y = f(x) self.assertIsNot(x, y) for a, b in zip(x._to_components(), y._to_components()): self.assertAllEqual(a, b) self.assertEqual(x._component_metadata(), y._component_metadata())
def testTypeSpec(self): device_map = values.SingleDeviceMap("CPU") vals = (constant_op.constant(1.),) per_replica = values.PerReplica(device_map, vals) spec = per_replica._type_spec self.assertEqual(spec._value_specs, (tensor_spec.TensorSpec([], dtypes.float32),)) self.assertEqual(spec._device_map, per_replica.device_map) self.assertEqual(spec._logical_device, per_replica.logical_device)