コード例 #1
0
  def testMultipleReplicasPerWorker(self):
    devices = [
        "/job:worker/replica:0/task:0/device:CPU:0",
        "/job:worker/replica:0/task:2/device:CPU:0"
    ]
    device_map = values.WorkerDeviceMap(devices, 2)

    replica_context = WorkerDeviceMapTest.ReplicaContext(3)
    self.assertEqual(
        "b", device_map.select_for_current_replica(["a", "b"], replica_context))
コード例 #2
0
def _get_next_as_optional(iterator, strategy, name=None):
    """Returns an empty dataset indicator and the next input from the iterator."""
    replicas = []
    worker_has_values = []
    worker_devices = []
    for i, worker in enumerate(iterator._input_workers.worker_devices):  # pylint: disable=protected-access
        if name is not None:
            d = tf_device.DeviceSpec.from_string(worker)
            new_name = "%s_%s_%d" % (name, d.job, d.task)
        else:
            new_name = None

        with ops.device(worker):
            worker_has_value, next_element = (
                iterator._iterators[i].get_next_as_list(new_name))  # pylint: disable=protected-access
            # Collective all-reduce requires explict devices for inputs.
            with ops.device("/cpu:0"):
                # Converting to integers for all-reduce.
                worker_has_value = math_ops.cast(worker_has_value,
                                                 dtypes.int32)
                worker_devices.append(worker_has_value.device)
                worker_has_values.append(worker_has_value)
            # Make `replicas` a flat list of values across all replicas.
            replicas.append(next_element)

    # Run an all-reduce to see whether any worker has values.
    # TODO(b/131423105): we should be able to short-cut the all-reduce in some
    # cases.
    if getattr(strategy.extended, "_support_per_replica_values", True):
        worker_has_values = values.PerReplica(
            values.WorkerDeviceMap(
                worker_devices,
                num_replicas_per_worker=len(
                    strategy.extended._input_workers._input_worker_devices)),  # pylint: disable=protected-access
            worker_has_values)
        global_has_value = strategy.reduce(reduce_util.ReduceOp.SUM,
                                           worker_has_values,
                                           axis=None)
    else:
        assert len(worker_has_values) == 1
        global_has_value = worker_has_values[0]
    global_has_value = array_ops.reshape(
        math_ops.cast(global_has_value, dtypes.bool), [])
    return global_has_value, replicas
コード例 #3
0
    def testBasic(self):
        devices = [
            "/job:worker/replica:0/task:0/device:CPU:0",
            "/job:worker/replica:0/task:2/device:CPU:0"
        ]
        device_map = values.WorkerDeviceMap(devices, 1)
        self.assertAllEqual(devices, device_map.all_devices)

        # pylint:disable=pointless-statement
        with self.assertRaisesWithPredicateMatch(
                ValueError, "`WorkerDeviceMap` is not indexed by replicas"):
            device_map.devices_by_replica

        self.assertEqual(1, device_map.num_logical_devices)

        self.assertEqual(2, device_map.num_replicas_in_graph)

        self.assertEqual(0, device_map.logical_device_from_values(["a", "b"]))

        self.assertAllEqual(devices, device_map.logical_to_actual_devices(0))

        replica_context = WorkerDeviceMapTest.ReplicaContext(1)
        self.assertEqual(
            "b",
            device_map.select_for_current_replica(["a", "b"], replica_context))

        with self.assertRaisesWithPredicateMatch(
                ValueError, "`WorkerDeviceMap` not indexed by replicas"):
            device_map.replica_for_device(devices[1])

        self.assertEqual("b",
                         device_map.select_for_device(["a", "b"], devices[1]))

        with self.assertRaisesWithPredicateMatch(
                ValueError, "WorkerDeviceMap not indexed by replicas"):
            device_map.is_device_in_replica(devices[1], 1)

        self.assertEqual(
            "WorkerDeviceMap(('/job:worker/replica:0/task:0/device:CPU:0', "
            "'/job:worker/replica:0/task:2/device:CPU:0'), "
            "num_replicas_per_worker=1)", repr(device_map))