Example #1
0
def create_local_cluster(num_workers, num_ps, protocol="grpc",
                         worker_config=None, ps_config=None):
  """Create and start local servers and return the associated `Server` objects.

  Example:
  ```python
  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)

  worker_sessions = [tf.Session(w.target) for w in workers]

  with tf.device("/job:ps/task:0"):
    ...
  with tf.device("/job:ps/task:1"):
    ...
  with tf.device("/job:worker/task:0"):
    ...
  with tf.device("/job:worker/task:1"):
    ...

  worker_sessions[0].run(...)
  ```

  Args:
    num_workers: Number of worker servers to start.
    num_ps: Number of PS servers to start.
    protocol: Communication protocol.  Allowed values are documented in
      the documentation of `tf.train.Server`.
    worker_config: (optional) ConfigProto to initialize workers. Can be used
      to instantiate multiple devices etc.
    ps_config: (optional) ConfigProto to initialize PS servers.

  Returns:
    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
    of `num_workers` objects of type `tf.train.Server` (all running locally);
    and `ps_servers` is a list of `num_ps` objects of similar type.

  Raises:
    ImportError: if portpicker module was not found at load time
  """
  if _portpicker_import_error:
    raise _portpicker_import_error  # pylint: disable=raising-bad-type
  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
  cluster_dict = {
      "worker": ["localhost:%s" % port for port in worker_ports],
      "ps": ["localhost:%s" % port for port in ps_ports]
  }
  cs = server_lib.ClusterSpec(cluster_dict)

  workers = [
      server_lib.Server(
          cs, job_name="worker", protocol=protocol, task_index=ix,
          config=worker_config, start=True)
      for ix in range(num_workers)
  ]
  ps_servers = [
      server_lib.Server(
          cs, job_name="ps", protocol=protocol, task_index=ix,
          config=ps_config, start=True)
      for ix in range(num_ps)
  ]

  return workers, ps_servers
def _run_std_server(cluster_spec=None,
                    task_type=None,
                    task_id=None,
                    session_config=None,
                    rpc_layer=None,
                    environment=None):
    """Runs a standard server."""
    # Check if the Server is already running. If so, assert that no configuration
    # options have changed, and return the existing Server. This allows us to
    # call `run_distribute_coordinator` multiple times.
    if getattr(_thread_local, "server", None) is not None:
        assert _thread_local.cluster_spec == cluster_spec
        assert _thread_local.task_type == task_type
        assert _thread_local.task_id == task_id
        assert _thread_local.session_config_str == repr(session_config)
        assert _thread_local.rpc_layer == rpc_layer
        assert _thread_local.environment == environment
        return _thread_local.server
    else:
        # This method is not thread-safe.
        _thread_local.server_started = True
        _thread_local.cluster_spec = cluster_spec
        _thread_local.task_type = task_type
        _thread_local.task_id = task_id
        _thread_local.session_config_str = repr(session_config)
        _thread_local.rpc_layer = rpc_layer
        _thread_local.environment = environment

    assert cluster_spec
    target = cluster_spec.task_address(task_type, task_id)
    if rpc_layer:
        target = rpc_layer + "://" + target

    class _FakeServer(object):
        """A fake server that runs a master session."""
        def start(self):
            # A tensorflow server starts when a remote session is created.
            logging.info(
                "Creating a remote session to start a TensorFlow server, "
                "target = %r, session_config=%r", target, session_config)
            session.Session(target=target, config=session_config)

        def join(self):
            while True:
                time.sleep(5)

    if environment == "google":
        server = _FakeServer()
    else:
        if session_config:
            logging.info(
                "Starting standard TensorFlow server, target = %r, session_config= "
                "%r", target, session_config)
        else:
            logging.info("Starting standard TensorFlow server, target = %r",
                         target)
        cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
        server = server_lib.Server(cluster_spec,
                                   job_name=task_type,
                                   task_index=task_id,
                                   config=session_config,
                                   protocol=rpc_layer)

    server.start()
    _thread_local.server = server
    return server
Example #3
0
 def testTwoServersSamePort(self):
     # Starting a server with the same target as the cached server should fail.
     server = self._cached_server
     with self.assertRaises(errors_impl.UnknownError):
         _ = server_lib.Server(
             {"local_2": [server.target[len("grpc://"):]]})
Example #4
0
 def testInvalidHostname(self):
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "port"):
         _ = server_lib.Server({"local": ["localhost"]},
                               job_name="local",
                               task_index=0)
def _create_cluster(num_workers,
                    num_ps,
                    has_chief=False,
                    has_eval=False,
                    protocol='grpc',
                    worker_config=None,
                    ps_config=None,
                    eval_config=None):
    """Creates and starts local servers and returns the cluster_spec dict."""
    if _portpicker_import_error:
        raise _portpicker_import_error  # pylint: disable=raising-bad-type
    worker_ports = [pick_unused_port() for _ in range(num_workers)]
    ps_ports = [pick_unused_port() for _ in range(num_ps)]

    cluster_dict = {}
    if num_workers > 0:
        cluster_dict['worker'] = [
            'localhost:%s' % port for port in worker_ports
        ]
    if num_ps > 0:
        cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
    if has_eval:
        cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
    if has_chief:
        cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]

    cs = server_lib.ClusterSpec(cluster_dict)

    for i in range(num_workers):
        server_lib.Server(cs,
                          job_name='worker',
                          protocol=protocol,
                          task_index=i,
                          config=worker_config,
                          start=True)

    for i in range(num_ps):
        server_lib.Server(cs,
                          job_name='ps',
                          protocol=protocol,
                          task_index=i,
                          config=ps_config,
                          start=True)

    if has_chief:
        server_lib.Server(cs,
                          job_name='chief',
                          protocol=protocol,
                          task_index=0,
                          config=worker_config,
                          start=True)

    if has_eval:
        server_lib.Server(cs,
                          job_name='evaluator',
                          protocol=protocol,
                          task_index=0,
                          config=eval_config,
                          start=True)

    return cluster_dict
def _create_cluster(num_workers,
                    num_ps,
                    has_chief=False,
                    has_eval=False,
                    protocol='grpc',
                    worker_config=None,
                    ps_config=None,
                    eval_config=None,
                    worker_name='worker',
                    ps_name='ps',
                    chief_name='chief'):
    """Creates and starts local servers and returns the cluster_spec dict."""

    worker_ports = [pick_unused_port() for _ in range(num_workers)]
    ps_ports = [pick_unused_port() for _ in range(num_ps)]

    cluster_dict = {}
    if num_workers > 0:
        cluster_dict[worker_name] = [
            'localhost:%s' % port for port in worker_ports
        ]
    if num_ps > 0:
        cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports]
    if has_eval:
        cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
    if has_chief:
        cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()]

    cs = server_lib.ClusterSpec(cluster_dict)

    for i in range(num_workers):
        server_lib.Server(cs,
                          job_name=worker_name,
                          protocol=protocol,
                          task_index=i,
                          config=worker_config,
                          start=True)

    for i in range(num_ps):
        server_lib.Server(cs,
                          job_name=ps_name,
                          protocol=protocol,
                          task_index=i,
                          config=ps_config,
                          start=True)

    if has_chief:
        server_lib.Server(cs,
                          job_name=chief_name,
                          protocol=protocol,
                          task_index=0,
                          config=worker_config,
                          start=True)

    if has_eval:
        server_lib.Server(cs,
                          job_name='evaluator',
                          protocol=protocol,
                          task_index=0,
                          config=eval_config,
                          start=True)

    return cluster_dict
Example #7
0
def _get_workers(num_workers, staleness):
    worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
    cluster_dict = {
        'worker': ['localhost:%s' % port for port in worker_ports],
        'ps': ['localhost:%s' % portpicker.pick_unused_port()]
    }
    cs = server_lib.ClusterSpec(cluster_dict)
    workers = [
        server_lib.Server(cs, job_name='worker', task_index=ix, start=True)
        for ix in range(num_workers)
    ]
    server_lib.Server(cs, job_name='ps', task_index=0, start=True)

    sessions = []
    graphs = []
    train_ops = []

    # To simulate stale cases, maintaining two queues for computing and
    # applying gradients respectively. In the phase of computing gradients,
    # all workers except chief worker compute gradients together and chief worker
    # computes after all other worers' computing finished. In the phase of
    # applying gradients, chief worker will first apply gradients, then all other
    # workers will apply gradients one by one. Therefore, the chief worker will
    # always have 0 staleness, each of all other workers will have a unique
    # staleness value from [1, num_workers).
    for worker_id in range(num_workers):
        graph = ops.Graph()
        with graph.as_default():
            global_step = training_util.create_global_step()
            var_0 = variables.VariableV1(0.0, name='v0')
            var_1 = variables.VariableV1(1.0, name='v1')
            compute_gradients_queue = data_flow_ops.FIFOQueue(
                -1,
                global_step.dtype.base_dtype,
                shapes=(),
                name='compute_gradients_queue',
                shared_name='compute_gradients_queue')
            apply_gradients_queue = data_flow_ops.FIFOQueue(
                -1,
                global_step.dtype.base_dtype,
                shapes=(),
                name='apply_gradients_queue',
                shared_name='apply_gradients_queue')

            # Gradients for loss on var_0 and var_1 will be 1.0.
            loss = 0 - var_0 - var_1
            sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
            stale_check_opt = (
                drop_stale_gradient_optimizer.DropStaleGradientOptimizer(
                    sgd_opt, staleness))

            # Compute gradients.
            if worker_id == 0:
                with ops.control_dependencies(
                    [compute_gradients_queue.dequeue_many(num_workers - 1)]):
                    grad_and_vars = stale_check_opt.compute_gradients(loss)
            else:
                grad_and_vars = stale_check_opt.compute_gradients(loss)
                with ops.control_dependencies([t[0] for t in grad_and_vars]):
                    worker_enqueue_op = compute_gradients_queue.enqueue(
                        global_step)

            # Apply gradients.
            if worker_id == 0:
                with ops.control_dependencies([
                        stale_check_opt.apply_gradients(
                            grad_and_vars, global_step)
                ]):
                    train_op = apply_gradients_queue.enqueue(global_step)
            else:
                with ops.control_dependencies([worker_enqueue_op]):
                    with ops.control_dependencies(
                        [apply_gradients_queue.dequeue()]):
                        with ops.control_dependencies([
                                stale_check_opt.apply_gradients(
                                    grad_and_vars, global_step)
                        ]):
                            train_op = apply_gradients_queue.enqueue(
                                global_step)

            sess = session.Session(workers[worker_id].target)

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)

    return sessions, graphs, train_ops