def testMultipleReplicasPerWorker(self): devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:2/device:CPU:0" ] device_map = values.WorkerDeviceMap(devices, 2) replica_context = WorkerDeviceMapTest.ReplicaContext(3) self.assertEqual( "b", device_map.select_for_current_replica(["a", "b"], replica_context))
def _get_next_as_optional(iterator, strategy, name=None): """Returns an empty dataset indicator and the next input from the iterator.""" replicas = [] worker_has_values = [] worker_devices = [] for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) else: new_name = None with ops.device(worker): worker_has_value, next_element = ( iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access # Collective all-reduce requires explict devices for inputs. with ops.device("/cpu:0"): # Converting to integers for all-reduce. worker_has_value = math_ops.cast(worker_has_value, dtypes.int32) worker_devices.append(worker_has_value.device) worker_has_values.append(worker_has_value) # Make `replicas` a flat list of values across all replicas. replicas.append(next_element) # Run an all-reduce to see whether any worker has values. # TODO(b/131423105): we should be able to short-cut the all-reduce in some # cases. if getattr(strategy.extended, "_support_per_replica_values", True): worker_has_values = values.PerReplica( values.WorkerDeviceMap( worker_devices, num_replicas_per_worker=len( strategy.extended._input_workers._input_worker_devices)), # pylint: disable=protected-access worker_has_values) global_has_value = strategy.reduce(reduce_util.ReduceOp.SUM, worker_has_values, axis=None) else: assert len(worker_has_values) == 1 global_has_value = worker_has_values[0] global_has_value = array_ops.reshape( math_ops.cast(global_has_value, dtypes.bool), []) return global_has_value, replicas
def testBasic(self): devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:2/device:CPU:0" ] device_map = values.WorkerDeviceMap(devices, 1) self.assertAllEqual(devices, device_map.all_devices) # pylint:disable=pointless-statement with self.assertRaisesWithPredicateMatch( ValueError, "`WorkerDeviceMap` is not indexed by replicas"): device_map.devices_by_replica self.assertEqual(1, device_map.num_logical_devices) self.assertEqual(2, device_map.num_replicas_in_graph) self.assertEqual(0, device_map.logical_device_from_values(["a", "b"])) self.assertAllEqual(devices, device_map.logical_to_actual_devices(0)) replica_context = WorkerDeviceMapTest.ReplicaContext(1) self.assertEqual( "b", device_map.select_for_current_replica(["a", "b"], replica_context)) with self.assertRaisesWithPredicateMatch( ValueError, "`WorkerDeviceMap` not indexed by replicas"): device_map.replica_for_device(devices[1]) self.assertEqual("b", device_map.select_for_device(["a", "b"], devices[1])) with self.assertRaisesWithPredicateMatch( ValueError, "WorkerDeviceMap not indexed by replicas"): device_map.is_device_in_replica(devices[1], 1) self.assertEqual( "WorkerDeviceMap(('/job:worker/replica:0/task:0/device:CPU:0', " "'/job:worker/replica:0/task:2/device:CPU:0'), " "num_replicas_per_worker=1)", repr(device_map))