Exemple #1
0
  def testDistributedFunctionBothServersReplaced(self):
    """Tests that replacing servers works correctly.

    We create two servers, t1 and t2. We first replace t2, then we replace t1.

    Among other things, this ensures that both already existing, and
    restarted workers have the context view IDs correctly updated.
    """
    with ops.device(self.device_local):
      x1 = array_ops.ones([2, 2])

    @def_function.function
    def worker_fn(i):
      with ops.device(self.device_t1):
        mul = math_ops.matmul(i, i)
      with ops.device(self.device_t2):
        add = mul + i
      return add - i

    # Forces function tracing and registration
    worker_fn.get_concrete_function(x1)

    # Replace task2
    context.update_server_def(server_def=self.server_def_s1_s3)
    for device in (self.device_t1, self.device_t2):
      with ops.device(device):
        y = worker_fn(x1)
      np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    # Then replace task1
    context.update_server_def(server_def=self.server_def_s4_s3)
    for device in (self.device_t1, self.device_t2):
      with ops.device(device):
        y = worker_fn(x1)
      np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #2
0
 def update_server_def_fn():
   for i in range(num_calls):
     lock.acquire()
     context.update_server_def(
         server_def=(self.server_def_s1_s2 if i %
                     2 == 0 else self.server_def_s1_s3))
     lock.release()
Exemple #3
0
  def testPendingNodesServerReplaced(self):
    """Update cluster when nodes are still pending on remote workers."""
    with ops.device(self.device_local):
      x1 = array_ops.ones([2, 2])

    @def_function.function
    def worker_fn(i):
      return math_ops.matmul(i, i)

    # Forces function tracing and registration
    worker_fn.get_concrete_function(x1)

    # Add enough ops so they are pending when changing the cluster
    num_nodes = 10
    ret = [None] * num_nodes
    for i in range(num_nodes):
      with ops.device(self.device_t1):
        ret[i] = worker_fn(x1)
    # While nodes are still pending on worker s1, replace worker s2 with s3.
    context.update_server_def(server_def=self.server_def_s1_s3)
    with ops.device(self.device_t2):
      y = worker_fn(x1)
    for i in range(num_nodes):
      np.testing.assert_array_equal([[2, 2], [2, 2]], ret[i].numpy())
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #4
0
  def testServerReplaced(self):
    """Replace remote host_port for a task, and run ops on cluster."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])

    context.update_server_def(server_def=self.server_def_s1_s3)
    with ops.device(self.device_t2):
      y = math_ops.matmul(x1, x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #5
0
  def testParameterServerMultiExecutors(self):
    context.update_server_def(server_def=self.server_def_s1_s2_s3_s4)

    with ops.device(self.device_t1):
      v1 = variables.Variable(initial_value=0.)
    with ops.device(self.device_t2):
      v2 = variables.Variable(initial_value=10.)

    @def_function.function
    def worker_fn():
      x1 = v1.read_value()
      x2 = v2.read_value()
      grad = (x1 + x2) * 0.1
      v1.assign_add(grad)
      v2.assign_sub(grad)
      return v1 + v2

    worker_fn.get_concrete_function()

    executor_t3 = executor.new_executor(enable_async=False)
    executor_t4 = executor.new_executor(enable_async=False)

    num_calls = 10
    self._coord = coordinator.Coordinator()

    def thread_fn(executor_obj, device, results):
      with self._coord.stop_on_exception():
        for i in range(num_calls):
          with context.executor_scope(executor_obj):
            with ops.device(device):
              results[i] = worker_fn()

    def update_server_def_fn():
      with self._coord.stop_on_exception():
        for _ in range(30):
          context.update_server_def(self.server_def_s1_s2_s3_s4)

    t3_results = [None] * num_calls
    t4_results = [None] * num_calls
    threads = []
    threads.append(
        threading.Thread(
            target=thread_fn, args=(executor_t3, self.device_t3, t3_results)))
    threads.append(
        threading.Thread(
            target=thread_fn, args=(executor_t4, self.device_t4, t4_results)))
    threads.append(threading.Thread(target=update_server_def_fn))
    for t in threads:
      t.start()
    self._coord.join(threads)

    # Cannot assert individual values since the results are non-deterministic.
    # By summing up the value we ensure that there are all reasonable and valid
    # numbers (not `None` or `NaN`).
    total = np.sum(t3_results + t4_results)
    self.assertGreater(total, 0)
Exemple #6
0
  def testFunctionServerReplaced(self):
    """Replace remote host_port for a task, and run functions on cluster."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])

    @def_function.function
    def worker_fn(i):
      return math_ops.matmul(i, i)

    # Forces function tracing and registration
    worker_fn.get_concrete_function(x1)

    context.update_server_def(server_def=self.server_def_s1_s3)
    with ops.device(self.device_t2):
      y = worker_fn(x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #7
0
  def testDistributedFunctionServerAdded(self):
    """Add a server to cluster, and run distributed function on it."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])

    @def_function.function
    def worker_fn(i):
      with ops.device(self.device_t2):
        mul = math_ops.matmul(i, i)
      return mul - array_ops.zeros_like(mul)

    # Forces function tracing and registration
    worker_fn.get_concrete_function(x1)

    context.update_server_def(server_def=self.server_def_s1_s2_s3)
    with ops.device(self.device_t3):
      y = worker_fn(x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #8
0
  def testServerAdded(self):
    """Add a server to cluster, and run remote ops on it."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])

    context.update_server_def(server_def=self.server_def_s1_s2_s3)
    with ops.device(self.device_t3):
      x2 = array_ops.ones([2, 2])

    # Test new server accessing resources on old server
    with ops.device(self.device_t3):
      y = math_ops.matmul(x1, x2)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    # Test old server accessing resources on new server
    with ops.device(self.device_t2):
      y = math_ops.matmul(x1, x2)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Exemple #9
0
  def testServerRemoved(self):
    """Remove a server from cluster, and run ops on cluster."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])
    with ops.device(self.device_t2):
      x2 = array_ops.ones([2, 2])

    with ops.device(self.device_t1):
      y = math_ops.matmul(x1, x2)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    context.update_server_def(server_def=self.server_def_s1)
    with ops.device(self.device_t1):
      y = math_ops.matmul(x1, x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    # Running ops on removed server s2 throws an exception
    with self.assertRaises(errors.InvalidArgumentError) as cm:
      with ops.device(self.device_t2):
        y = math_ops.matmul(x1, x2)
    self.assertIn("unknown device", cm.exception.message)
Exemple #10
0
  def testFunctionServerAdded(self):
    """Add a server to cluster, and run remote function on it."""
    with ops.device(self.device_t1):
      x1 = array_ops.ones([2, 2])

    @def_function.function
    def worker_fn(i):
      return math_ops.matmul(i, i)

    # Forces function tracing and registration
    worker_fn.get_concrete_function(x1)

    context.update_server_def(server_def=self.server_def_s1_s2_s3)
    with ops.device(self.device_t3):
      y = worker_fn(x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    with ops.device(self.device_t3):
      x2 = array_ops.ones([2, 2])
    with ops.device(self.device_t1):
      y = worker_fn(x2)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
    def testFunctionServerRemoved(self):
        """Remove a server from cluster, and run ops on cluster."""
        @def_function.function
        def worker_fn(i):
            return math_ops.matmul(i, i)

        with ops.device(self.device_t1):
            x1 = array_ops.ones([2, 2])

        # Forces function tracing and registration
        worker_fn.get_concrete_function(x1)

        context.update_server_def(server_def=self.server_def_s1)

        with ops.device(self.device_t1):
            y = worker_fn(x1)
        np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

        # Running functions on removed server s2 throws an exception
        with self.assertRaises(errors.InvalidArgumentError) as cm:
            with ops.device(self.device_t2):
                y = worker_fn(x1)
        self.assertIn(" unknown device", cm.exception.message)
Exemple #12
0
 def update_server_def_fn():
   with self._coord.stop_on_exception():
     for _ in range(30):
       context.update_server_def(self.server_def_s1_s2_s3_s4)
Exemple #13
0
 def update_server_def_fn():
   with self._coord.stop_on_exception():
     for i in range(num_calls):
       context.update_server_def(
           server_def=(self.server_def_s1_s2_s3 if i %
                       2 == 0 else self.server_def_s1_s2))
Exemple #14
0
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
  """
    if not context.executing_eagerly():
        raise ValueError(
            "`tf.config.experimental_connect_to_cluster` can only be called in "
            "eager mode.")
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
            # Do nothing if the master is local.
            return
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name
        # TODO(fishx): Update this to make sure remote worker has valid ip address
        # to connect with local.
        job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol,
                           default_session_config=context.context().config)

    if context.get_server_def() is None:
        context.set_server_def(server_def)
    else:
        context.update_server_def(server_def)

    if make_master_device_default and isinstance(
            cluster_spec_or_resolver, cluster_resolver.ClusterResolver
    ) and cluster_spec_or_resolver.master():
        master = cluster_spec_or_resolver.master()
        master_job_name = None
        master_task_id = None
        for job_name in cluster_spec.jobs:
            for task_id in cluster_spec.task_indices(job_name):
                task_address = cluster_spec.task_address(job_name, task_id)
                if master in task_address or task_address in master:
                    master_job_name = job_name
                    master_task_id = task_id
                    break

        if not master_job_name:
            raise ValueError(
                "`make_master_device_default` is set to True but cannot find "
                "master %s in the cluster" % master)

        master_device = "/job:{}/replica:0/task:{}".format(
            master_job_name, master_task_id)
        master_device = device_util.canonicalize(master_device)
        current_device = device_util.current()
        if current_device:
            current_device = device_util.canonicalize(current_device)
        if current_device and current_device != master_device:
            raise ValueError(
                "`connect_to_cluster` is called inside existing device "
                "scope %s, which is different from the master device "
                "scope %s to enter. This is not allowed." %
                (current_device, master_device))
        # TODO(b/138389076): Think of the entering device scope behavior in the
        # failure recovery case when dealing with preemptions.
        if not current_device:
            logging.info("Entering into master device scope: %s",
                         master_device)
            ops.device(master_device).__enter__()
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True,
                       cluster_device_filters=None):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Device filters can be specified to isolate groups of remote tasks to avoid
  undesired accesses between workers. Workers accessing resources or launching
  ops / functions on filtered remote devices will result in errors (unknown
  devices). For any remote task, if no device filter is present, all cluster
  devices will be visible; if any device filter is specified, it can only
  see devices matching at least one filter. Devices on the task itself are
  always visible. Device filters can be particially specified.

  For example, for a cluster set up for parameter server training, the following
  device filters might be specified:

  ```python
  cdf = tf.config.experimental.ClusterDeviceFilters()
  # For any worker, only the devices on PS nodes and itself are visible
  for i in range(num_workers):
    cdf.set_device_filters('worker', i, ['/job:ps'])
  # Similarly for any ps, only the devices on workers and itself are visible
  for i in range(num_ps):
    cdf.set_device_filters('ps', i, ['/job:worker'])

  tf.config.experimental_connect_to_cluster(cluster_def,
                                            cluster_device_filters=cdf)
  ```

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
    cluster_device_filters: an instance of
      `tf.train.experimental/ClusterDeviceFilters` that specify device filters
      to the remote tasks in cluster.
  """
    if not context.executing_eagerly():
        raise ValueError(
            "`tf.config.experimental_connect_to_cluster` can only be called in "
            "eager mode.")
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
            # Do nothing if the master is local.
            return
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
    if cluster_device_filters:
        if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
            cluster_device_filters = copy.deepcopy(
                cluster_device_filters._as_cluster_device_filters())  # pylint: disable=protected-access
        else:
            raise ValueError("`cluster_device_filters` must be an instance of "
                             "`tf.train.experimental.ClusterDeviceFilters`.")

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name

        ipstr = _get_local_ip_address(local_port)
        if ipstr:
            job_def.tasks[0] = "{}:{}".format(ipstr, local_port)
        else:
            job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol,
                           default_session_config=context.context().config,
                           cluster_device_filters=cluster_device_filters)

    if context.get_server_def() is None:
        context.set_server_def(server_def)
    else:
        context.update_server_def(server_def)

    if make_master_device_default and isinstance(
            cluster_spec_or_resolver, cluster_resolver.ClusterResolver
    ) and cluster_spec_or_resolver.master():
        master = cluster_spec_or_resolver.master()
        master_job_name = None
        master_task_id = None
        for job_name in cluster_spec.jobs:
            for task_id in cluster_spec.task_indices(job_name):
                task_address = cluster_spec.task_address(job_name, task_id)
                if master in task_address or task_address in master:
                    master_job_name = job_name
                    master_task_id = task_id
                    break

        if not master_job_name:
            raise ValueError(
                "`make_master_device_default` is set to True but cannot find "
                "master %s in the cluster" % master)

        master_device = "/job:{}/replica:0/task:{}".format(
            master_job_name, master_task_id)
        master_device = device_util.canonicalize(master_device)
        current_device = device_util.current()
        if current_device:
            current_device = device_util.canonicalize(current_device)
        if current_device and current_device != master_device:
            raise ValueError(
                "`connect_to_cluster` is called inside existing device "
                "scope %s, which is different from the master device "
                "scope %s to enter. This is not allowed." %
                (current_device, master_device))
        # TODO(b/138389076): Think of the entering device scope behavior in the
        # failure recovery case when dealing with preemptions.
        if not current_device:
            logging.info("Entering into master device scope: %s",
                         master_device)
            ops.device(master_device).__enter__()