def testAllReduceDense(self, num_processes, required_gpus, implementation,
                           reduce_op, prefer_unique_instance_key):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")
        options = self.RunOptions(
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            prefer_unique_instance_key=prefer_unique_instance_key)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [1.0, 2.0, 3.0, 4.0]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = 1.0
        if group_size == 2:
            expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5
        elif group_size == 4:
            expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5

        self.reduce_and_verify(inputs, expect, options)
示例#2
0
    def __init__(self, cluster_resolver=None, communication_options=None):
        """Creates the strategy.

    Args:
      cluster_resolver: optional
        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
      communication_options: optional
        `tf.distribute.experimental.CommunicationOptions`. This configures the
        default options for cross device communications. It can be overridden by
        options provided to the communication APIs like
        `tf.distribute.ReplicaContext.all_reduce`. See
        `tf.distribute.experimental.CommunicationOptions` for details.
    """
        if communication_options is None:
            communication_options = collective_util.Options()
        super(CollectiveAllReduceStrategy, self).__init__(
            CollectiveAllReduceExtended(
                self,
                cluster_resolver=cluster_resolver,
                communication_options=communication_options))

        distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
            "MultiWorkerMirroredStrategy")
        # pylint: disable=protected-access
        distribute_lib.distribution_strategy_replica_gauge.get_cell(
            "num_workers").set(self.extended._num_workers)
        distribute_lib.distribution_strategy_replica_gauge.get_cell(
            "num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
    def __init__(self,
                 communication=collective_util.CommunicationImplemenation.AUTO,
                 cluster_resolver=None):
        """Creates the strategy.

    Args:
      communication: optional
        `tf.distribute.experimental.CommunicationImplemenation`. This is a hint
        on the preferred collective communication implementation. Possible
        values include `AUTO`, `RING`, and `NCCL`.
      cluster_resolver: optional
        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
    """
        communication_options = collective_util.Options(
            implementation=communication)
        super(CollectiveAllReduceStrategy, self).__init__(
            CollectiveAllReduceExtended(
                self,
                cluster_resolver=cluster_resolver,
                communication_options=communication_options))

        distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
            "MultiWorkerMirroredStrategy")
        # pylint: disable=protected-access
        distribute_lib.distribution_strategy_replica_gauge.get_cell(
            "num_workers").set(self.extended._num_workers)
        distribute_lib.distribution_strategy_replica_gauge.get_cell(
            "num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)
            value = constant_op.constant([[[1, 2], [1, 2]]],
                                         dtype=dtypes.float32)

            def gather_fn():
                per_replica_value = make_per_replica_value(value, devices)
                gathered_values = collective._gather(per_replica_value,
                                                     per_replica_value,
                                                     axis=axis,
                                                     options=options)
                gathered_values = self.as_list(gathered_values)
                # Skip checking devices in eager. In eager the device attribute doesn't
                # reflect the actual device of the tensor.
                if not context.executing_eagerly():
                    self.assertAllEqual(devices,
                                        [v.device for v in gathered_values])
                return [ops.convert_to_tensor(v) for v in gathered_values]

            group_size = num_processes * (required_gpus or 1)
            expect = array_ops.concat([value] * group_size, axis=axis)
            per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)

            if func_mode == "eager":
                result = gather_fn()
                self.assertAllClose(result, per_replica_expect)

            if func_mode == "func_graph":
                result = def_function.function(gather_fn)()
                self.assertAllClose(result, per_replica_expect)
    def testBatchAllReduceDense(self, num_processes, required_gpus,
                                implementation, reduce_op,
                                use_scoped_allocator, use_collective_v2):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")

        options = self.RunOptions(
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_scoped_allocator=use_scoped_allocator,
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = [1.0, 2.0]
        if group_size == 2:
            expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0]
        elif group_size == 4:
            expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0]

        self.batch_reduce_and_verify(inputs, expect, options)
示例#6
0
    def replica_fn():
      collective, devices, _ = self.make_collective(num_processes,
                                                    required_gpus)
      options = collective_util.Options(implementation=implementation)
      group_size = num_processes * (required_gpus or 1)

      @def_function.function
      def collective_batch_all_reduce():
        results = []
        for replica_id, device in enumerate(devices):
          with ops.device(device):
            value = (IndexedSlices(
                array_ops.identity([[1.]]), array_ops.identity([0]),
                array_ops.identity([5, 1])), array_ops.identity(1.0),
                     IndexedSlices(
                         array_ops.identity([[3.]]), array_ops.identity([2]),
                         array_ops.identity([5, 1])), array_ops.identity(2.0))
            results.append(
                collective._all_reduce(reduce_op, value, replica_id, options))
        return results

      got = collective_batch_all_reduce()
      expect = [
          (IndexedSlices([[1. * group_size]], [0], [5, 1]), 1.0 * group_size,
           IndexedSlices([[3. * group_size]], [2], [5, 1]), 2.0 * group_size)
      ] * len(devices)
      self.assertAllClose(
          nest.map_structure(ops.convert_to_tensor, got),
          nest.map_structure(ops.convert_to_tensor, expect))
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)

            @def_function.function
            def reduce_fn(v):
                # Function inputs don't have device placement.
                self.assertEqual(v.values[0].device, "")
                self.assertEqual(v.values[1].device, "")
                # We only use NCCL for batch reduce with two or more values, so we use
                # two values here.
                reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                  [(v, v), (v, v)], options)
                self.assertEqual(reduced[0].values[0].device, devices[0])
                self.assertEqual(reduced[0].values[1].device, devices[1])
                self.assertEqual(reduced[1].values[0].device, devices[0])
                self.assertEqual(reduced[1].values[1].device, devices[1])
                # Returning Mirrored only evaluates the primary value, which causes
                # hanging,
                return [reduced[0].values, reduced[1].values]

            v = make_per_replica_value(1.0, devices)
            reduced = reduce_fn(v)
            self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]])
示例#8
0
  def make_collective(self, num_processes, gpu_per_process):
    """Returns collectives and other info to be used in tests.

    Args:
      num_processes: an integer indicating the number of processes that
        participate in the collective.
      gpu_per_process: number of GPUs (0 if no GPUs) used by each process.

    Returns:
     A tuple of (collective, devices, pid) where collective is a instance
     of `CollectiveAllReduce`, devices are a list of local devices (str)
     attached to the current process, and pid is the id of this process among
     all participant processes.
    """

    cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
    devices = [
        "/job:worker/replica:0/task:%d/device:CPU:0" % cluster_resolver.task_id
    ]
    if gpu_per_process > 0:
      devices = [
          "/job:worker/replica:0/task:%d/device:GPU:%d" %
          (cluster_resolver.task_id, i) for i in range(gpu_per_process)
      ]
    group_size = num_processes * len(devices)
    collective = cross_device_ops_lib.CollectiveAllReduce(
        devices=devices,
        group_size=group_size,
        options=collective_util.Options())
    return collective, devices, cluster_resolver.task_id
    def _get_test_object(self,
                         task_type,
                         task_id,
                         num_gpus=0,
                         num_tpus=0,
                         use_devices_arg=False):
        strategy, target = create_test_objects(cluster_spec=self._cluster_spec,
                                               task_type=task_type,
                                               task_id=task_id,
                                               num_gpus=num_gpus,
                                               num_tpus=num_tpus)

        if use_devices_arg and num_gpus > 0:
            devices = ['GPU:%d' % i for i in range(num_gpus)]
            # Temporary workaround to manually set the `_extended` field before device
            # initialization is exposed as a public interface.
            strategy._extended = CollectiveAllReduceExtended(
                container_strategy=strategy,
                cluster_resolver=None,
                communication_options=collective_util.Options(),
                devices=devices)
            # Manually set the field since the workaround bypasses the base
            # contructor, resulting in the absence of this field.
            strategy._extended._retrace_functions_for_each_device = (num_gpus >
                                                                     1)

        return strategy, target
示例#10
0
    def _start_check_health_thread(self):
        # Use a dummy all-reduce as a barrier to wait for all workers to be up,
        # otherwise the check health may fail immediately.

        # Use array_ops.identity to create the dummy tensor so that we have a new
        # Tensor. If we use constant it may be a cached from on a /job:localhost
        # device, which will cause some code that relies on tensor.device to error.
        #
        # TODO(b/151232436): change to an explicit barrier if we have it.
        dummy_value = array_ops.identity([])
        logging.info("Waiting for the cluster, timeout = %s",
                     self._check_health_initial_timeout or "inf")
        try:
            self._host_cross_device_ops.reduce(
                reduce_util.ReduceOp.SUM,
                dummy_value,
                dummy_value,
                options=collective_util.Options(
                    timeout_seconds=self._check_health_initial_timeout,
                    implementation=collective_util.CommunicationImplementation.
                    RING))
            if context.is_async():
                context.async_wait()
        except errors.DeadlineExceededError:
            raise RuntimeError(
                "Timeout waiting for the cluster, timeout is %d seconds" %
                self._check_health_initial_timeout)
        logging.info("Cluster is ready.")
        self._check_health_thread_should_stop = threading.Event()
        # Start the thread as daemon to avoid it blocking the program from exiting.
        # We try best to shutdown the thread but __del__ is not guaranteed to be
        # called when program exists.
        self._check_health_thread = threading.Thread(target=self._check_health,
                                                     daemon=True)
        self._check_health_thread.start()
示例#11
0
  def __init__(self, container_strategy, devices=None, cross_device_ops=None):
    super(MirroredExtended, self).__init__(container_strategy)
    if context.executing_eagerly():
      if devices and not _is_device_list_single_worker(devices):
        raise RuntimeError("In-graph multi-worker training with "
                           "`MirroredStrategy` is not supported in eager mode.")
      else:
        if TFConfigClusterResolver().cluster_spec().as_dict():
          # if you are executing in eager mode, only the single machine code
          # path is supported.
          logging.info("Initializing local devices since in-graph multi-worker "
                       "training with `MirroredStrategy` is not supported in "
                       "eager mode. TF_CONFIG will be ignored when "
                       "when initializing `MirroredStrategy`.")
        devices = devices or all_local_devices()
    else:
      devices = devices or all_devices()

    assert devices, ("Got an empty `devices` list and unable to recognize "
                     "any local devices.")
    self._cross_device_ops = cross_device_ops
    self._communication_options = collective_util.Options()
    self._initialize_strategy(devices)

    # TODO(b/128995245): Enable last partial batch support in graph mode.
    if ops.executing_eagerly_outside_functions():
      self.experimental_enable_get_next_as_optional = True

    # Flag to turn on VariablePolicy.
    self._use_var_policy = False
 def testStrategyInitializationError(self):
     with self.assertRaisesRegex(
             ValueError,
             'cluster_resolver and devices cannot be set at the same time'):
         _ = collective_all_reduce_strategy.CollectiveAllReduceExtended(
             container_strategy=None,
             cluster_resolver=multi_worker_test_base.
             create_in_process_cluster(num_workers=3, num_ps=0),
             communication_options=collective_util.Options(),
             devices=['GPU:0', 'GPU:1'])
    def testAllReduceSparse(self, num_processes, required_gpus, implementation,
                            reduce_op, prefer_unique_instance_key):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")
        options = self.RunOptions(
            mode=["func_graph"],  # Sparse reduce is not supported in eager.
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            prefer_unique_instance_key=prefer_unique_instance_key)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [
            IndexedSlicesValue(values=[[1.], [2.]],
                               indices=[0, 1],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[3.], [4.]],
                               indices=[1, 2],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[5.], [6.]],
                               indices=[7, 8],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[7.], [8.]],
                               indices=[3, 2],
                               dense_shape=[10, 1]),
        ]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = IndexedSlices(values=[[1.], [2.]],
                                   indices=[0, 1],
                                   dense_shape=[10, 1])
        elif group_size == 2:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.]],
                                   indices=[0, 1, 1, 2],
                                   dense_shape=[10, 1])
        elif group_size == 4:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.], [5.], [6.],
                                           [7.], [8.]],
                                   indices=[0, 1, 1, 2, 7, 8, 3, 2],
                                   dense_shape=[10, 1])

        self.reduce_and_verify(inputs, expect, options)
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)

            # We would like to simulate the following sequence:
            #   thread-0  device0                 device1
            #   thread-1          device0 device1
            # If the kernel launch sequence is as-is the program will deadlock since
            # NCCL requires the launch order to be same on each device.
            v0 = make_per_replica_value(1.0, devices)
            v1 = make_per_replica_value(2.0, devices)

            # Add a delay to collective_ops.all_reduce according to the input tensors
            # index in `sequence.`
            sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
            all_reduce = collective_ops.all_reduce

            def delayed_all_reduce(input_tensor, *args, **kwargs):
                for idx, v in enumerate(sequence):
                    if input_tensor is v:
                        time.sleep(idx)
                        break
                return all_reduce(input_tensor, *args, **kwargs)

            with test.mock.patch.object(collective_ops, "all_reduce",
                                        delayed_all_reduce):
                # We only use NCCL for batch reduce with two or more values, so we use
                # two values here.

                def thread_fn():
                    reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                      [(v0, v0),
                                                       (v0, v0)], options)
                    self.assertAllEqual(reduced[0].values, [2.0, 2.0])
                    self.assertAllEqual(reduced[1].values, [2.0, 2.0])

                t = threading.Thread(target=thread_fn)
                t.start()
                reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                  [(v1, v1),
                                                   (v1, v1)], options)
                self.assertAllEqual(reduced[0].values, [4.0, 4.0])
                self.assertAllEqual(reduced[1].values, [4.0, 4.0])
                t.join()
    def replica_fn():
      collective, devices, task_id = self.make_collective(
          num_processes, required_gpus)
      if task_id != 0:
        return

      v = make_per_replica_value(1.0, devices)
      options = collective_util.Options(
          timeout_seconds=1, implementation=implementation)

      @def_function.function
      def reduce_dense():
        collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)

      # The collective should time out because we only launch it on worker-0,
      # while there're three workers in total.
      with self.assertRaises(errors.DeadlineExceededError):
        reduce_dense()
示例#16
0
    def __init__(self,
                 communication=collective_util.CommunicationImplemenation.AUTO,
                 cluster_resolver=None):
        """Creates the strategy.

    Args:
      communication: optional
        `tf.distribute.experimental.CommunicationImplementation`. This is a hint
        on the preferred collective communication implementation. Possible
        values include `AUTO`, `RING`, and `NCCL`.
      cluster_resolver: optional
        `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
        `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
    """
        communication_options = collective_util.Options(
            implementation=communication)
        super(_CollectiveAllReduceStrategyExperimental,
              self).__init__(cluster_resolver, communication_options)
示例#17
0
 def __init__(self,
              communication=collective_util.CommunicationImplemenation.AUTO,
              cluster_resolver=None):
     """Initializes the object."""
     communication_options = collective_util.Options(
         implementation=communication)
     super(CollectiveAllReduceStrategyV1, self).__init__(
         CollectiveAllReduceExtended(
             self,
             cluster_resolver=cluster_resolver,
             communication_options=communication_options))
     distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
         "MultiWorkerMirroredStrategy")
     # pylint: disable=protected-access
     distribute_lib.distribution_strategy_replica_gauge.get_cell(
         "num_workers").set(self.extended._num_workers)
     distribute_lib.distribution_strategy_replica_gauge.get_cell(
         "num_gpu_per_worker").set(self.extended._num_gpus_per_worker)
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)
            value = make_per_replica_value(constant_op.constant([1.]), devices)

            @def_function.function
            def reduce_fn():
                def cond_body():
                    reduced = collective.reduce(reduce_util.ReduceOp.SUM,
                                                value, value, options)
                    return math_ops.add_n(self.as_list(reduced)) / len(devices)

                return control_flow_ops.cond(array_ops.identity(False),
                                             cond_body, cond_body)

            num_replicas = num_processes * len(devices)
            self.assertAllEqual(reduce_fn(), [1. * num_replicas])
    def testAllReduceSparse(self, num_processes, required_gpus, implementation,
                            reduce_op, use_collective_v2):
        options = self.RunOptions(
            mode=["func_graph"],  # Sparse reduce is not supported in eager.
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [
            IndexedSlicesValue(values=[[1.], [2.]],
                               indices=[0, 1],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[3.], [4.]],
                               indices=[1, 2],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[5.], [6.]],
                               indices=[7, 8],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[7.], [8.]],
                               indices=[3, 2],
                               dense_shape=[10, 1]),
        ]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = IndexedSlices(values=[[1.], [2.]],
                                   indices=[0, 1],
                                   dense_shape=[10, 1])
        elif group_size == 2:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.]],
                                   indices=[0, 1, 1, 2],
                                   dense_shape=[10, 1])
        elif group_size == 4:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.], [5.], [6.],
                                           [7.], [8.]],
                                   indices=[0, 1, 1, 2, 7, 8, 3, 2],
                                   dense_shape=[10, 1])

        self.reduce_and_verify(inputs, expect, options)
示例#20
0
        def replica_fn():
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)
            group_size = num_processes * (required_gpus or 1)

            @def_function.function
            def collective_all_reduce():
                results = []
                for replica_id, device in enumerate(devices):
                    with ops.device(device):
                        value = constant_op.constant(1.0)
                        results.append(
                            collective._all_reduce(reduce_op, value,
                                                   replica_id, options))
                return results

            got = collective_all_reduce()
            if reduce_op == ReduceOp.SUM:
                expect = [1.0 * group_size] * len(devices)
            elif reduce_op == ReduceOp.MEAN:
                expect = [1.0] * len(devices)
            self.assertAllClose(got, expect)

            @def_function.function
            def collective_batch_all_reduce():
                results = []
                for replica_id, device in enumerate(devices):
                    with ops.device(device):
                        value = (constant_op.constant(1.0),
                                 constant_op.constant(2.0))
                        results.append(
                            collective._all_reduce(reduce_op, value,
                                                   replica_id, options))
                return results

            got = collective_batch_all_reduce()
            if reduce_op == ReduceOp.SUM:
                expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices)
            elif reduce_op == ReduceOp.MEAN:
                expect = [(1.0, 2.0)] * len(devices)
            self.assertAllClose(got, expect)
  def testAllReduceDense(self, num_processes, required_gpus, implementation,
                         reduce_op):
    options = self.RunOptions(
        num_processes=num_processes,
        gpus_per_process=required_gpus,
        reduce_op=reduce_op,
        communication_options=collective_util.Options(
            implementation=implementation))
    group_size = options.num_processes * (options.gpus_per_process or 1)

    inputs_data = [1.0, 2.0, 3.0, 4.0]
    inputs = inputs_data[0:group_size]

    if group_size == 1:
      expect = 1.0
    if group_size == 2:
      expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5
    elif group_size == 4:
      expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5

    self.reduce_and_verify(inputs, expect, options)
    def replica_fn():
      collective, devices, task_id = self.make_collective(
          num_processes, required_gpus)
      if task_id != 0:
        return

      v = make_per_replica_value(
          IndexedSlicesValue(
              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
      options = collective_util.Options(
          timeout_seconds=1, implementation=implementation)

      @def_function.function
      def batch_reduce_sparse():
        collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
                                options)

      # The collective should time out because we only launch it on worker-0,
      # while there're two workers in total.
      with self.assertRaises(errors.DeadlineExceededError):
        batch_reduce_sparse()
        def replica_fn():
            CollectiveReplicaLauncher._prefer_unique_instance_key = (
                prefer_unique_instance_key)
            collective, devices, task_id = self.make_collective(
                num_processes, required_gpus)
            if task_id != 0:
                return

            v = make_per_replica_value(1.0, devices)
            options = collective_util.Options(timeout_seconds=1,
                                              implementation=implementation)

            @def_function.function
            def batch_reduce_dense():
                return collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                               [(v, v), (v, v)], options)

            # The collective should time out because we only launch it on worker-0,
            # while there're two workers in total.
            with self.assertRaises(errors.DeadlineExceededError):
                batch_reduce_dense()
示例#24
0
    def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
        """Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
        # This implementation avoids using `merge_call` and just launches collective
        # ops in one replica.
        if options is None:
            options = collective_util.Options()

        if context.executing_eagerly():
            # In eager mode, falls back to the default implemenation that uses
            # `merge_call`. Replica functions are running sequentially in eager mode,
            # and due to the blocking nature of collective ops, execution will hang if
            # collective ops are to be launched sequentially.
            return super()._replica_ctx_all_reduce(reduce_op, value, options)

        replica_context = ds_context.get_replica_context()
        assert replica_context, (
            "`StrategyExtended._replica_ctx_all_reduce` must be called in a "
            "replica context")
        return self._cross_device_ops._all_reduce(  # pylint: disable=protected-access
            reduce_op,
            value,
            replica_context._replica_id,  # pylint: disable=protected-access
            options)
        def replica_fn():
            cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
            cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(
                implementation=CommunicationImplementation.NCCL)

            v_dense = make_per_replica_value([1.0, 1.0], devices)
            v_sparse = make_per_replica_value([
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
            ], devices)

            @def_function.function
            def nested_dense():
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)

            @def_function.function
            def nested_sparse():
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)

            # All collectives, function calls, if clause and while loops should be
            # chained by control dependencies, so that the execution order is
            # deterministic.
            @def_function.function
            def f():
                # pylint: disable=pointless-statement
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reducing dense value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                # reducing sparse value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reduce dense value in nested tf.function.
                nested_dense()
                # reduce sparse value in nested tf.function.
                nested_sparse()
                # reduce dense value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                else:
                    v_dense
                # reduce sparse value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    v_sparse
                else:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                # reduce dense value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                    i += 1
                # reduce sparse value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                    i += 1
                # reducing dense and sparse value again.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # pylint: enable=pointless-statement

            graph = f.get_concrete_function().graph
            should_be_ordered = set([
                "CollectiveReduceV2", "CollectiveGatherV2", "If", "While",
                "StatefulPartitionedCall"
            ])
            nodes_by_device = {}
            for op in graph.get_operations():
                if op.type in should_be_ordered:
                    if op.device not in nodes_by_device:
                        nodes_by_device[op.device] = []
                    nodes_by_device[op.device].append(op)
            order = test_util.topological_sort_operations(
                graph.get_operations())
            for device in devices:
                device = device_util.canonicalize(device)
                # Those function ops don't have device annotations, but they contain
                # collectives for both devices so we always include them.
                operations = nodes_by_device[device] + nodes_by_device[""]
                # Verify that we get all types of nodes we want.
                self.assertEqual(set(op.type for op in operations),
                                 should_be_ordered)
                test_util.assert_sequential_execution(order, operations)
    def testBatchAllReduceSparse(self, num_processes, required_gpus,
                                 implementation, reduce_op,
                                 use_scoped_allocator, use_collective_v2):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")

        options = self.RunOptions(
            mode=["func_graph"],  # Sparse reduce is not supported in eager.
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_scoped_allocator=use_scoped_allocator,
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = ([
            IndexedSlicesValue(values=[[1.], [2.]],
                               indices=[0, 1],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[3.], [4.]],
                               indices=[1, 2],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[5.], [6.]],
                               indices=[1, 2],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[7.], [8.]],
                               indices=[0, 1],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[9.], [10.]],
                               indices=[3, 4],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[11.], [12.]],
                               indices=[3, 4],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[13.], [14.]],
                               indices=[8, 9],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[15.], [16.]],
                               indices=[3, 4],
                               dense_shape=[5, 1])
        ])
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = [
                IndexedSlices(values=[[1.], [2.]],
                              indices=[0, 1],
                              dense_shape=[10, 1]),
                IndexedSlicesValue(values=[[3.], [4.]],
                                   indices=[1, 2],
                                   dense_shape=[5, 1])
            ]
        if group_size == 2:
            expect = [
                IndexedSlices(values=[[1.], [2.], [5.], [6.]],
                              indices=[0, 1, 1, 2],
                              dense_shape=[10, 1]),
                IndexedSlices(values=[[3.], [4.], [7.], [8.]],
                              indices=[1, 2, 3, 4],
                              dense_shape=[5, 1])
            ]
        elif group_size == 4:
            expect = [
                IndexedSlices(values=[[1.], [2.], [5.], [6.], [9.], [10.],
                                      [13.], [14.]],
                              indices=[0, 1, 1, 2, 3, 4, 8, 9],
                              dense_shape=[10, 1]),
                IndexedSlices(values=[[3.], [4.], [7.], [8.], [11.], [12.],
                                      [15.], [16.]],
                              indices=[1, 2, 0, 1, 3, 4, 3, 4],
                              dense_shape=[5, 2])
            ]
            self.batch_reduce_and_verify(inputs, expect, options)
class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
    def setUp(self):
        super().setUp()
        # Enabling collectives can be done in "setUpClass", but requires using
        # different collective_keys in different tests as collectives are reused
        # across tests. Always resetting collective ops before each test offers
        # better test isolation.
        global_mpr_1p.runner.run(enable_collective_ops)
        global_mpr_2p.runner.run(enable_collective_ops)

    def make_collective(self, num_processes, gpu_per_process):
        """Returns collectives and other info to be used in tests.

    Args:
      num_processes: an integer indicating the number of processes that
        participate in the collective.
      gpu_per_process: number of GPUs (0 if no GPUs) used by each process.

    Returns:
     A tuple of (collective, devices, group_size) where collective is a instance
     of `CollectiveAllReduce`, devices are a list of local devices (str)
     attached to the current process, and group_size is the group_size of
     collective.
    """

        cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
        devices = [
            "/job:worker/replica:0/task:%d/device:CPU:0" %
            cluster_resolver.task_id
        ]
        if gpu_per_process > 0:
            devices = [
                "/job:worker/replica:0/task:%d/device:GPU:%d" %
                (cluster_resolver.task_id, i) for i in range(gpu_per_process)
            ]
        group_size = num_processes * len(devices)
        collective = cross_device_ops_lib.CollectiveAllReduce(
            devices=devices, group_size=group_size)
        return collective, devices, cluster_resolver.task_id

    def as_list(self, value):
        """An utility to convert a `Mirrored`, `Tensor` or `IndexedSlices` to a list.

    The reason it exists is to provide a uniformed view of returned value of
    "reduce" calls, especially across tf.function boundaries. Returning
    `Mirrored` from a tf.function will only evaluate the primary value, which
    makes collective ops of non-primary device being pruned, and will eventually
    cause hanging.

    Args:
      value: the value to convert, can be one of `Mirrored`, `Tensor` and
        `IndexedSlices`.

    Returns:
      A list of `Tensor` or `IndexedSlices`.
    """
        if isinstance(value, ops.Tensor):
            return [value]
        elif isinstance(value, IndexedSlices):
            return [value]
        elif isinstance(value, value_lib.Mirrored):
            return value.values
        else:
            raise ValueError("unwrap: unsupported input type: %s" %
                             type(value))

    RunOptions = collections.namedtuple(  # pylint: disable=invalid-name
        "RunOptions",
        [
            "mode",  # A list of str from ["eager", "func_graph"]
            "num_processes",
            "gpus_per_process",
            "reduce_op",
            "communication_options",
            "use_scoped_allocator",
            "use_collective_v2",
        ])
    RunOptions.__new__.__defaults__ = (["eager",
                                        "func_graph"], 2, 0, ReduceOp.SUM,
                                       collective_util.Options(), True, False)

    def reduce_and_verify(self, inputs, expect, options):
        """Reduce the given `inputs` and verify the output matches `expect`.

    Args:
      inputs: a list of `Tensor` or `IndexedSlices`, where i-th value will be
        fed to i-th replica.
      expect: a `Tensor` or `IndexedSlices`. This should be the expected value
        for one replica.
      options: a `RunOpotions` instance.
    """
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                options.use_collective_v2)
            collective, devices, pid = self.make_collective(
                options.num_processes, options.gpus_per_process)

            def reduce_fn():
                value_fn = lambda device_idx: inputs[pid * len(devices) +
                                                     device_idx]
                per_replica_value = make_per_replica_value(value_fn, devices)
                reduced_values = collective.reduce(
                    options.reduce_op, per_replica_value, per_replica_value,
                    options.communication_options)
                reduced_values = self.as_list(reduced_values)
                self.assertAllEqual(devices,
                                    [v.device for v in reduced_values])
                return [ops.convert_to_tensor(v) for v in reduced_values]

            per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)

            if "eager" in options.mode:
                got = reduce_fn()
                self.assertAllClose(got, per_replica_expect)

            if "func_graph" in options.mode:
                got = def_function.function(reduce_fn)()
                self.assertAllClose(got, per_replica_expect)

        get_global_mpr(options.num_processes).run(replica_fn)

    def batch_reduce_and_verify(self, inputs, expect, options):
        """Batch reduce the given `inputs` and verify the output matches `expect`.

    Args:
      inputs: a 2-level nested list of `Tensor` or `IndexedSlices`, where i-th
        value will be fed to i-th replica.
      expect: a list of `Tensor` or `IndexedSlices`. This should be the expected
        value for one replica.
      options: a `RunOpotions` instance.
    """
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = (
                options.use_scoped_allocator)
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                options.use_collective_v2)
            collective, devices, pid = self.make_collective(
                options.num_processes, options.gpus_per_process)

            def batch_reduce_fn():
                batch_size = len(inputs[0])
                value_dst_pairs = []
                for i in range(batch_size):

                    def value_fn(device_idx, idx=i):
                        return inputs[pid * len(devices) + device_idx][idx]

                    per_replica_value = make_per_replica_value(
                        value_fn, devices)
                    value_dst_pairs.append(
                        (per_replica_value, per_replica_value))
                reduced_values = collective.batch_reduce(
                    options.reduce_op, value_dst_pairs,
                    options.communication_options)
                reduced_values = [self.as_list(v) for v in reduced_values]
                for v in reduced_values:
                    self.assertAllEqual(devices, [t.device for t in v])
                return nest.map_structure(ops.convert_to_tensor,
                                          reduced_values)

            per_replica_expect = nest.map_structure(
                lambda x: [ops.convert_to_tensor(x)] * len(devices), expect)

            if "eager" in options.mode:
                got = batch_reduce_fn()
                self.assertAllClose(got, per_replica_expect)

            if "func_graph" in options.mode:
                got = def_function.function(batch_reduce_fn)()
                self.assertAllClose(got, per_replica_expect)

        get_global_mpr(options.num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(
            num_processes=[1, 2],
            required_gpus=[0, 1, 2],
            implementation=[
                # NCCL is only used for batch reduce, so we are not including
                # NCCL combination here.
                CommunicationImplementation.AUTO,
                CommunicationImplementation.RING
            ],
            reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
            use_collective_v2=[True, False]))
    def testAllReduceDense(self, num_processes, required_gpus, implementation,
                           reduce_op, use_collective_v2):
        options = self.RunOptions(
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [1.0, 2.0, 3.0, 4.0]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = 1.0
        if group_size == 2:
            expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5
        elif group_size == 4:
            expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5

        self.reduce_and_verify(inputs, expect, options)

    @combinations.generate(
        combinations.combine(
            num_processes=[1, 2],
            required_gpus=[0, 1, 2],
            implementation=[
                # NCCL is only used for batch reduce, so we are not including
                # NCCL combination here.
                CommunicationImplementation.AUTO,
                CommunicationImplementation.RING
            ],
            # TODO(b/166682130): add MEAN reduce once the bug is fixed.
            reduce_op=ReduceOp.SUM,
            use_collective_v2=[True, False]))
    def testAllReduceSparse(self, num_processes, required_gpus, implementation,
                            reduce_op, use_collective_v2):
        options = self.RunOptions(
            mode=["func_graph"],  # Sparse reduce is not supported in eager.
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [
            IndexedSlicesValue(values=[[1.], [2.]],
                               indices=[0, 1],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[3.], [4.]],
                               indices=[1, 2],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[5.], [6.]],
                               indices=[7, 8],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[7.], [8.]],
                               indices=[3, 2],
                               dense_shape=[10, 1]),
        ]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = IndexedSlices(values=[[1.], [2.]],
                                   indices=[0, 1],
                                   dense_shape=[10, 1])
        elif group_size == 2:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.]],
                                   indices=[0, 1, 1, 2],
                                   dense_shape=[10, 1])
        elif group_size == 4:
            expect = IndexedSlices(values=[[1.], [2.], [3.], [4.], [5.], [6.],
                                           [7.], [8.]],
                                   indices=[0, 1, 1, 2, 7, 8, 3, 2],
                                   dense_shape=[10, 1])

        self.reduce_and_verify(inputs, expect, options)

    @combinations.generate(
        combinations.combine(use_collective_v2=[True, False]))
    def testAllReduceSparseVariableLength(self, use_collective_v2):
        # One device per process, 2 processes, 2 replicas in total.
        inputs = [
            IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10,
                                                                        1]),
            IndexedSlicesValue(values=[[2.], [3.], [4.]],
                               indices=[0, 1, 2],
                               dense_shape=[10, 1]),
        ]
        expect = IndexedSlices(values=[[1.], [2.], [3.], [4.]],
                               indices=[0, 0, 1, 2],
                               dense_shape=[10, 1])
        self.reduce_and_verify(
            inputs,
            expect,
            self.RunOptions(
                mode=["func_graph"
                      ],  # Sparse reduce is not supported in eager.
                num_processes=2,
                reduce_op=ReduceOp.SUM,
                use_collective_v2=use_collective_v2))

    @combinations.generate(
        combinations.combine(num_processes=[1, 2],
                             required_gpus=[0, 1, 2],
                             implementation=[
                                 CommunicationImplementation.AUTO,
                                 CommunicationImplementation.RING,
                                 CommunicationImplementation.NCCL
                             ],
                             reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
                             use_scoped_allocator=[True, False],
                             use_collective_v2=[True, False]))
    def testBatchAllReduceDense(self, num_processes, required_gpus,
                                implementation, reduce_op,
                                use_scoped_allocator, use_collective_v2):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")

        options = self.RunOptions(
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_scoped_allocator=use_scoped_allocator,
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = [1.0, 2.0]
        if group_size == 2:
            expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0]
        elif group_size == 4:
            expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0]

        self.batch_reduce_and_verify(inputs, expect, options)

    @combinations.generate(
        combinations.combine(
            num_processes=[1, 2],
            required_gpus=[0, 1, 2],
            implementation=[
                CommunicationImplementation.AUTO,
                CommunicationImplementation.RING,
                CommunicationImplementation.NCCL,
            ],
            # TODO(b/166682130): add MEAN reduce once the bug is fixed.
            reduce_op=ReduceOp.SUM,
            use_scoped_allocator=[True, False],
            use_collective_v2=[True, False]))
    def testBatchAllReduceSparse(self, num_processes, required_gpus,
                                 implementation, reduce_op,
                                 use_scoped_allocator, use_collective_v2):
        if (required_gpus == 0
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip CPU + NCCL combination")
        if (num_processes == 2
                and implementation == CommunicationImplementation.NCCL):
            self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
                          "physical GPUs for every process.")

        options = self.RunOptions(
            mode=["func_graph"],  # Sparse reduce is not supported in eager.
            num_processes=num_processes,
            gpus_per_process=required_gpus,
            reduce_op=reduce_op,
            communication_options=collective_util.Options(
                implementation=implementation),
            use_scoped_allocator=use_scoped_allocator,
            use_collective_v2=use_collective_v2)
        group_size = options.num_processes * (options.gpus_per_process or 1)

        inputs_data = ([
            IndexedSlicesValue(values=[[1.], [2.]],
                               indices=[0, 1],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[3.], [4.]],
                               indices=[1, 2],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[5.], [6.]],
                               indices=[1, 2],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[7.], [8.]],
                               indices=[0, 1],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[9.], [10.]],
                               indices=[3, 4],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[11.], [12.]],
                               indices=[3, 4],
                               dense_shape=[5, 1])
        ], [
            IndexedSlicesValue(values=[[13.], [14.]],
                               indices=[8, 9],
                               dense_shape=[10, 1]),
            IndexedSlicesValue(values=[[15.], [16.]],
                               indices=[3, 4],
                               dense_shape=[5, 1])
        ])
        inputs = inputs_data[0:group_size]

        if group_size == 1:
            expect = [
                IndexedSlices(values=[[1.], [2.]],
                              indices=[0, 1],
                              dense_shape=[10, 1]),
                IndexedSlicesValue(values=[[3.], [4.]],
                                   indices=[1, 2],
                                   dense_shape=[5, 1])
            ]
        if group_size == 2:
            expect = [
                IndexedSlices(values=[[1.], [2.], [5.], [6.]],
                              indices=[0, 1, 1, 2],
                              dense_shape=[10, 1]),
                IndexedSlices(values=[[3.], [4.], [7.], [8.]],
                              indices=[1, 2, 3, 4],
                              dense_shape=[5, 1])
            ]
        elif group_size == 4:
            expect = [
                IndexedSlices(values=[[1.], [2.], [5.], [6.], [9.], [10.],
                                      [13.], [14.]],
                              indices=[0, 1, 1, 2, 3, 4, 8, 9],
                              dense_shape=[10, 1]),
                IndexedSlices(values=[[3.], [4.], [7.], [8.], [11.], [12.],
                                      [15.], [16.]],
                              indices=[1, 2, 0, 1, 3, 4, 3, 4],
                              dense_shape=[5, 2])
            ]
            self.batch_reduce_and_verify(inputs, expect, options)

    @combinations.generate(
        combinations.combine(num_processes=[1, 2],
                             required_gpus=[0, 1, 2],
                             axis=[0, 1, 2],
                             func_mode=["eager", "func_graph"],
                             implementation=[
                                 CommunicationImplementation.NCCL,
                                 CommunicationImplementation.AUTO,
                                 CommunicationImplementation.RING
                             ],
                             use_collective_v2=[True, False]))
    def testAllGatherSameShape(self, num_processes, required_gpus,
                               implementation, func_mode, axis,
                               use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)
            value = constant_op.constant([[[1, 2], [1, 2]]],
                                         dtype=dtypes.float32)

            def gather_fn():
                per_replica_value = make_per_replica_value(value, devices)
                gathered_values = collective._gather(per_replica_value,
                                                     per_replica_value,
                                                     axis=axis,
                                                     options=options)
                gathered_values = self.as_list(gathered_values)
                # Skip checking devices in eager. In eager the device attribute doesn't
                # reflect the actual device of the tensor.
                if not context.executing_eagerly():
                    self.assertAllEqual(devices,
                                        [v.device for v in gathered_values])
                return [ops.convert_to_tensor(v) for v in gathered_values]

            group_size = num_processes * (required_gpus or 1)
            expect = array_ops.concat([value] * group_size, axis=axis)
            per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)

            if func_mode == "eager":
                result = gather_fn()
                self.assertAllClose(result, per_replica_expect)

            if func_mode == "func_graph":
                result = def_function.function(gather_fn)()
                self.assertAllClose(result, per_replica_expect)

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=[1, 2],
                             required_gpus=[0, 1, 2],
                             implementation=[CommunicationImplementation.RING])
    )
    def testCollectiveV2ControlFlow(self, num_processes, required_gpus,
                                    implementation):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)
            value = make_per_replica_value(constant_op.constant([1.]), devices)

            @def_function.function
            def reduce_fn():
                def cond_body():
                    reduced = collective.reduce(reduce_util.ReduceOp.SUM,
                                                value, value, options)
                    return math_ops.add_n(self.as_list(reduced)) / len(devices)

                return control_flow_ops.cond(array_ops.identity(False),
                                             cond_body, cond_body)

            num_replicas = num_processes * len(devices)
            self.assertAllEqual(reduce_fn(), [1. * num_replicas])

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=1,
                             required_gpus=2,
                             implementation=[
                                 CommunicationImplementation.NCCL,
                                 CommunicationImplementation.RING
                             ],
                             use_collective_v2=[True, False]))
    def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
                                                      required_gpus,
                                                      implementation,
                                                      use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)

            # We would like to simulate the following sequence:
            #   thread-0  device0                 device1
            #   thread-1          device0 device1
            # If the kernel launch sequence is as-is the program will deadlock since
            # NCCL requires the launch order to be same on each device.
            v0 = make_per_replica_value(1.0, devices)
            v1 = make_per_replica_value(2.0, devices)

            # Add a delay to collective_ops.all_reduce according to the input tensors
            # index in `sequence.`
            sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
            all_reduce = collective_ops.all_reduce

            def delayed_all_reduce(input_tensor, *args, **kwargs):
                for idx, v in enumerate(sequence):
                    if input_tensor is v:
                        time.sleep(idx)
                        break
                return all_reduce(input_tensor, *args, **kwargs)

            with test.mock.patch.object(collective_ops, "all_reduce",
                                        delayed_all_reduce):
                # We only use NCCL for batch reduce with two or more values, so we use
                # two values here.

                def thread_fn():
                    reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                      [(v0, v0),
                                                       (v0, v0)], options)
                    self.assertAllEqual(reduced[0].values, [2.0, 2.0])
                    self.assertAllEqual(reduced[1].values, [2.0, 2.0])

                t = threading.Thread(target=thread_fn)
                t.start()
                reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                  [(v1, v1),
                                                   (v1, v1)], options)
                self.assertAllEqual(reduced[0].values, [4.0, 4.0])
                self.assertAllEqual(reduced[1].values, [4.0, 4.0])
                t.join()

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=1,
                             required_gpus=2,
                             implementation=[
                                 CommunicationImplementation.NCCL,
                                 CommunicationImplementation.RING
                             ],
                             use_collective_v2=[True, False]))
    def testInputsAreFunctionArgs(self, num_processes, required_gpus,
                                  implementation, use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(implementation=implementation)

            @def_function.function
            def reduce_fn(v):
                # Function inputs don't have device placement.
                self.assertEqual(v.values[0].device, "")
                self.assertEqual(v.values[1].device, "")
                # We only use NCCL for batch reduce with two or more values, so we use
                # two values here.
                reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                                  [(v, v), (v, v)], options)
                self.assertEqual(reduced[0].values[0].device, devices[0])
                self.assertEqual(reduced[0].values[1].device, devices[1])
                self.assertEqual(reduced[1].values[0].device, devices[0])
                self.assertEqual(reduced[1].values[1].device, devices[1])
                # Returning Mirrored only evaluates the primary value, which causes
                # hanging,
                return [reduced[0].values, reduced[1].values]

            v = make_per_replica_value(1.0, devices)
            reduced = reduce_fn(v)
            self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]])

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=2,
                             required_gpus=[0, 1],
                             implementation=[CommunicationImplementation.RING],
                             use_collective_v2=[True, False]))
    def testTimeoutReduceDense(self, num_processes, implementation,
                               required_gpus, use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, task_id = self.make_collective(
                num_processes, required_gpus)
            if task_id != 0:
                return

            v = make_per_replica_value(1.0, devices)
            options = collective_util.Options(timeout_seconds=1,
                                              implementation=implementation)

            @def_function.function
            def reduce_dense():
                return collective.reduce(reduce_util.ReduceOp.SUM, v, v,
                                         options)

            # The collective should time out because we only launch it on worker-0,
            # while there're three workers in total.
            with self.assertRaises(errors.DeadlineExceededError):
                reduce_dense()

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=2,
                             required_gpus=[0, 1],
                             implementation=[CommunicationImplementation.RING],
                             use_collective_v2=[True, False]))
    def testTimeoutBatchReduceDense(self, num_processes, implementation,
                                    required_gpus, use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, task_id = self.make_collective(
                num_processes, required_gpus)
            if task_id != 0:
                return

            v = make_per_replica_value(1.0, devices)
            options = collective_util.Options(timeout_seconds=1,
                                              implementation=implementation)

            @def_function.function
            def batch_reduce_dense():
                return collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                               [(v, v), (v, v)], options)

            # The collective should time out because we only launch it on worker-0,
            # while there're two workers in total.
            with self.assertRaises(errors.DeadlineExceededError):
                batch_reduce_dense()

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=2,
                             required_gpus=[0, 1],
                             implementation=[CommunicationImplementation.RING],
                             use_collective_v2=[True, False]))
    def testTimeoutReduceSparse(self, num_processes, implementation,
                                required_gpus, use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, task_id = self.make_collective(
                num_processes, required_gpus)
            if task_id != 0:
                return

            v = make_per_replica_value(
                IndexedSlicesValue(values=[[4., 6.]],
                                   indices=[1],
                                   dense_shape=[5, 2]), devices)
            options = collective_util.Options(timeout_seconds=1,
                                              implementation=implementation)

            @def_function.function
            def reduce_sparse():
                return collective.reduce(reduce_util.ReduceOp.SUM, v, v,
                                         options)

            # The collective should time out because we only launch it on worker-0,
            # while there're two workers in total.
            with self.assertRaises(errors.DeadlineExceededError):
                reduce_sparse()

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=2,
                             required_gpus=[0, 1],
                             implementation=[CommunicationImplementation.RING],
                             use_collective_v2=[True, False]))
    def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
                                     implementation, use_collective_v2):
        def replica_fn():
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
                use_collective_v2)
            collective, devices, task_id = self.make_collective(
                num_processes, required_gpus)
            if task_id != 0:
                return

            v = make_per_replica_value(
                IndexedSlicesValue(values=[[4., 6.]],
                                   indices=[1],
                                   dense_shape=[5, 2]), devices)
            options = collective_util.Options(timeout_seconds=1,
                                              implementation=implementation)

            @def_function.function
            def batch_reduce_sparse():
                return collective.batch_reduce(reduce_util.ReduceOp.SUM,
                                               [(v, v), (v, v)], options)

            # The collective should time out because we only launch it on worker-0,
            # while there're two workers in total.
            with self.assertRaises(errors.DeadlineExceededError):
                batch_reduce_sparse()

        get_global_mpr(num_processes).run(replica_fn)

    @combinations.generate(
        combinations.combine(num_processes=1, required_gpus=2))
    def testNcclOrdering(self, num_processes, required_gpus):
        def replica_fn():
            cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
            cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(
                implementation=CommunicationImplementation.NCCL)

            v_dense = make_per_replica_value([1.0, 1.0], devices)
            v_sparse = make_per_replica_value([
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
            ], devices)

            @def_function.function
            def nested_dense():
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)

            @def_function.function
            def nested_sparse():
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)

            # All collectives, function calls, if clause and while loops should be
            # chained by control dependencies, so that the execution order is
            # deterministic.
            @def_function.function
            def f():
                # pylint: disable=pointless-statement
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reducing dense value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                # reducing sparse value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reduce dense value in nested tf.function.
                nested_dense()
                # reduce sparse value in nested tf.function.
                nested_sparse()
                # reduce dense value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                else:
                    v_dense
                # reduce sparse value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    v_sparse
                else:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                # reduce dense value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                    i += 1
                # reduce sparse value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                    i += 1
                # reducing dense and sparse value again.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # pylint: enable=pointless-statement

            graph = f.get_concrete_function().graph
            should_be_ordered = set([
                "CollectiveReduceV2", "CollectiveGatherV2", "If", "While",
                "StatefulPartitionedCall"
            ])
            nodes_by_device = {}
            for op in graph.get_operations():
                if op.type in should_be_ordered:
                    if op.device not in nodes_by_device:
                        nodes_by_device[op.device] = []
                    nodes_by_device[op.device].append(op)
            order = test_util.topological_sort_operations(
                graph.get_operations())
            for device in devices:
                device = device_util.canonicalize(device)
                # Those function ops don't have device annotations, but they contain
                # collectives for both devices so we always include them.
                operations = nodes_by_device[device] + nodes_by_device[""]
                # Verify that we get all types of nodes we want.
                self.assertEqual(set(op.type for op in operations),
                                 should_be_ordered)
                test_util.assert_sequential_execution(order, operations)

        get_global_mpr(num_processes).run(replica_fn)
    def _get_test_objects(self,
                          task_type,
                          task_id,
                          num_gpus=0,
                          communication=CollectiveCommunication.AUTO,
                          use_strategy_object=False,
                          local_mode=False):
        collective_keys = cross_device_utils.CollectiveKeys(
            group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
        if local_mode:
            if num_gpus:
                devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
            else:
                devices = ["/device:CPU:0"]

            comm_options = collective_util.Options(
                implementation=communication)
            if use_strategy_object:
                strategy = (
                    mwms_lib.CollectiveAllReduceStrategy._from_local_devices(
                        devices, comm_options))  # pylint: disable=protected-access
                return strategy, devices, ""
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    devices=devices,
                    group_size=len(devices),
                    options=comm_options,
                    collective_keys=collective_keys)
                return collective_all_reduce_ops, devices, ""
        else:
            # NCCL requires physical GPUs for every replica, which we can't do with
            # simulated multi host set up now.
            assert communication != CollectiveCommunication.NCCL
            if num_gpus:
                devices = [
                    "/job:%s/task:%d/replica:0/device:GPU:%d" %
                    (task_type, task_id, i) for i in range(num_gpus)
                ]
            else:
                devices = [
                    "/job:%s/task:%d/replica:0/device:CPU:0" %
                    (task_type, task_id)
                ]

            comm_options = collective_util.Options(
                implementation=communication)
            if use_strategy_object:
                resolver = cluster_resolver.SimpleClusterResolver(
                    cluster_spec=multi_worker_util.normalize_cluster_spec(
                        self._cluster_spec),
                    task_type=task_type,
                    task_id=task_id,
                    num_accelerators={"GPU": num_gpus})
                strategy = mwms_lib.CollectiveAllReduceStrategy(
                    communication_options=comm_options,
                    cluster_resolver=resolver)
                return (strategy, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])
            else:
                collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
                    devices=devices,
                    group_size=len(devices) * NUM_WORKERS,
                    options=comm_options,
                    collective_keys=collective_keys)
                return (collective_all_reduce_ops, devices,
                        "grpc://" + self._cluster_spec[task_type][task_id])