コード例 #1
0
def _all_devices():
  devices = []
  tfconfig = TFConfigClusterResolver()
  if tfconfig.cluster_spec().as_dict():
    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
                                           context.num_gpus())
  return devices if devices else all_local_devices()
コード例 #2
0
def all_devices():
  devices = []
  tfconfig = TFConfigClusterResolver()
  if tfconfig.cluster_spec().as_dict():
    devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(),
                                           context.num_gpus())
  return devices if devices else all_local_devices()
コード例 #3
0
        def task_function(start_events, finish_events):
            cluster_resolver = TFConfigClusterResolver()
            cluster_spec = cluster_resolver.cluster_spec()
            task_type = cluster_resolver.task_type
            task_id = cluster_resolver.task_id
            rpc_layer = cluster_resolver.rpc_layer

            logging.info(
                'Starting server with cluster_spec = %r, task_type = %r, '
                'task_id = %r, rpc_layer = %r', cluster_spec, task_type,
                task_id, rpc_layer)

            # TODO(yuefengz): support GPU clusters.
            server_config = config_pb2.ConfigProto()
            server_config.device_count['GPU'] = 0

            # Set the environment variable to prevent hanging upon job failure and
            # restart. Note that it defaults to 'use_caller' at Google, but defaults
            # to False in OSS.
            os.environ['GRPC_FAIL_FAST'] = 'use_caller'

            server_lib.Server(cluster_spec,
                              job_name=task_type,
                              protocol=rpc_layer,
                              task_index=task_id,
                              config=server_config,
                              start=True)

            start_event = start_events[task_type][task_id]
            start_event.set()

            finish_event = finish_events[task_type][task_id]
            finish_event.wait()

            os._exit(0)  # pylint: disable=protected-access
コード例 #4
0
        def task_function(start_events, finish_events):
            cluster_resolver = TFConfigClusterResolver()
            cluster_spec = cluster_resolver.cluster_spec()
            task_type = cluster_resolver.task_type
            task_id = cluster_resolver.task_id
            rpc_layer = cluster_resolver.rpc_layer

            logging.info(
                'Starting server with cluster_spec = %r, task_type = %r, '
                'task_id = %r, rpc_layer = %r', cluster_spec, task_type,
                task_id, rpc_layer)

            # TODO(yuefengz): support GPU clusters.
            server_config = config_pb2.ConfigProto()
            server_config.device_count['GPU'] = 0

            server_lib.Server(cluster_spec,
                              job_name=task_type,
                              protocol=rpc_layer,
                              task_index=task_id,
                              config=server_config,
                              start=True)

            start_event = start_events[task_type][task_id]
            start_event.set()

            finish_event = finish_events[task_type][task_id]
            finish_event.wait()

            os._exit(0)  # pylint: disable=protected-access
コード例 #5
0
def get_num_workers():
  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    task_type = cluster_resolver.task_type
    return int(multi_worker_util.worker_count(cluster_spec, task_type))
  return 1
コード例 #6
0
def get_num_workers():
    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        task_type = cluster_resolver.task_type
        return int(multi_worker_util.worker_count(cluster_spec, task_type))
    return 1
コード例 #7
0
def maybe_shard_dataset(dataset):
  """Shard the dataset if running in multi-node environment."""
  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    dataset = dataset.shard(
        multi_worker_util.worker_count(cluster_spec,
                                       cluster_resolver.task_type),
        multi_worker_util.id_in_cluster(
            cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id))
  return dataset
コード例 #8
0
def maybe_shard_dataset(dataset):
    """Shard the dataset if running in multi-node environment."""
    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        dataset = dataset.shard(
            multi_worker_util.worker_count(cluster_spec,
                                           cluster_resolver.task_type),
            multi_worker_util.id_in_cluster(cluster_spec,
                                            cluster_resolver.task_type,
                                            cluster_resolver.task_id))
    return dataset
コード例 #9
0
 def __init__(self, container_strategy, num_gpus_per_worker):
   # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
   # the constructor's interface to allow customized cluster resolver. Use
   # SimpleClusterResolver to override num_accelerators.
   tfconfig = TFConfigClusterResolver()
   cluster_resolver = SimpleClusterResolver(
       cluster_spec=tfconfig.cluster_spec(),
       task_type=tfconfig.task_type,
       task_id=tfconfig.task_id,
       num_accelerators=num_gpus_per_worker)
   super(CollectiveAllReduceExtended, self).__init__(
       container_strategy, cluster_resolver=cluster_resolver)
コード例 #10
0
 def __init__(self, container_strategy, num_gpus_per_worker):
   # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
   # the constructor's interface to allow customized cluster resolver. Use
   # SimpleClusterResolver to override num_accelerators.
   tfconfig = TFConfigClusterResolver()
   cluster_resolver = SimpleClusterResolver(
       cluster_spec=tfconfig.cluster_spec(),
       task_type=tfconfig.task_type,
       task_id=tfconfig.task_id,
       num_accelerators=num_gpus_per_worker)
   super(ParameterServerExtended, self).__init__(
       container_strategy, cluster_resolver=cluster_resolver)
コード例 #11
0
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
  """Shard the dataset if running in multi-node environment."""

  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    num_workers = int(multi_worker_util.worker_count(cluster_spec, task_type))
    id_in_cluster = int(
        multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
    dataset = dataset.shard(num_workers, id_in_cluster)
  return dataset.batch(global_batch_size)
コード例 #12
0
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
    """Shard the dataset if running in multi-node environment."""

    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        num_workers = int(
            multi_worker_util.worker_count(cluster_spec, task_type))
        id_in_cluster = int(
            multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
        dataset = dataset.shard(num_workers, id_in_cluster)
    return dataset.batch(global_batch_size)
コード例 #13
0
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
コード例 #14
0
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
コード例 #15
0
 def __init__(self,
              container_strategy,
              num_gpus_per_worker,
              communication):
   # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
   # the constructor's interface to allow customized cluster resolver. Use
   # SimpleClusterResolver to override num_accelerators.
   tfconfig = TFConfigClusterResolver()
   cluster_resolver = SimpleClusterResolver(
       cluster_spec=tfconfig.cluster_spec(),
       task_type=tfconfig.task_type,
       task_id=tfconfig.task_id,
       num_accelerators={"GPU": num_gpus_per_worker},
       rpc_layer=tfconfig.rpc_layer)
   super(CollectiveAllReduceExtended, self).__init__(
       container_strategy,
       communication=communication,
       cluster_resolver=cluster_resolver)
コード例 #16
0
  def __init__(self, cluster_resolver=None):
    """Initializes this strategy with an optional `cluster_resolver`.

    Args:
      cluster_resolver: Optional
        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
    """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    if not cluster_resolver.cluster_spec():
      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
    extended = ParameterServerStrategyExtended(
        self, cluster_resolver=cluster_resolver)
    super(ParameterServerStrategy, self).__init__(extended)
    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
        "ParameterServerStrategy")
    distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set(
        len(self.extended.parameter_devices))
コード例 #17
0
  def testNormalClusterSpecRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
                         tasks { key: 1 value: 'worker1:2222' }
                         tasks { key: 2 value: 'worker2:2222' } }
    """
    actual_cluster_spec = cluster_resolver.cluster_spec()
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
コード例 #18
0
    def testNormalClusterSpecRead(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver()
        expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
                         tasks { key: 1 value: 'worker1:2222' }
                         tasks { key: 2 value: 'worker2:2222' } }
    """
        actual_cluster_spec = cluster_resolver.cluster_spec()
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)