def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None):
  sess_config = config_pb2.ConfigProto()
  if num_gpus is None:
    num_gpus = len(tf_config.list_logical_devices('GPU'))

  if cluster_spec and task_type and task_id is not None:
    cluster_resolver = SimpleClusterResolver(
        cluster_spec=ClusterSpec(cluster_spec),
        task_type=task_type,
        task_id=task_id,
        num_accelerators={'GPU': num_gpus})
    target = 'grpc://' + cluster_spec[task_type][task_id]
  else:
    cluster_resolver = SimpleClusterResolver(
        ClusterSpec({}), num_accelerators={'GPU': num_gpus})
    target = ''

  strategy = mwms_lib.CollectiveAllReduceStrategy(
      cluster_resolver=cluster_resolver)
  sess_config = strategy.update_config_proto(sess_config)

  return strategy, target, sess_config
  def cluster_spec(self):
    """Returns a ClusterSpec based on the TF_CONFIG environment variable.

    Returns:
      A ClusterSpec with information from the TF_CONFIG environment variable.
    """
    tf_config = _load_tf_config()
    if 'cluster' not in tf_config:
      return ClusterSpec({})
    return ClusterSpec(tf_config['cluster'])
示例#3
0
    def cluster_spec(self):
        """Returns a ClusterSpec based on the SageMaker environment variables.

    Returns:
      A ClusterSpec with information from the SageMaker environment variables.
    """
        tf_config = _load_tf_config(self._port)
        if 'cluster' not in tf_config:
            return ClusterSpec({})
        return ClusterSpec(tf_config['cluster'])
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None):
    sess_config = config_pb2.ConfigProto()
    if num_gpus is None:
        num_gpus = context.num_gpus()

    if cluster_spec and task_type and task_id is not None:
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={'GPU': num_gpus})
        target = 'grpc://' + cluster_spec[task_type][task_id]
    else:
        cluster_resolver = SimpleClusterResolver(
            ClusterSpec({}), num_accelerators={'GPU': num_gpus})
        target = ''

    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
        cluster_resolver=cluster_resolver)
    sess_config = strategy.update_config_proto(sess_config)

    return strategy, target, sess_config
示例#5
0
 def testArbitraryCurrentTaskType(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1, has_chief=True)
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
   with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def gen_server(cluster_spec, job_name: str, task_index: int,
               cpu_device_num: int):
    """
    Start a TensorFlow server.

    Args:
        cluster_spec (dict): TensorFlow ClusterSpec dict
        job_name: TensorFlow job name
        task_index: TensorFlow task index
        cpu_device_num: The number of CPU devices
    """
    _clean_stale_servers()

    # TODO: The following config should be less hard coded ad based on strategy
    experimental = config_pb2.ConfigProto.Experimental(
        collective_nccl=True, collective_group_leader=DEFAULT_GROUP_LEADER)
    s = Server(ClusterSpec(cluster_spec),
               job_name=job_name,
               task_index=task_index,
               config=config_pb2.ConfigProto(
                   experimental=experimental,
                   device_count={"CPU": cpu_device_num},
                   inter_op_parallelism_threads=0,
                   intra_op_parallelism_threads=0))
    return s
示例#7
0
    def test_dataset_creator_usage_in_parameter_server_model_fit(self):
        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=2, num_ps=1, rpc_layer="grpc")
        cluster_def["chief"] = [
            "localhost:%d" % multi_worker_test_base.pick_unused_port()
        ]
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
        model.compile(gradient_descent.SGD(), loss="mse")

        def dataset_fn(input_context):
            global_batch_size = 64
            batch_size = input_context.get_per_replica_batch_size(
                global_batch_size)
            dataset = dataset_ops.DatasetV2.from_tensors(([1.], [1.])).repeat()
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)
            dataset = dataset.batch(batch_size)
            dataset = dataset.prefetch(2)
            return dataset

        history = model.fit(dataset_creator.DatasetCreator(dataset_fn),
                            epochs=10,
                            steps_per_epoch=10,
                            verbose=0)
        self.assertLen(history.history["loss"], 10)
示例#8
0
 def setUpClass(cls):
     super(ParameterServerStrategyV2Test, cls).setUpClass()
     cluster_def = multi_worker_test_base.create_in_process_cluster(
         num_workers=2, num_ps=3)
     cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
     remote.connect_to_cluster(cls.cluster_resolver.cluster_spec(),
                               job_name="chief")
示例#9
0
def make_parameter_server_cluster(num_workers, num_ps):
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
  cluster_def["chief"] = [
      "localhost:%d" % multi_worker_test_base.pick_unused_port()
  ]
  return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None,
                        num_tpus=None):
    if num_gpus is None:
        num_gpus = context.num_gpus()
    if num_tpus is None:
        num_tpus = context.context().list_physical_devices('TPU')
    if num_tpus:
        tpu_strategy_util.initialize_tpu_system()

    if cluster_spec and task_type and task_id is not None:
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={
                'GPU': num_gpus,
                'TPU': num_tpus
            })
        target = 'grpc://' + cluster_spec[task_type][task_id]
    else:
        cluster_resolver = SimpleClusterResolver(ClusterSpec({}),
                                                 num_accelerators={
                                                     'GPU': num_gpus,
                                                     'TPU': num_tpus
                                                 })
        target = ''

    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
        cluster_resolver=cluster_resolver)

    return strategy, target
示例#11
0
  def testClusterCoordinatorMetrics(self):

    metric_utils.enable_metrics = True

    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=1, num_ps=1, rpc_layer=self.get_rpc_layer())
    cluster_def['chief'] = [
        'localhost:%d' % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    cluster = coordinator_lib.Cluster(strategy)

    @def_function.function
    def func():
      time.sleep(0.5)
      return 3

    result = cluster.schedule(func, args=None, kwargs=None)
    result = cluster.schedule(func, args=None, kwargs=None)
    cluster.join()
    self.assertEqual(result.fetch(), 3)

    # Tracing, closure execution, and remote_value fetching should be executed
    # exactly once for running this function.
    metric_tracing = metric_utils.get_metric_summary('function_tracing')
    self.assertEqual(metric_tracing['num'], 1)
    # Tracing time should be longer than the sleep time in Python function.
    self.assertGreater(metric_tracing['sum'], 0.5)
    metric_closure = metric_utils.get_metric_summary('closure_execution')
    self.assertEqual(metric_closure['num'], 2)
    metric_remote_value = metric_utils.get_metric_summary('remote_value_fetch')
    self.assertEqual(metric_remote_value['num'], 2)
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest instance group info.

    Returns:
      A ClusterSpec containing host information retrieved from Slurm's
        environment variables.
    """
        hostlist = self._resolve_hostnames()

        task_list = []
        self._cluster_allocation = {}

        for host in hostlist:
            for port_offset in range(self._tasks_per_node):

                host_addr = '%s:%d' % (host, self._port_base + port_offset)
                task_list.append(host_addr)

        cluster_rank_offset_start = 0
        cluster_rank_offset_end = 0

        for task_type, num_tasks in self._jobs.items():
            cluster_rank_offset_end = cluster_rank_offset_start + num_tasks

            self._cluster_allocation[task_type] = (
                task_list[cluster_rank_offset_start:cluster_rank_offset_end])

            if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end:
                self.task_type = task_type
                self.task_id = self._rank - cluster_rank_offset_start

            cluster_rank_offset_start = cluster_rank_offset_end

        return ClusterSpec(self._cluster_allocation)
示例#13
0
 def testLessThanOneWorker(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=0, num_ps=1, has_chief=True)
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at least one worker."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
示例#14
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest instance group info.

    This returns a ClusterSpec object for use based on information from the
    specified initialization parameters and Slurm environment variables. The
    cluster specification is resolved each time this function is called. The
    resolver extract hostnames of nodes by scontrol and pack tasks in that
    order until a node a has number of tasks that is equal to specification.
    GPUs on nodes are allocated to tasks by specification through setting
    CUDA_VISIBLE_DEVICES environment variable.

    Returns:
      A ClusterSpec containing host information retrieved from Slurm's
        environment variables.
    """
        hostlist = self._resolve_hostnames()

        task_list = []
        self._gpu_allocation = []
        self._cluster_allocation = {}

        for host in hostlist:
            for port_offset, gpu_offset in zip(
                    range(self._tasks_per_node),
                    range(0, self._gpus_per_node, self._gpus_per_task)):

                host_addr = '%s:%d' % (host, self._port_base + port_offset)
                task_list.append(host_addr)
                gpu_id_list = []

                for gpu_id in range(gpu_offset,
                                    gpu_offset + self._gpus_per_task):
                    gpu_id_list.append(str(gpu_id))

                self._gpu_allocation.append(','.join(gpu_id_list))

        cluster_rank_offset_start = 0
        cluster_rank_offset_end = 0

        for job_name, num_tasks in self._jobs.items():
            cluster_rank_offset_end = cluster_rank_offset_start + num_tasks

            self._cluster_allocation[job_name] = \
              task_list[cluster_rank_offset_start:cluster_rank_offset_end]

            if self._rank >= cluster_rank_offset_start and \
                self._rank < cluster_rank_offset_end:

                self._job_name = job_name
                self._task_index = self._rank - cluster_rank_offset_start

            cluster_rank_offset_start = cluster_rank_offset_end

        if self._auto_set_gpu is True:
            os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[
                self._rank]

        return ClusterSpec(self._cluster_allocation)
示例#15
0
 def testArbitraryCurrentTaskType(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=1, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
   with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
示例#16
0
def make_client(num_workers, num_ps):
    # TODO(rchao): Test the internal rpc_layer version.
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return parameter_server_client.ParameterServerClient(cluster_resolver)
def make_coordinator(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return tf.distribute.experimental.coordinator.ClusterCoordinator(
        tf.distribute.experimental.ParameterServerStrategy(cluster_resolver))
示例#18
0
 def testArbitraryJobName(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1, has_chief=True)
   cluster_def["some_arbitrary_name"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc")
   with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
示例#19
0
 def testLessThanOneWorker(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=0, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at least one worker."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
示例#20
0
def make_client(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return client_lib.Client(
        parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver))
示例#21
0
    def _create_parameter_server():

        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
        resolver = cluster_resolver.SimpleClusterResolver(
            ClusterSpec(cluster_def),
            num_accelerators={"GPU": required_gpus},
            rpc_layer="grpc")
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            resolver,
            variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
        return strategy
def make_coordinator(num_workers, num_ps):
  # TODO(rchao): Test the internal rpc_layer version.
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
  cluster_def['chief'] = [
      'localhost:%d' % multi_worker_test_base.pick_unused_port()
  ]
  cluster_resolver = SimpleClusterResolver(
      ClusterSpec(cluster_def), rpc_layer='grpc')
  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
      cluster_resolver)
  return coordinator_lib.ClusterCoordinator(strategy)
示例#23
0
 def testMoreThanOneChief(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1)
   chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
   cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def),
       rpc_layer="grpc",
       task_type="chief",
       task_id=1)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at most one 'chief' job."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
示例#24
0
    def test_dataset_creator_usage_in_parameter_server_model_fit(self):
        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=2, num_ps=1, rpc_layer="grpc")
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
        model.compile(gradient_descent.SGD(), loss="mse")

        history = model.fit(dataset_creator.DatasetCreator(
            self._get_dataset_fn()),
                            epochs=10,
                            steps_per_epoch=10,
                            verbose=0)
        self.assertLen(history.history["loss"], 10)
示例#25
0
    def cluster_spec(self):
        """Retrieve the current state of the cluster and return a ClusterSpec.

    Returns:
      A ClusterSpec representing the state of the cluster at the moment this
      function is called.

    Implementors of this function must take care in ensuring that the
    ClusterSpec returned is up-to-date at the time of calling this function.
    This usually means retrieving the information from the underlying cluster
    management system every time this function is invoked and reconstructing
    a cluster_spec, rather than attempting to cache anything.
    """
        # We currently only use the worker task type.
        # We currently do not support in dynamic changes of the cluster environment.
        return ClusterSpec({"worker": self._peers})
示例#26
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest instance group info.

    This returns a ClusterSpec object for use based on information from the
    specified instance group. We will retrieve the information from the GCE APIs
    every time this method is called.

    Returns:
      A ClusterSpec containing host information retrieved from GCE.
    """
        request_body = {'instanceState': 'RUNNING'}
        request = self._service.instanceGroups().listInstances(
            project=self._project,
            zone=self._zone,
            instanceGroups=self._instance_group,
            body=request_body,
            orderBy='name')

        worker_list = []

        while request is not None:
            response = request.execute()

            items = response['items']
            for instance in items:
                instance_name = instance['instance'].split('/')[-1]

                instance_request = self._service.instances().get(
                    project=self._project,
                    zone=self._zone,
                    instance=instance_name)

                if instance_request is not None:
                    instance_details = instance_request.execute()
                    ip_address = instance_details['networkInterfaces'][0][
                        'networkIP']
                    instance_url = '%s:%s' % (ip_address, self._port)
                    worker_list.append(instance_url)

            request = self._service.instanceGroups().listInstances_next(
                previous_request=request, previous_response=response)

        worker_list.sort()
        return ClusterSpec({self._job_name: worker_list})
示例#27
0
    def testClientMetrics(self):
        if sys.version_info >= (3, 8) and platform.system() == 'Windows':
            # TODO(b/165013260): Fix this
            self.skipTest(
                'Test is currently broken on Windows with Python 3.8')

        metric_utils.enable_metrics = True

        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=1, num_ps=1, rpc_layer=self.get_rpc_layer())
        cluster_def['chief'] = [
            'localhost:%d' % multi_worker_test_base.pick_unused_port()
        ]
        cluster_resolver = SimpleClusterResolver(
            ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver)
        cluster = client.Cluster(strategy)

        @def_function.function
        def func():
            time.sleep(0.5)
            return 3

        result = cluster.schedule(func, args=None, kwargs=None)
        result = cluster.schedule(func, args=None, kwargs=None)
        cluster.join()
        self.assertEqual(result._get_value().numpy(), 3)

        # Tracing, closure execution, and remote_value fetching should be executed
        # exactly once for running this function.
        metric_tracing = metric_utils.get_metric_summary('function_tracing')
        self.assertEqual(metric_tracing['num'], 1)
        # Tracing time should be longer than the sleep time in Python function.
        self.assertGreater(metric_tracing['sum'], 0.5)
        metric_closure = metric_utils.get_metric_summary('closure_execution')
        self.assertEqual(metric_closure['num'], 2)
        metric_remote_value = metric_utils.get_metric_summary(
            'remote_value_fetch')
        self.assertEqual(metric_remote_value['num'], 2)
示例#28
0
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None,
                        sess_config=None,
                        use_core_strategy=False):
    sess_config = sess_config or config_pb2.ConfigProto()
    if num_gpus is None:
        num_gpus = context.num_gpus()
    if use_core_strategy:
        if cluster_spec and task_type and task_id is not None:
            cluster_resolver = SimpleClusterResolver(
                cluster_spec=multi_worker_util.normalize_cluster_spec(
                    cluster_spec),
                task_type=task_type,
                task_id=task_id,
                num_accelerators={'GPU': num_gpus})
            target = 'grpc://' + cluster_spec[WORKER][task_id]
        else:
            cluster_resolver = SimpleClusterResolver(
                ClusterSpec({}), num_accelerators={'GPU': num_gpus})
            target = ''

        distribution = MockCoreParameterServerStrategy(cluster_resolver)
        sess_config = copy.deepcopy(sess_config)
        sess_config = distribution.update_config_proto(sess_config)
    else:
        distribution = parameter_server_strategy.ParameterServerStrategy(
            num_gpus_per_worker=num_gpus)
        if task_type:
            sess_config = copy.deepcopy(sess_config)
            distribution.configure(session_config=sess_config,
                                   cluster_spec=cluster_spec,
                                   task_type=task_type,
                                   task_id=task_id)
            target = 'grpc://' + cluster_spec[WORKER][task_id]
        else:
            target = ''

    return distribution, target, sess_config
示例#29
0
    def cluster_spec(self):
        """Returns a ClusterSpec object based on the latest TPU information.

    We retrieve the information from the GCE APIs every time this method is
    called.

    Returns:
      A ClusterSpec containing host information returned from Cloud TPUs.
    """
        worker_list = []

        for tpu_name in self._tpu_names:
            full_name = 'projects/%s/locations/%s/nodes/%s' % (
                self._project, self._zone, tpu_name)
            request = self._service.projects().locations().nodes().get(
                name=full_name)
            response = request.execute()

            instance_url = '%s:%s' % (response['ipAddress'], response['port'])
            worker_list.append(instance_url)

        return ClusterSpec({self._job_name: worker_list})
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None,
                        use_core_strategy=False):
    sess_config = config_pb2.ConfigProto()
    if num_gpus is None:
        num_gpus = context.num_gpus()
    if use_core_strategy:
        if cluster_spec and task_type and task_id is not None:
            cluster_resolver = SimpleClusterResolver(
                cluster_spec=multi_worker_util.normalize_cluster_spec(
                    cluster_spec),
                task_type=task_type,
                task_id=task_id,
                num_accelerators=num_gpus)
            target = 'grpc://' + cluster_spec[task_type][task_id]
        else:
            cluster_resolver = SimpleClusterResolver(ClusterSpec({}),
                                                     num_accelerators=num_gpus)
            target = ''

        strategy = MockCollectiveAllReduceStrategy(cluster_resolver)
        sess_config = strategy.update_config_proto(sess_config)
    else:
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            num_gpus_per_worker=num_gpus)
        if task_type and task_id is not None:
            strategy.configure(session_config=sess_config,
                               cluster_spec=cluster_spec,
                               task_type=task_type,
                               task_id=task_id)
            target = 'grpc://' + cluster_spec[task_type][task_id]
        else:
            target = ''

    return strategy, target, sess_config