def _connect_to_cluster(self, coordinator_name): if coordinator_name in ["worker", "ps"]: raise ValueError("coordinator name should not be 'worker' or 'ps'.") cluster_spec = self._cluster_resolver.cluster_spec() self._num_workers = len(cluster_spec.as_dict().get("worker", ())) self._num_ps = len(cluster_spec.as_dict().get("ps", ())) device_filters = server_lib.ClusterDeviceFilters() # For any worker, only the devices on ps and coordinator nodes are visible for i in range(self._num_workers): device_filters.set_device_filters( "worker", i, ["/job:ps", "/job:%s" % coordinator_name]) # Similarly for any ps, only the devices on workers and coordinator are # visible for i in range(self._num_ps): device_filters.set_device_filters( "ps", i, ["/job:worker", "/job:%s" % coordinator_name]) # Allow at most one outstanding RPC for each worker at a certain time. This # is to simplify worker failure handling in the runtime os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" logging.info("%s is now connecting to cluster with cluster_spec: %r", self.__class__.__name__, cluster_spec) remote.connect_to_cluster( cluster_spec, job_name=coordinator_name, protocol=self._cluster_resolver.rpc_layer, cluster_device_filters=device_filters) distribute_lib.distribution_strategy_replica_gauge.get_cell( "ps_strategy_num_workers").set(self._num_workers) distribute_lib.distribution_strategy_replica_gauge.get_cell( "ps_strategy_num_ps").set(self._num_ps)
def DISABLED_testSimpleParameterServerWithDeviceFilters(self): cluster_device_filters = server_lib.ClusterDeviceFilters() for i in range(2): cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) remote.connect_to_cluster( self._cluster, cluster_device_filters=cluster_device_filters) with ops.device('/job:my_ps/task:0/device:CPU:0'): v1 = variables.Variable(initial_value=0) with ops.device('/job:my_ps/task:1/device:CPU:0'): v2 = variables.Variable(initial_value=10) @def_function.function def worker_fn(): v1.assign_add(1) v2.assign_sub(2) return v1.read_value() + v2.read_value() with ops.device('/job:my_worker/task:0/device:CPU:0'): self.assertAllEqual(worker_fn(), 9) with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 8) # The following remote call would fail because the ps nodes cannot see each # other due to the device filters. with self.assertRaises(errors.InvalidArgumentError) as cm: with ops.device('/job:my_ps/task:0/device:CPU:0'): worker_fn().numpy() self.assertIn( '/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', cm.exception.message) with self.assertRaises(errors.InvalidArgumentError) as cm: with ops.device('/job:my_ps/task:1/device:CPU:0'): worker_fn().numpy() self.assertIn( '/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', cm.exception.message) with ops.device('/job:my_worker/task:0/device:CPU:0'): self.assertAllEqual(worker_fn(), 7) with ops.device('/job:my_worker/task:1/device:CPU:0'): self.assertAllEqual(worker_fn(), 6) # Explicitly delete variables to avoid triggering errors when being GC'ed in # subsequent tests. del v1, v2
def __init__(self, cluster_resolver, client_name="chief"): """Initializes the cluster instance and connect to the remote cluster.""" if client_name in ["worker", "ps"]: raise ValueError("Client name should not be 'worker' or 'ps'.") cluster_spec = cluster_resolver.cluster_spec() self._num_workers = len(cluster_spec.as_dict().get("worker", ())) self._num_ps = len(cluster_spec.as_dict().get("ps", ())) device_filters = server_lib.ClusterDeviceFilters() # For any worker, only the devices on PS and chief nodes are visible for i in range(self._num_workers): device_filters.set_device_filters( "worker", i, ["/job:ps", "/job:%s" % client_name]) # Similarly for any ps, only the devices on workers and chief are visible for i in range(self._num_ps): device_filters.set_device_filters( "ps", i, ["/job:worker", "/job:%s" % client_name]) context.context().mirroring_policy = context.MIRRORING_ALL # Allow at most one outstanding RPC for each worker at a certain time. This # is to simplify worker failure handling in the runtime os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" remote.connect_to_cluster(cluster_spec, job_name=client_name, protocol=cluster_resolver.rpc_layer, cluster_device_filters=device_filters) self._cancellation_mgr = cancellation.CancellationManager() self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr) self.failure_handler = WorkerPreemptionHandler(context.get_server_def()) worker_device_strings = [ "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) ] self.workers = [ Worker(i, w, self) for i, w in enumerate(worker_device_strings) ]