def create_local_cluster(num_workers, num_ps, protocol="grpc", worker_config=None, ps_config=None): """Create and start local servers and return the associated `Server` objects. Example: ```python workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) worker_sessions = [tf.Session(w.target) for w in workers] with tf.device("/job:ps/task:0"): ... with tf.device("/job:ps/task:1"): ... with tf.device("/job:worker/task:0"): ... with tf.device("/job:worker/task:1"): ... worker_sessions[0].run(...) ``` Args: num_workers: Number of worker servers to start. num_ps: Number of PS servers to start. protocol: Communication protocol. Allowed values are documented in the documentation of `tf.train.Server`. worker_config: (optional) ConfigProto to initialize workers. Can be used to instantiate multiple devices etc. ps_config: (optional) ConfigProto to initialize PS servers. Returns: A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list of `num_workers` objects of type `tf.train.Server` (all running locally); and `ps_servers` is a list of `num_ps` objects of similar type. Raises: ImportError: if portpicker module was not found at load time """ if _portpicker_import_error: raise _portpicker_import_error # pylint: disable=raising-bad-type worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] cluster_dict = { "worker": ["localhost:%s" % port for port in worker_ports], "ps": ["localhost:%s" % port for port in ps_ports] } cs = server_lib.ClusterSpec(cluster_dict) workers = [ server_lib.Server( cs, job_name="worker", protocol=protocol, task_index=ix, config=worker_config, start=True) for ix in range(num_workers) ] ps_servers = [ server_lib.Server( cs, job_name="ps", protocol=protocol, task_index=ix, config=ps_config, start=True) for ix in range(num_ps) ] return workers, ps_servers
def _run_std_server(cluster_spec=None, task_type=None, task_id=None, session_config=None, rpc_layer=None, environment=None): """Runs a standard server.""" # Check if the Server is already running. If so, assert that no configuration # options have changed, and return the existing Server. This allows us to # call `run_distribute_coordinator` multiple times. if getattr(_thread_local, "server", None) is not None: assert _thread_local.cluster_spec == cluster_spec assert _thread_local.task_type == task_type assert _thread_local.task_id == task_id assert _thread_local.session_config_str == repr(session_config) assert _thread_local.rpc_layer == rpc_layer assert _thread_local.environment == environment return _thread_local.server else: # This method is not thread-safe. _thread_local.server_started = True _thread_local.cluster_spec = cluster_spec _thread_local.task_type = task_type _thread_local.task_id = task_id _thread_local.session_config_str = repr(session_config) _thread_local.rpc_layer = rpc_layer _thread_local.environment = environment assert cluster_spec target = cluster_spec.task_address(task_type, task_id) if rpc_layer: target = rpc_layer + "://" + target class _FakeServer(object): """A fake server that runs a master session.""" def start(self): # A tensorflow server starts when a remote session is created. logging.info( "Creating a remote session to start a TensorFlow server, " "target = %r, session_config=%r", target, session_config) session.Session(target=target, config=session_config) def join(self): while True: time.sleep(5) if environment == "google": server = _FakeServer() else: if session_config: logging.info( "Starting standard TensorFlow server, target = %r, session_config= " "%r", target, session_config) else: logging.info("Starting standard TensorFlow server, target = %r", target) cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type) server = server_lib.Server(cluster_spec, job_name=task_type, task_index=task_id, config=session_config, protocol=rpc_layer) server.start() _thread_local.server = server return server
def testTwoServersSamePort(self): # Starting a server with the same target as the cached server should fail. server = self._cached_server with self.assertRaises(errors_impl.UnknownError): _ = server_lib.Server( {"local_2": [server.target[len("grpc://"):]]})
def testInvalidHostname(self): with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "port"): _ = server_lib.Server({"local": ["localhost"]}, job_name="local", task_index=0)
def _create_cluster(num_workers, num_ps, has_chief=False, has_eval=False, protocol='grpc', worker_config=None, ps_config=None, eval_config=None): """Creates and starts local servers and returns the cluster_spec dict.""" if _portpicker_import_error: raise _portpicker_import_error # pylint: disable=raising-bad-type worker_ports = [pick_unused_port() for _ in range(num_workers)] ps_ports = [pick_unused_port() for _ in range(num_ps)] cluster_dict = {} if num_workers > 0: cluster_dict['worker'] = [ 'localhost:%s' % port for port in worker_ports ] if num_ps > 0: cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] if has_eval: cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] if has_chief: cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()] cs = server_lib.ClusterSpec(cluster_dict) for i in range(num_workers): server_lib.Server(cs, job_name='worker', protocol=protocol, task_index=i, config=worker_config, start=True) for i in range(num_ps): server_lib.Server(cs, job_name='ps', protocol=protocol, task_index=i, config=ps_config, start=True) if has_chief: server_lib.Server(cs, job_name='chief', protocol=protocol, task_index=0, config=worker_config, start=True) if has_eval: server_lib.Server(cs, job_name='evaluator', protocol=protocol, task_index=0, config=eval_config, start=True) return cluster_dict
def _create_cluster(num_workers, num_ps, has_chief=False, has_eval=False, protocol='grpc', worker_config=None, ps_config=None, eval_config=None, worker_name='worker', ps_name='ps', chief_name='chief'): """Creates and starts local servers and returns the cluster_spec dict.""" worker_ports = [pick_unused_port() for _ in range(num_workers)] ps_ports = [pick_unused_port() for _ in range(num_ps)] cluster_dict = {} if num_workers > 0: cluster_dict[worker_name] = [ 'localhost:%s' % port for port in worker_ports ] if num_ps > 0: cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports] if has_eval: cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] if has_chief: cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()] cs = server_lib.ClusterSpec(cluster_dict) for i in range(num_workers): server_lib.Server(cs, job_name=worker_name, protocol=protocol, task_index=i, config=worker_config, start=True) for i in range(num_ps): server_lib.Server(cs, job_name=ps_name, protocol=protocol, task_index=i, config=ps_config, start=True) if has_chief: server_lib.Server(cs, job_name=chief_name, protocol=protocol, task_index=0, config=worker_config, start=True) if has_eval: server_lib.Server(cs, job_name='evaluator', protocol=protocol, task_index=0, config=eval_config, start=True) return cluster_dict
def _get_workers(num_workers, staleness): worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] cluster_dict = { 'worker': ['localhost:%s' % port for port in worker_ports], 'ps': ['localhost:%s' % portpicker.pick_unused_port()] } cs = server_lib.ClusterSpec(cluster_dict) workers = [ server_lib.Server(cs, job_name='worker', task_index=ix, start=True) for ix in range(num_workers) ] server_lib.Server(cs, job_name='ps', task_index=0, start=True) sessions = [] graphs = [] train_ops = [] # To simulate stale cases, maintaining two queues for computing and # applying gradients respectively. In the phase of computing gradients, # all workers except chief worker compute gradients together and chief worker # computes after all other worers' computing finished. In the phase of # applying gradients, chief worker will first apply gradients, then all other # workers will apply gradients one by one. Therefore, the chief worker will # always have 0 staleness, each of all other workers will have a unique # staleness value from [1, num_workers). for worker_id in range(num_workers): graph = ops.Graph() with graph.as_default(): global_step = training_util.create_global_step() var_0 = variables.VariableV1(0.0, name='v0') var_1 = variables.VariableV1(1.0, name='v1') compute_gradients_queue = data_flow_ops.FIFOQueue( -1, global_step.dtype.base_dtype, shapes=(), name='compute_gradients_queue', shared_name='compute_gradients_queue') apply_gradients_queue = data_flow_ops.FIFOQueue( -1, global_step.dtype.base_dtype, shapes=(), name='apply_gradients_queue', shared_name='apply_gradients_queue') # Gradients for loss on var_0 and var_1 will be 1.0. loss = 0 - var_0 - var_1 sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) stale_check_opt = ( drop_stale_gradient_optimizer.DropStaleGradientOptimizer( sgd_opt, staleness)) # Compute gradients. if worker_id == 0: with ops.control_dependencies( [compute_gradients_queue.dequeue_many(num_workers - 1)]): grad_and_vars = stale_check_opt.compute_gradients(loss) else: grad_and_vars = stale_check_opt.compute_gradients(loss) with ops.control_dependencies([t[0] for t in grad_and_vars]): worker_enqueue_op = compute_gradients_queue.enqueue( global_step) # Apply gradients. if worker_id == 0: with ops.control_dependencies([ stale_check_opt.apply_gradients( grad_and_vars, global_step) ]): train_op = apply_gradients_queue.enqueue(global_step) else: with ops.control_dependencies([worker_enqueue_op]): with ops.control_dependencies( [apply_gradients_queue.dequeue()]): with ops.control_dependencies([ stale_check_opt.apply_gradients( grad_and_vars, global_step) ]): train_op = apply_gradients_queue.enqueue( global_step) sess = session.Session(workers[worker_id].target) sessions.append(sess) graphs.append(graph) train_ops.append(train_op) return sessions, graphs, train_ops