Пример #1
0
    def testRemote(self):
        gpus = config.list_logical_devices('GPU')
        self.assertNotEqual(len(gpus), 0)

        context.ensure_initialized()

        gpus = config.list_logical_devices('GPU')
        self.assertNotEqual(len(gpus), 0)
        for gpu in gpus:
            self.assertIsNotNone(gpu.name)

        context.ensure_initialized()

        job_name = 'test'
        cluster_def = cluster_pb2.ClusterDef()
        job_def = cluster_def.job.add()
        job_def.name = job_name
        job_def.tasks[0] = 'localhost:0'

        server_def = tensorflow_server_pb2.ServerDef(cluster=cluster_def,
                                                     job_name=job_name,
                                                     task_index=0,
                                                     protocol='grpc')

        context.set_server_def(server_def)

        gpus = config.list_logical_devices('GPU')
        for gpu in gpus:
            self.assertIsNotNone(gpu.name)
Пример #2
0
  def testRemote(self):
    gpus = config.list_logical_devices('GPU')
    self.assertNotEqual(len(gpus), 0)

    context.ensure_initialized()

    gpus = config.list_logical_devices('GPU')
    self.assertNotEqual(len(gpus), 0)
    for gpu in gpus:
      self.assertIsNotNone(gpu.name)

    context.ensure_initialized()

    job_name = 'test'
    cluster_def = cluster_pb2.ClusterDef()
    job_def = cluster_def.job.add()
    job_def.name = job_name
    job_def.tasks[0] = 'localhost:0'

    server_def = tensorflow_server_pb2.ServerDef(
        cluster=cluster_def, job_name=job_name, task_index=0, protocol='grpc')

    context.set_server_def(server_def)

    gpus = config.list_logical_devices('GPU')
    for gpu in gpus:
      self.assertIsNotNone(gpu.name)
Пример #3
0
def connect_to_remote_host(remote_host=None, job_name="worker"):
    """Connects to a single machine to enable remote execution on it.

  Will make devices on the remote host available to use. Note that calling this
  more than once will work, but will invalidate any tensor handles on the old
  remote devices.

  Using the default job_name of worker, you can schedule ops to run remotely as
  follows:
  ```python
  # Enable eager execution, and connect to the remote host.
  tf.compat.v1.enable_eager_execution()
  tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")

  with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
    # The following tensors should be resident on the remote device, and the op
    # will also execute remotely.
    x1 = array_ops.ones([2, 2])
    x2 = array_ops.ones([2, 2])
    y = math_ops.matmul(x1, x2)
  ```

  Args:
    remote_host: a single or a list the remote server addr in host-port format.
    job_name: The job name under which the new server will be accessible.

  Raises:
    ValueError: if remote_host is None.
  """
    if not remote_host:
        raise ValueError("Must provide at least one remote_host")

    remote_host = nest.flatten(remote_host)
    grpc_prefix = "grpc://"

    local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()

    cluster_def = ClusterDef()
    job_def = cluster_def.job.add()
    job_def.name = "localhost"
    # TODO(fishx): Update this to make sure remote worker has valid ip address
    # to connect with local.
    job_def.tasks[0] = "localhost:{}".format(local_port)

    job_def = cluster_def.job.add()
    job_def.name = job_name
    for i in range(len(remote_host)):
        if remote_host[i].startswith(grpc_prefix):
            job_def.tasks[i] = remote_host[i][len(grpc_prefix):]
        else:
            job_def.tasks[i] = remote_host[i]

    server_def = ServerDef(cluster=cluster_def,
                           job_name="localhost",
                           task_index=0,
                           protocol="grpc")

    # TODO(nareshmodi): Make this default since it works in more situations.
    os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
    context.set_server_def(server_def)
Пример #4
0
    def testServerDefChanged(self):
        """Update server def, and run ops on new cluster."""
        context.set_server_def(server_def=get_server_def(
            ALT_JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))

        with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME):
            x1 = array_ops.ones([2, 2])
        y = math_ops.matmul(x1, x1)
        np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

        # Set the server def back to JOB_NAME
        context.set_server_def(server_def=get_server_def(
            JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))

        with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
            x1 = array_ops.ones([2, 2])
        y = math_ops.matmul(x1, x1)
        np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Пример #5
0
  def testServerDefChanged(self):
    """Update server def, and run ops on new cluster."""
    context.set_server_def(
        server_def=get_server_def(
            ALT_JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))

    with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME):
      x1 = array_ops.ones([2, 2])
    y = math_ops.matmul(x1, x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())

    # Set the server def back to JOB_NAME
    context.set_server_def(
        server_def=get_server_def(
            JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))

    with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
      x1 = array_ops.ones([2, 2])
    y = math_ops.matmul(x1, x1)
    np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
Пример #6
0
 def setUp(self):
     # Start the local server.
     context.set_server_def(server_def=get_server_def(
         JOB_NAME,
         local_server_port=0,
         remote_server_addresses=[
             self._cached_server1_target, self._cached_server2_target
         ],
         task_index=0))
Пример #7
0
 def setUp(self):
     super(RemoteExecutionTest, self).setUp()
     local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
     context.set_server_def(server_def=get_server_def(
         JOB_NAME,
         local_server_port=local_port,
         remote_server_addresses=[
             self._cached_server1_target, self._cached_server2_target
         ],
         task_index=0))
 def setUp(self):
     # Start the local server.
     local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
     context.set_server_def(server_def=get_server_def(
         JOB_NAME,
         local_server_port=local_port,
         remote_server_addresses=[
             self._cached_server1_target, self._cached_server2_target
         ],
         task_index=0))
Пример #9
0
 def setUp(self):
   # Start the local server.
   context.set_server_def(
       server_def=get_server_def(
           JOB_NAME,
           local_server_port=0,
           remote_server_addresses=[
               self._cached_server1_target, self._cached_server2_target
           ],
           task_index=0))
Пример #10
0
 def setUp(self):
     super(RemoteReplicateTest, self).setUp()
     # Start the local server.
     local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
     context.set_server_def(server_def=_get_server_def(
         JOB_NAME,
         local_server_port=local_port,
         remote_server_addresses=[
             self._cached_server1_target, self._cached_server2_target
         ],
         task_index=0))
Пример #11
0
 def setUp(self):
     super(DynamicClusterTest, self).setUp()
     os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = str(False)
     local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
     context.set_server_def(server_def=get_server_def(
         JOB_NAME,
         local_server_port=local_port,
         remote_server_addresses=[
             self._cached_server1_target, self._cached_server2_target
         ],
         task_index=0))
Пример #12
0
def connect_to_remote_host(remote_host=None, job_name="worker"):
  """Connects to a single machine to enable remote execution on it.

  Will make devices on the remote host available to use. Note that calling this
  more than once will work, but will invalidate any tensor handles on the old
  remote devices.

  Using the default job_name of worker, you can schedule ops to run remotely as
  follows:
  ```python
  # Enable eager execution, and connect to the remote host.
  tf.compat.v1.enable_eager_execution()
  tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")

  with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
    # The following tensors should be resident on the remote device, and the op
    # will also execute remotely.
    x1 = array_ops.ones([2, 2])
    x2 = array_ops.ones([2, 2])
    y = math_ops.matmul(x1, x2)
  ```

  Args:
    remote_host: The addr of the remote server in host-port format.
    job_name: The job name under which the new server will be accessible.

  Raises:
    ValueError: if remote_host is None.
  """
  if remote_host is None:
    raise ValueError("Must provide an remote_host")

  grpc_prefix = "grpc://"
  if remote_host.startswith(grpc_prefix):
    remote_host = remote_host[len(grpc_prefix):]

  cluster_def = ClusterDef()
  job_def = cluster_def.job.add()
  job_def.name = job_name
  job_def.tasks[0] = "127.0.0.1:0"
  job_def.tasks[1] = remote_host

  server_def = ServerDef(
      cluster=cluster_def,
      job_name=job_name,
      task_index=0,
      protocol="grpc")

  # TODO(nareshmodi): Make this default since it works in more situations.
  os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
  context.set_server_def(server_def)
Пример #13
0
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
  """
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = cluster_spec.as_cluster_def()

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name
        # TODO(fishx): Update this to make sure remote worker has valid ip address
        # to connect with local.
        job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol)

    # TODO(nareshmodi): Make this default since it works in more situations.
    os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
    context.set_server_def(server_def)
Пример #14
0
  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
    super(RemoteExecutionTest, self).__init__(methodName)
    self._cached_server1 = server_lib.Server.create_local_server()
    self._cached_server2 = server_lib.Server.create_local_server()

    os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"

    self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
    self._cached_server2_target = self._cached_server2.target[len("grpc://"):]

    # Start the local server.
    context.set_server_def(
        server_def=get_server_def(
            JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))
Пример #15
0
    def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
        super(RemoteExecutionTest, self).__init__(methodName)
        self._cached_server1 = server_lib.Server.create_local_server()
        self._cached_server2 = server_lib.Server.create_local_server()

        os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"

        self._cached_server1_target = self._cached_server1.target[len("grpc://"
                                                                      ):]
        self._cached_server2_target = self._cached_server2.target[len("grpc://"
                                                                      ):]

        # Start the local server.
        context.set_server_def(server_def=get_server_def(
            JOB_NAME,
            local_server_port=0,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))
Пример #16
0
    def setUp(self):
        super(EagerClusterReplicateTest, self).setUp()

        if context.context().use_tfrt:
            self.skipTest(
                "b/171412104: This test requires distributed support.")

        # TODO(b/171412104): Move create server to __init__ once tfrt support it.
        self._cached_server1 = server_lib.Server.create_local_server()
        self._cached_server2 = server_lib.Server.create_local_server()
        self._cached_server1_target = self._cached_server1.target[len("grpc://"
                                                                      ):]
        self._cached_server2_target = self._cached_server2.target[len("grpc://"
                                                                      ):]

        # Start the local server.
        local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
        context.set_server_def(server_def=_get_server_def(
            self._job_name,
            local_server_port=local_port,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))
Пример #17
0
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
  """
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = cluster_spec.as_cluster_def()

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name
        # TODO(fishx): Update this to make sure remote worker has valid ip address
        # to connect with local.
        job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol)

    # TODO(nareshmodi): Make this default since it works in more situations.
    os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
    context.set_server_def(server_def)

    if make_master_device_default and isinstance(
            cluster_spec_or_resolver, cluster_resolver.ClusterResolver
    ) and cluster_spec_or_resolver.master():
        master = cluster_spec_or_resolver.master()
        master_job_name = None
        master_task_id = None
        for job_name in cluster_spec.jobs:
            for task_id in cluster_spec.task_indices(job_name):
                task_address = cluster_spec.task_address(job_name, task_id)
                if master in task_address or task_address in master:
                    master_job_name = job_name
                    master_task_id = task_id
                    break

        if not master_job_name:
            raise ValueError(
                "`make_master_device_default` is set to True but cannot find "
                "master %s in the cluster" % master)

        master_device = "/job:{}/replica:0/task:{}".format(
            master_job_name, master_task_id)
        if not _device_stack_is_empty():
            raise ValueError(
                "`connect_to_cluster` should not be called inside "
                "an existing device scope")
        logging.info("Entering into master device scope: %s", master_device)
        # TODO(b/138389076): Think of the entering device scope behavior in the
        # failure recovery case when dealing with preemptions.
        ops.device(master_device).__enter__()
Пример #18
0
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
  """
    if not context.executing_eagerly():
        raise ValueError(
            "`tf.config.experimental_connect_to_cluster` can only be called in "
            "eager mode.")
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
            # Do nothing if the master is local.
            return
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name
        # TODO(fishx): Update this to make sure remote worker has valid ip address
        # to connect with local.
        job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol,
                           default_session_config=context.context().config)

    if context.get_server_def() is None:
        context.set_server_def(server_def)
    else:
        context.update_server_def(server_def)

    if make_master_device_default and isinstance(
            cluster_spec_or_resolver, cluster_resolver.ClusterResolver
    ) and cluster_spec_or_resolver.master():
        master = cluster_spec_or_resolver.master()
        master_job_name = None
        master_task_id = None
        for job_name in cluster_spec.jobs:
            for task_id in cluster_spec.task_indices(job_name):
                task_address = cluster_spec.task_address(job_name, task_id)
                if master in task_address or task_address in master:
                    master_job_name = job_name
                    master_task_id = task_id
                    break

        if not master_job_name:
            raise ValueError(
                "`make_master_device_default` is set to True but cannot find "
                "master %s in the cluster" % master)

        master_device = "/job:{}/replica:0/task:{}".format(
            master_job_name, master_task_id)
        master_device = device_util.canonicalize(master_device)
        current_device = device_util.current()
        if current_device:
            current_device = device_util.canonicalize(current_device)
        if current_device and current_device != master_device:
            raise ValueError(
                "`connect_to_cluster` is called inside existing device "
                "scope %s, which is different from the master device "
                "scope %s to enter. This is not allowed." %
                (current_device, master_device))
        # TODO(b/138389076): Think of the entering device scope behavior in the
        # failure recovery case when dealing with preemptions.
        if not current_device:
            logging.info("Entering into master device scope: %s",
                         master_device)
            ops.device(master_device).__enter__()
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True,
                       cluster_device_filters=None):
    """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Device filters can be specified to isolate groups of remote tasks to avoid
  undesired accesses between workers. Workers accessing resources or launching
  ops / functions on filtered remote devices will result in errors (unknown
  devices). For any remote task, if no device filter is present, all cluster
  devices will be visible; if any device filter is specified, it can only
  see devices matching at least one filter. Devices on the task itself are
  always visible. Device filters can be particially specified.

  For example, for a cluster set up for parameter server training, the following
  device filters might be specified:

  ```python
  cdf = tf.config.experimental.ClusterDeviceFilters()
  # For any worker, only the devices on PS nodes and itself are visible
  for i in range(num_workers):
    cdf.set_device_filters('worker', i, ['/job:ps'])
  # Similarly for any ps, only the devices on workers and itself are visible
  for i in range(num_ps):
    cdf.set_device_filters('ps', i, ['/job:worker'])

  tf.config.experimental_connect_to_cluster(cluster_def,
                                            cluster_device_filters=cdf)
  ```

  Args:
    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
      use the default from `python/platform/remote_utils.py`.
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
    cluster_device_filters: an instance of
      `tf.train.experimental/ClusterDeviceFilters` that specify device filters
      to the remote tasks in cluster.
  """
    if not context.executing_eagerly():
        raise ValueError(
            "`tf.config.experimental_connect_to_cluster` can only be called in "
            "eager mode.")
    protocol = protocol or remote_utils.get_default_communication_protocol()
    if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
        cluster_spec = cluster_spec_or_resolver
    elif isinstance(cluster_spec_or_resolver,
                    cluster_resolver.ClusterResolver):
        if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
            # Do nothing if the master is local.
            return
        cluster_spec = cluster_spec_or_resolver.cluster_spec()
    else:
        raise ValueError(
            "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
            "`ClusterResolver`.")

    cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
    if cluster_device_filters:
        if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
            cluster_device_filters = copy.deepcopy(
                cluster_device_filters._as_cluster_device_filters())  # pylint: disable=protected-access
        else:
            raise ValueError("`cluster_device_filters` must be an instance of "
                             "`tf.train.experimental.ClusterDeviceFilters`.")

    # Automatically add local job, if not part of the cluster spec.
    if job_name not in cluster_spec.jobs:
        local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
        job_def = cluster_def.job.add()
        job_def.name = job_name

        ipstr = _get_local_ip_address(local_port)
        if ipstr:
            job_def.tasks[0] = "{}:{}".format(ipstr, local_port)
        else:
            job_def.tasks[0] = "localhost:{}".format(local_port)

    server_def = ServerDef(cluster=cluster_def,
                           job_name=job_name,
                           task_index=task_index,
                           protocol=protocol,
                           default_session_config=context.context().config,
                           cluster_device_filters=cluster_device_filters)

    if context.get_server_def() is None:
        context.set_server_def(server_def)
    else:
        context.update_server_def(server_def)

    if make_master_device_default and isinstance(
            cluster_spec_or_resolver, cluster_resolver.ClusterResolver
    ) and cluster_spec_or_resolver.master():
        master = cluster_spec_or_resolver.master()
        master_job_name = None
        master_task_id = None
        for job_name in cluster_spec.jobs:
            for task_id in cluster_spec.task_indices(job_name):
                task_address = cluster_spec.task_address(job_name, task_id)
                if master in task_address or task_address in master:
                    master_job_name = job_name
                    master_task_id = task_id
                    break

        if not master_job_name:
            raise ValueError(
                "`make_master_device_default` is set to True but cannot find "
                "master %s in the cluster" % master)

        master_device = "/job:{}/replica:0/task:{}".format(
            master_job_name, master_task_id)
        master_device = device_util.canonicalize(master_device)
        current_device = device_util.current()
        if current_device:
            current_device = device_util.canonicalize(current_device)
        if current_device and current_device != master_device:
            raise ValueError(
                "`connect_to_cluster` is called inside existing device "
                "scope %s, which is different from the master device "
                "scope %s to enter. This is not allowed." %
                (current_device, master_device))
        # TODO(b/138389076): Think of the entering device scope behavior in the
        # failure recovery case when dealing with preemptions.
        if not current_device:
            logging.info("Entering into master device scope: %s",
                         master_device)
            ops.device(master_device).__enter__()
Пример #20
0
import tensorflow as tf
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python.eager import context
from tensorflow.python.training.server_lib import ClusterSpec

cluster_def = ClusterSpec({'worker': ['127.0.0.1:15293']}).as_cluster_def()
# 15293 is just some random available port

server_def = ServerDef(cluster=cluster_def,
                       job_name='worker',
                       task_index=0,
                       protocol='grpc')

v = tf.Variable(3)

print(v.device)
# > /job:localhost/replica:0/task:0/device:CPU:0

context.set_server_def(server_def)

####################################
print(v.device)
# > Segmentation fault (core dumped)
####################################