def testZeroItemsInClusterSpecMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {}
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('', cluster_resolver.master())
예제 #2
0
def _all_devices():
  devices = []
  tfconfig = TFConfigClusterResolver()
  if tfconfig.cluster_spec().as_dict():
    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
                                           context.num_gpus())
  return devices if devices else all_local_devices()
def get_num_workers():
  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    task_type = cluster_resolver.task_type
    return int(multi_worker_util.worker_count(cluster_spec, task_type))
  return 1
  def testOneItemInClusterSpecMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker": ["worker0:2222"]
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('', cluster_resolver.master())
예제 #5
0
def maybe_shard_dataset(dataset):
  """Shard the dataset if running in multi-node environment."""
  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    dataset = dataset.shard(
        multi_worker_util.worker_count(cluster_spec,
                                       cluster_resolver.task_type),
        multi_worker_util.id_in_cluster(
            cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id))
  return dataset
 def __init__(self, container_strategy, num_gpus_per_worker):
   # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
   # the constructor's interface to allow customized cluster resolver. Use
   # SimpleClusterResolver to override num_accelerators.
   tfconfig = TFConfigClusterResolver()
   cluster_resolver = SimpleClusterResolver(
       cluster_spec=tfconfig.cluster_spec(),
       task_type=tfconfig.task_type,
       task_id=tfconfig.task_id,
       num_accelerators=num_gpus_per_worker)
   super(ParameterServerExtended, self).__init__(
       container_strategy, cluster_resolver=cluster_resolver)
예제 #7
0
 def __init__(self, container_strategy, num_gpus_per_worker):
     # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
     # the constructor's interface to allow customized cluster resolver. Use
     # SimpleClusterResolver to override num_accelerators.
     tfconfig = TFConfigClusterResolver()
     cluster_resolver = SimpleClusterResolver(
         cluster_spec=tfconfig.cluster_spec(),
         task_type=tfconfig.task_type,
         task_id=tfconfig.task_id,
         num_accelerators={'GPU': num_gpus_per_worker})
     super(ParameterServerExtended,
           self).__init__(container_strategy,
                          cluster_resolver=cluster_resolver)
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
  """Shard the dataset if running in multi-node environment."""

  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    num_workers = int(multi_worker_util.worker_count(cluster_spec, task_type))
    id_in_cluster = int(
        multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
    dataset = dataset.shard(num_workers, id_in_cluster)
  return dataset.batch(global_batch_size)
예제 #9
0
 def __init__(self, container_strategy, num_gpus_per_worker, communication):
     # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
     # the constructor's interface to allow customized cluster resolver. Use
     # SimpleClusterResolver to override num_accelerators.
     tfconfig = TFConfigClusterResolver()
     cluster_resolver = SimpleClusterResolver(
         cluster_spec=tfconfig.cluster_spec(),
         task_type=tfconfig.task_type,
         task_id=tfconfig.task_id,
         num_accelerators={"GPU": num_gpus_per_worker})
     super(CollectiveAllReduceExtended,
           self).__init__(container_strategy,
                          communication=communication,
                          cluster_resolver=cluster_resolver)
예제 #10
0
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
    """Shard the dataset if running in multi-node environment."""

    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        num_workers = int(
            multi_worker_util.worker_count(cluster_spec, task_type))
        id_in_cluster = int(
            multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
        dataset = dataset.shard(num_workers, id_in_cluster)
    return dataset.batch(global_batch_size)
    def testParameterOverrides(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 1
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver(task_type='ps',
                                                   task_index=0,
                                                   num_accelerators=8)

        self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
        self.assertEqual('ps', cluster_resolver.task_type)
        self.assertEqual(0, cluster_resolver.task_index)
        self.assertEqual(8, cluster_resolver.num_accelerators())

        cluster_resolver.task_type = 'worker'
        cluster_resolver.task_index = 1
        cluster_resolver.rpc_layer = 'test'

        self.assertEqual('test://worker1:2222', cluster_resolver.master())
        self.assertEqual('worker', cluster_resolver.task_type)
        self.assertEqual(1, cluster_resolver.task_index)
        self.assertEqual('test', cluster_resolver.rpc_layer)
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
예제 #13
0
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
예제 #14
0
  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
    super(MirroredExtended, self).__init__(container_strategy)
    if context.executing_eagerly():
      if devices and not _is_device_list_single_worker(devices):
        raise RuntimeError("In-graph multi-worker training with "
                           "`MirroredStrategy` is not supported in eager mode.")
      else:
        if TFConfigClusterResolver().cluster_spec().as_dict():
          # if you are executing in eager mode, only the single machine code
          # path is supported.
          logging.info("Initializing local devices since in-graph multi-worker "
                       "training with `MirroredStrategy` is not supported in "
                       "eager mode. TF_CONFIG will be ignored when "
                       "when initializing `MirroredStrategy`.")
        devices = devices or all_local_devices()
    else:
      devices = devices or all_devices()

    assert devices, ("Got an empty `devices` list and unable to recognize "
                     "any local devices.")
    self._cross_device_ops = cross_device_ops
    self._communication_options = collective_util.Options()
    self._initialize_strategy(devices)

    # TODO(b/128995245): Enable last partial batch support in graph mode.
    if ops.executing_eagerly_outside_functions():
      self.experimental_enable_get_next_as_optional = True

    # Flag to turn on VariablePolicy.
    self._use_var_policy = False
  def testSpecifiedTaskTypeAndIndexMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('worker1:2222', cluster_resolver.master('worker', 1))
예제 #16
0
    def testAutomaticMasterRead(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver()
        self.assertEqual('ps0:2222', cluster_resolver.master())
예제 #17
0
    def testSpecifiedTaskTypeAndIndexMasterRead(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver()
        self.assertEqual('worker1:2222', cluster_resolver.master('worker', 1))
  def testAutomaticMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('ps0:2222', cluster_resolver.master())
        def fn(functions_scheduled_event, test_finished_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_coordinator.schedule(worker_fn)
            ps_coordinator.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise UnavailableError.
            try:
                if test_join:
                    ps_coordinator.join()
                if test_schedule:
                    while ps_coordinator.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_coordinator.schedule(worker_fn)
            except errors.UnavailableError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if coordinator_lib._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError("UnavailableError supposed to be raised.")
예제 #20
0
 def _from_local_devices(
     cls,
     devices,
     communication=cross_device_ops_lib.CollectiveCommunication.AUTO):
   """A convenience method to create an object with a list of devices."""
   obj = cls(communication)
   obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices)  # pylint: disable=protected-access
   return obj
  def testParameterOverrides(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 1
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver(task_type='ps', task_index=0,
                                               num_accelerators_per_worker=8)

    self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
    self.assertEqual('ps', cluster_resolver.task_type)
    self.assertEqual(0, cluster_resolver.task_index)
    self.assertEqual(8, cluster_resolver.num_accelerators_per_worker())

    cluster_resolver.task_type = 'worker'
    cluster_resolver.task_index = 1
    cluster_resolver.rpc_layer = 'test'

    self.assertEqual('test://worker1:2222', cluster_resolver.master())
    self.assertEqual('worker', cluster_resolver.task_type)
    self.assertEqual(1, cluster_resolver.task_index)
    self.assertEqual('test', cluster_resolver.rpc_layer)
  def __init__(self,
               container_strategy,
               cluster_resolver=TFConfigClusterResolver()):
    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
    self._initialize_strategy(cluster_resolver)

    # We typically don't need to do all-reduce in this strategy.
    self._cross_device_ops = (
        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
 def __init__(self,
              container_strategy,
              cluster_resolver=TFConfigClusterResolver()):
     distribute_lib.DistributionStrategyExtended.__init__(
         self, container_strategy)
     self._cross_device_ops = None
     self._initialize_strategy(cluster_resolver)
     assert isinstance(self._get_cross_device_ops(),
                       cross_device_ops_lib.CollectiveAllReduce)
 def __init__(self, container_strategy, communication, cluster_resolver):
     cluster_resolver = cluster_resolver or TFConfigClusterResolver()
     distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
     assert isinstance(communication,
                       cross_device_ops_lib.CollectiveCommunication)
     self._communication = communication
     self._initialize_strategy(cluster_resolver)
     assert isinstance(self._get_cross_device_ops(),
                       cross_device_ops_lib.CollectiveAllReduce)
 def __init__(self,
              container_strategy,
              num_gpus_per_worker,
              communication):
   # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
   # the constructor's interface to allow customized cluster resolver. Use
   # SimpleClusterResolver to override num_accelerators.
   tfconfig = TFConfigClusterResolver()
   cluster_resolver = SimpleClusterResolver(
       cluster_spec=tfconfig.cluster_spec(),
       task_type=tfconfig.task_type,
       task_id=tfconfig.task_id,
       num_accelerators={"GPU": num_gpus_per_worker},
       rpc_layer=tfconfig.rpc_layer)
   super(CollectiveAllReduceExtended, self).__init__(
       container_strategy,
       communication=communication,
       cluster_resolver=cluster_resolver)
예제 #26
0
        def proc_func(functions_scheduled_event, test_finished_event):
            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            ps_client = parameter_server_client.ParameterServerClient(
                cluster_resolver)
            with ps_client._strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                # An ever-running function.
                for _ in math_ops.range(100000):
                    v.assign_add(1)

            # Keep the two workers occupied.
            ps_client.schedule(worker_fn)
            ps_client.schedule(worker_fn)
            # Now the main process can terminate.
            functions_scheduled_event.set()

            # Verified that join and schedule indeed raise
            # ParameterServerFailureError.
            try:
                if test_join:
                    ps_client.join()
                if test_schedule:
                    while ps_client.cluster._closure_queue._error is None:
                        time.sleep(1)
                    ps_client.schedule(worker_fn)
            except client.ParameterServerFailureError:
                # The following verifies that after PS fails, continue executing
                # functions on workers should fail and indicate it's PS failure.
                for worker_id in range(3):
                    with ops.device(
                            "/job:worker/replica:0/task:{}".format(worker_id)):
                        try:
                            # Executing a function after PS fails should result in a PS
                            # failure.
                            worker_fn()
                        except Exception as e:  # pylint: disable=broad-except
                            if client._is_ps_failure(e):
                                if worker_id < 2:
                                    continue
                                logging.info(
                                    "_test_translate_ps_failure_error ends properly."
                                )
                                # Now we can safely exit the test.
                                test_finished_event.set()
                                return
                        raise RuntimeError(
                            "Executing a function after PS fails, should "
                            "result in a PS failure.")

            raise RuntimeError(
                "ParameterServerFailureError supposed to be raised.")
 def __init__(self, container_strategy, cluster_resolver,
              communication_options):
     self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
     distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
     self._communication_options = communication_options
     self._initialize_strategy(self._cluster_resolver)
     self._cfer_fn_cache = weakref.WeakKeyDictionary()
     self.experimental_enable_get_next_as_optional = True
     assert isinstance(self._cross_device_ops,
                       cross_device_ops_lib.CollectiveAllReduce)
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy with an optional `cluster_resolver`.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
        "ParameterServerStrategy")
    distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set(
        len(self.extended.parameter_devices))
 def __init__(self, container_strategy, communication, cluster_resolver):
     self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
     distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
     assert isinstance(communication,
                       cross_device_ops_lib.CollectiveCommunication)
     self._communication = communication
     self._initialize_strategy(self._cluster_resolver)
     self._cfer_fn_cache = weakref.WeakKeyDictionary()
     assert isinstance(self._cross_device_ops,
                       cross_device_ops_lib.CollectiveAllReduce)
 def __init__(self, cluster_resolver=None):
     """Initializes this strategy."""
     # The `cluster_resolver` must be set so that
     # `ParameterServerStrategyExtended` will keep num_gpus for `configure`
     # method.
     if cluster_resolver is None:
         cluster_resolver = TFConfigClusterResolver()
     extended = parameter_server_strategy.ParameterServerStrategyExtended(
         self, cluster_resolver=cluster_resolver)
     super(ParameterServerStrategy, self).__init__(extended)
    def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices,
                                               mock_eager_list_devices):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker1": ["w10:2222"],
        "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "worker1",
        "index": "0"
      }
    }
    """

        devices = [
            LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'),
            LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'),
            LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'),
            LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'),
            LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = TFConfigClusterResolver()

        # By default we read from TF_CONFIG
        self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2})

        # Override still works when we want it to
        self.assertEqual(
            resolver.num_accelerators(task_type='worker2', task_id=3),
            {'GPU': 1})
        def task_function(start_events, finish_events):
            cluster_resolver = TFConfigClusterResolver()
            cluster_spec = cluster_resolver.cluster_spec()
            task_type = cluster_resolver.task_type
            task_id = cluster_resolver.task_id
            rpc_layer = cluster_resolver.rpc_layer

            # TODO(yuefengz): support GPU clusters.
            server_config = config_pb2.ConfigProto()
            server_config.device_count['GPU'] = 0

            if collective_leader:
                server_config.experimental.collective_group_leader = collective_leader
                server_config.experimental.collective_nccl = False

                logging.info(
                    'Enabling collective ops with cluster_spec = %r, task_type = %r, '
                    'task_id = %r, rpc_layer = %r, collective_leader = %s',
                    cluster_spec, task_type, task_id, rpc_layer,
                    collective_leader)
            else:
                logging.info(
                    'Starting server with cluster_spec = %r, task_type = %r, '
                    'task_id = %r, rpc_layer = %r', cluster_spec, task_type,
                    task_id, rpc_layer)

            server_lib.Server(cluster_spec,
                              job_name=task_type,
                              protocol=rpc_layer,
                              task_index=task_id,
                              config=server_config,
                              start=True)

            start_event = start_events[task_type][task_id]
            start_event.set()

            finish_event = finish_events[task_type][task_id]
            finish_event.wait()

            os._exit(0)  # pylint: disable=protected-access
 def testTaskIndexOverride(self):
     os.environ['TF_CONFIG'] = """
 {
   "cluster": {
     "worker": ["worker0:2222", "worker1:2222"]
   },
   "task": {
     "type": "worker",
     "index": "0"
   }
 }
 """
     cluster_resolver = TFConfigClusterResolver(task_id=1)
     self.assertEqual(1, cluster_resolver.task_id)
  def testNormalClusterSpecRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
                         tasks { key: 1 value: 'worker1:2222' }
                         tasks { key: 2 value: 'worker2:2222' } }
    """
    actual_cluster_spec = cluster_resolver.cluster_spec()
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
예제 #35
0
    def testNormalClusterSpecRead(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver()
        expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
                         tasks { key: 1 value: 'worker1:2222' }
                         tasks { key: 2 value: 'worker2:2222' } }
    """
        actual_cluster_spec = cluster_resolver.cluster_spec()
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
예제 #36
0
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy with an optional `cluster_resolver`.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    super(ParameterServerStrategyV1, self).__init__(
        ParameterServerStrategyExtended(
            self, cluster_resolver=cluster_resolver))
    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
        "ParameterServerStrategy")
예제 #37
0
    def task_function(start_events, finish_events):
      cluster_resolver = TFConfigClusterResolver()
      cluster_spec = cluster_resolver.cluster_spec()
      task_type = cluster_resolver.task_type
      task_id = cluster_resolver.task_id
      rpc_layer = cluster_resolver.rpc_layer

      logging.info(
          'Starting server with cluster_spec = %r, task_type = %r, '
          'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
          rpc_layer)

      # TODO(yuefengz): support GPU clusters.
      server_config = config_pb2.ConfigProto()
      server_config.device_count['GPU'] = 0

      # Set the environment variable to prevent hanging upon job failure and
      # restart. Note that it defaults to 'use_caller' at Google, but defaults
      # to False in OSS.
      os.environ['GRPC_FAIL_FAST'] = 'use_caller'

      server_lib.Server(
          cluster_spec,
          job_name=task_type,
          protocol=rpc_layer,
          task_index=task_id,
          config=server_config,
          start=True)

      start_event = start_events[task_type][task_id]
      start_event.set()

      finish_event = finish_events[task_type][task_id]
      finish_event.wait()

      os._exit(0)  # pylint: disable=protected-access
 def testTaskTypeCastToString(self):
   os.environ['TF_CONFIG'] = """
   {
     "cluster": {
       "123456": ["ps0:2222", "ps1:2222"],
       "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     },
     "rpc_layer": "grpc",
     "task": {
       "type": 123456,
       "index": 0
     }
   }
   """
   cluster_resolver = TFConfigClusterResolver()
   self.assertEqual('123456', cluster_resolver.task_type)
 def testTaskIndexCastToInteger(self):
   os.environ['TF_CONFIG'] = """
   {
     "cluster": {
       "ps": ["ps0:2222", "ps1:2222"],
       "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     },
     "rpc_layer": "grpc",
     "task": {
       "type": "ps",
       "index": "1"
     }
   }
   """
   cluster_resolver = TFConfigClusterResolver()
   self.assertEqual(1, cluster_resolver.task_id)
예제 #40
0
 def __init__(self, container_strategy, cluster_resolver,
              communication_options):
     if not isinstance(communication_options, collective_util.Options):
         raise ValueError("communication_options must be an instance of "
                          "tf.distribute.experimental.CommunicationOptions")
     self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
     if not isinstance(self._cluster_resolver, ClusterResolver):
         raise ValueError("cluster_resolver must be an instance of "
                          "tf.distribute.cluster_resolver.ClusterResolver")
     distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
     self._communication_options = communication_options
     self._collective_key_base = container_strategy._collective_key_base  # pylint: disable=protected-access
     self._initialize_strategy(self._cluster_resolver)
     self._cfer_fn_cache = weakref.WeakKeyDictionary()
     self.experimental_enable_get_next_as_optional = True
     assert isinstance(self._cross_device_ops,
                       cross_device_ops_lib.CollectiveAllReduce)
예제 #41
0
    def testTaskTypeIndexRpcRead(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver()
        self.assertEqual('ps', cluster_resolver.task_type)
        self.assertEqual(0, cluster_resolver.task_id)
        self.assertEqual('grpc', cluster_resolver.rpc_layer)
        def fn(first_fetch_occurred_event, worker_terminated_event):
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                return v + 1, v - 1

            remote_value = ps_coordinator.schedule(worker_fn)
            logging.info("result (1st fetch): %r", remote_value.fetch())
            first_fetch_occurred_event.set()
            worker_terminated_event.wait()
            logging.info("result (2nd fetch): %r", remote_value.fetch())
        def fn(functions_scheduled_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_client = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=1)

                @def_function.function
                def worker_fn(input_tensor):
                    def replica_fn(input_tensor):
                        return input_tensor + v

                    run_result = strategy.run(replica_fn,
                                              args=(input_tensor, ))
                    check_ops.assert_equal_v2(run_result, 4)
                    return run_result

            for i in range(5000):
                if i % 500 == 0:
                    logging.info("Scheduling function-{}...".format(i))
                result = ps_client.schedule(worker_fn,
                                            args=(constant_op.constant(3), ))
            functions_scheduled_event.set()
            logging.info("Joining...")
            ps_client.join()
            logging.info("Finished joining.")
            if result.fetch() != 4:
                raise AssertionError(
                    "Unexpected RemoteValue result: {}".format(result.fetch()))
            logging.info("testStrategyRun succeeded")