Exemplo n.º 1
0
 def run_nccl_broadcast_double_send():
     for device in self._devices:
         with ops.device(device):
             t = constant_op.constant(tensor_value)
             collective_ops.broadcast_send(t, t.shape, t.dtype,
                                           self._group_size, group_key,
                                           instance_key)
Exemplo n.º 2
0
    def _multi_worker_init(**kwargs):
        replica_ctx = get_replica_context()
        global_id = replica_ctx.replica_id_in_sync_group
        if global_id == 0:
            unique_id = kit_lib.get_nccl_unique_id()
            re = collective_ops.broadcast_send(
                unique_id,
                TensorShape([
                    32,
                ]),
                int32,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=1,
                instance_key=2)
        else:
            re = collective_ops.broadcast_recv(
                TensorShape([
                    32,
                ]),
                int32,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=1,
                instance_key=2)
        if global_id == 0:
            global_seed = kwargs.get("seed", None) or kit_lib.gen_random_seed()
            re_seed = collective_ops.broadcast_send(
                global_seed,
                TensorShape([
                    1,
                ]),
                int64,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=1,
                instance_key=3)
        else:
            global_seed = kwargs.get("seed", None)
            re_seed = collective_ops.broadcast_recv(
                TensorShape([
                    1,
                ]),
                int64,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=1,
                instance_key=3)

            if (global_seed and global_seed != re_seed):
                logging.warning(
                    "The seed: {} is not consistent with that from cheif-node: {}, "
                    "and the seed from cheif-node will be used.".format(
                        global_seed, re_seed))

        visible_devices = _get_visible_devices()
        status = kit_lib.plugin_init(
            global_id,
            replica_ctx.num_replicas_in_sync,
            re,
            re_seed,
            visible_devices,
            global_batch_size=kwargs['global_batch_size'])
        return status
    def testBasicNcclBroadcast(self):
        tensor_value = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
        group_size = 2
        group_key = 1
        instance_key = 1
        devices = ['/GPU:{}'.format(i) for i in range(group_size)]

        with self.session(config=self._configure(group_size)) as sess:
            if not test_util.is_gpu_available(cuda_only=True):
                self.skipTest('No GPU available')
            collectives = []
            with ops.device(devices[0]):
                t = constant_op.constant(tensor_value)
                collectives.append(
                    collective_ops.broadcast_send(t, t.shape, t.dtype,
                                                  group_size, group_key,
                                                  instance_key))
            with ops.device(devices[1]):
                t = constant_op.constant(tensor_value)
                collectives.append(
                    collective_ops.broadcast_recv(t.shape, t.dtype, group_size,
                                                  group_key, instance_key))
            results = sess.run(collectives)
        for result in results:
            self.assertAllClose(result, tensor_value, rtol=1e-5, atol=1e-5)
Exemplo n.º 4
0
            def initial_value_fn():  # pylint: disable=g-missing-docstring
                # Only the first device participates in the broadcast of initial values.
                group_key = self._collective_keys.get_group_key([device])
                group_size = self._num_workers
                collective_instance_key = (
                    self._collective_keys.get_variable_instance_key())

                with ops.device(device):
                    initial_value = kwargs["initial_value"]
                    if callable(initial_value):
                        initial_value = initial_value()
                    assert not callable(initial_value)
                    initial_value = ops.convert_to_tensor(initial_value,
                                                          dtype=kwargs.get(
                                                              "dtype", None))

                    if self._num_workers > 1:
                        if self._is_chief:
                            bcast_send = collective_ops.broadcast_send(
                                initial_value, initial_value.shape,
                                initial_value.dtype, group_size, group_key,
                                collective_instance_key)
                            with ops.control_dependencies([bcast_send]):
                                return array_ops.identity(initial_value)
                        else:
                            return collective_ops.broadcast_recv(
                                initial_value.shape, initial_value.dtype,
                                group_size, group_key, collective_instance_key)
                    return initial_value
Exemplo n.º 5
0
    def broadcast_variables(self):
        replica_ctx = tf.distribute.get_replica_context()
        g_replica_id = replica_ctx.replica_id_in_sync_group
        if replica_ctx.num_replicas_in_sync == 1:
            return

        variable = tf.identity(self._embedding_weights)
        if 0 == g_replica_id:
            values = collective_ops.broadcast_send(
                variable,
                variable.shape,
                variable.dtype,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=2,
                instance_key=2 + self._uid,
                timeout=5)
        else:
            values = collective_ops.broadcast_recv(
                variable.shape,
                variable.dtype,
                group_size=replica_ctx.num_replicas_in_sync,
                group_key=2,
                instance_key=2 + self._uid,
                timeout=5)
        self._embedding_weights.assign(values)
            def initial_value_fn():  # pylint: disable=g-missing-docstring
                # Only the first device participates in the broadcast of initial values.
                group_key = self._collective_keys.get_group_key([device])
                group_size = self._num_workers
                collective_instance_key = (
                    self._collective_keys.get_variable_instance_key())

                with ops.device(device):
                    initial_value = kwargs["initial_value"]
                    if callable(initial_value):
                        initial_value = initial_value()
                    assert not callable(initial_value)
                    initial_value = ops.convert_to_tensor(initial_value,
                                                          dtype=kwargs.get(
                                                              "dtype", None))

                    if self._num_workers > 1:
                        if self._is_chief:
                            # Unwrap `initial_value` if it is a `CheckpointInitialValue`.
                            # TODO(b/138130844): Revert the following check once
                            # `CheckpointInitialValue` class is removed.
                            if isinstance(initial_value,
                                          trackable.CheckpointInitialValue):
                                initial_value = initial_value.wrapped_value
                            bcast_send = collective_ops.broadcast_send(
                                initial_value, initial_value.shape,
                                initial_value.dtype, group_size, group_key,
                                collective_instance_key)
                            with ops.control_dependencies([bcast_send]):
                                return array_ops.identity(initial_value)
                        else:
                            return collective_ops.broadcast_recv(
                                initial_value.shape, initial_value.dtype,
                                group_size, group_key, collective_instance_key)
                    return initial_value
Exemplo n.º 7
0
 def send():
     s0 = collective_ops.broadcast_send(c * 3,
                                        c.shape,
                                        c.dtype,
                                        group_size=2,
                                        group_key=1,
                                        instance_key=1)
     with ops.control_dependencies([s0.op]):
         return array_ops.identity(c)
Exemplo n.º 8
0
 def run_basic_nccl_broadcast():
   collectives = []
   with ops.device(self._devices[0]):
     t = constant_op.constant(tensor_value)
     collectives.append(collective_ops.broadcast_send(
         t, t.shape, t.dtype, self._group_size, group_key, instance_key))
   with ops.device(self._devices[1]):
     t = constant_op.constant(tensor_value)
     collectives.append(collective_ops.broadcast_recv(
         t.shape, t.dtype, self._group_size, group_key, instance_key))
   return collectives
Exemplo n.º 9
0
    def testAbortInstanceParamsResolution(self, device, communication):
        if communication == "NCCL":
            self.skipTest("b/171358086: cannot test multi worker NCCL")
        dev0 = "/device:%s:0" % device
        cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
        enable_collective_ops_with_barrier(cluster_resolver)
        group_size = 2
        group_key = 100
        instance_key = 100
        in_tensor = constant_op.constant([1.])

        # First perform a normal all-reduce to complete the group resolution.
        with ops.device(dev0):
            collective_ops.all_reduce(in_tensor, group_size, group_key,
                                      instance_key)

        # We use broadcast to test aborting instance resolution since only broadcast
        # waits for the group.

        if cluster_resolver.task_id == 1:

            def abort_fn():
                time.sleep(2)
                context.context().abort_collective_ops(errors.UNAVAILABLE,
                                                       "peer down")

            t = threading.Thread(target=abort_fn)
            t.start()

            # Use a different instance key to trigger another instance resolution.
            instance_key = 101
            with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
                # This hangs on params resolution since we're only launching one
                # collective for a group size of 2.
                with ops.device(dev0):
                    collective_ops.broadcast_send(in_tensor, (1, ),
                                                  dtypes.float32, group_size,
                                                  group_key, instance_key)

            # After abortion, subsequent collectives should fail immediately.
            with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
                with ops.device(dev0):
                    collective_ops.broadcast_send(in_tensor, (1, ),
                                                  dtypes.float32, group_size,
                                                  group_key, instance_key)

            t.join()

        # Enable collective ops again in order to reset the collective executor.
        enable_collective_ops_with_barrier(cluster_resolver)
        # Reassign instance_key so that it's the same on each worker.
        instance_key = 100
        with ops.device(dev0):
            if cluster_resolver.task_id == 0:
                collective_ops.broadcast_send(in_tensor, (1, ), dtypes.float32,
                                              group_size, group_key,
                                              instance_key)
            else:
                collective_ops.broadcast_recv(
                    (1, ), dtypes.float32, group_size, group_key, instance_key)
          def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
            with ops.device(device):
              initial_value = initial_value_fn()
              assert not callable(initial_value)
              initial_value = ops.convert_to_tensor(initial_value)

              if self._is_chief and index == 0:
                bcast_send = collective_ops.broadcast_send(
                    initial_value, initial_value.shape, initial_value.dtype,
                    group_size, group_key, collective_instance_key)
                with ops.control_dependencies([bcast_send]):
                  return array_ops.identity(initial_value)
              else:
                return collective_ops.broadcast_recv(
                    initial_value.shape, initial_value.dtype, group_size,
                    group_key, collective_instance_key)
          def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
            with ops.device(device):
              initial_value = initial_value_fn()
              assert not callable(initial_value)
              initial_value = ops.convert_to_tensor(initial_value)

              if self._is_chief and index == 0:
                bcast_send = collective_ops.broadcast_send(
                    initial_value, initial_value.shape, initial_value.dtype,
                    group_size, group_key, collective_instance_key)
                with ops.control_dependencies([bcast_send]):
                  return array_ops.identity(initial_value)
              else:
                return collective_ops.broadcast_recv(
                    initial_value.shape, initial_value.dtype, group_size,
                    group_key, collective_instance_key)
Exemplo n.º 12
0
 def _testCollectiveBroadcast(self, t0):
   group_key = 1
   instance_key = 1
   with self.test_session(
       config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
     with ops.device('/CPU:0'):
       in0 = constant_op.constant(t0)
       out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype,
                                            2, group_key, instance_key)
     with ops.device('/CPU:1'):
       c1 = constant_op.constant(t0)
       out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype,
                                            2, group_key, instance_key)
     run_options = config_pb2.RunOptions()
     run_options.experimental.collective_graph_key = 1
     results = sess.run([out0, out1], options=run_options)
   self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5)
   self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5)
  def testNcclBroadcastDoubleSend(self):
    tensor_value = [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
    group_size = 2
    group_key = 1
    instance_key = 1
    devices = ['/GPU:{}'.format(i) for i in range(group_size)]

    with self.session(config=self._configure(group_size)) as sess:
      if not test_util.is_gpu_available(cuda_only=True):
        self.skipTest('No GPU available')
      collectives = []
      for device in devices:
        with ops.device(device):
          t = constant_op.constant(tensor_value)
          collectives.append(collective_ops.broadcast_send(
              t, t.shape, t.dtype, group_size, group_key, instance_key))
      with self.assertRaisesRegexp(errors.InternalError, 'already has source'):
        sess.run(collectives)
Exemplo n.º 14
0
 def _testCollectiveBroadcast(self, t0):
   group_key = 1
   instance_key = 1
   with self.session(
       config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
     with ops.device('/CPU:0'):
       in0 = constant_op.constant(t0)
       out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype,
                                            2, group_key, instance_key)
     with ops.device('/CPU:1'):
       c1 = constant_op.constant(t0)
       out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype,
                                            2, group_key, instance_key)
     run_options = config_pb2.RunOptions()
     run_options.experimental.collective_graph_key = 1
     results = sess.run([out0, out1], options=run_options)
   self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5)
   self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5)
Exemplo n.º 15
0
    def _broadcast_implementation(self, initial_value, device):
        # This is an extension point for overriding, try to keep a stable API.

        if self._num_workers <= 1:
            return initial_value

        # Only the first device participates in the broadcast of initial values.
        group_key = self._collective_keys.get_group_key([device])
        group_size = self._num_workers
        collective_instance_key = (
            self._collective_keys.get_variable_instance_key())

        if self._is_chief:
            bcast_send = collective_ops.broadcast_send(
                initial_value, initial_value.shape, initial_value.dtype,
                group_size, group_key, collective_instance_key)
            with ops.control_dependencies([bcast_send]):
                return array_ops.identity(initial_value)
        else:
            return collective_ops.broadcast_recv(initial_value.shape,
                                                 initial_value.dtype,
                                                 group_size, group_key,
                                                 collective_instance_key)
Exemplo n.º 16
0
with tf.device('gpu:1'):
    sum_reduce.append(collective_ops.all_reduce(v2, 2, 0, 1, 'Add', 'Id'))
print(sess.run(sum_reduce))

average_reduce = []
with tf.device('gpu:0'):
    average_reduce.append(collective_ops.all_reduce(v1, 2, 1, 1, 'Add', 'Div'))
with tf.device('gpu:1'):
    average_reduce.append(collective_ops.all_reduce(v2, 2, 1, 1, 'Add', 'Div'))
print(sess.run(average_reduce))

print('==========================')

bcast = []
# with tf.device('cpu:0'):
#     bcast.append(collective_ops.broadcast_send(v0, v0.shape, v0.dtype, 2, 3, 1))
with tf.device('gpu:0'):
    bcast.append(collective_ops.broadcast_send(v1, v1.shape, v1.dtype, 2, 3, 2))
with tf.device('gpu:1'):
    bcast.append(collective_ops.broadcast_recv(v1.shape, v1.dtype, 2, 3, 2))

print(sess.run(bcast))

print('==========================')

average_reduce = []
with tf.device('gpu:0'):
    average_reduce.append(collective_ops.all_reduce(v1, 2, 4, 3, 'Add', 'Div'))
with tf.device('gpu:1'):
    average_reduce.append(collective_ops.all_reduce(v3, 2, 4, 3, 'Add', 'Div'))
print(sess.run(average_reduce))
Exemplo n.º 17
0
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key):
    return collective_ops.broadcast_send(t, shape, dtype, group_size,
                                         group_key, instance_key)