Exemple #1
0
 def _start_check_health_thread(self):
     # Use a dummy all-reduce as a barrier to wait for all workers to be up,
     # otherwise the check health may fail immediately.
     #
     # TODO(b/151232436): change to an explicit barrier if we have it.
     dummy_value = ops.convert_to_tensor([])
     logging.info("Waiting for the cluster, timeout = %s",
                  self._check_health_initial_timeout or "inf")
     try:
         self._host_cross_device_ops.reduce(
             reduce_util.ReduceOp.SUM,
             dummy_value,
             dummy_value,
             experimental_hints=collective_util.Hints(
                 timeout_seconds=self._check_health_initial_timeout))
         if context.is_async():
             context.async_wait()
     except errors.DeadlineExceededError:
         raise RuntimeError(
             "Timeout waiting for the cluster, timeout is %d seconds" %
             self._check_health_initial_timeout)
     self._check_health_thread_should_stop = threading.Event()
     # Start the thread as daemon to avoid it blocking the program from exiting.
     # We try best to shutdown the thread but __del__ is not guaranteed to be
     # called when program exists.
     self._check_health_thread = threading.Thread(target=self._check_health,
                                                  daemon=True)
     self._check_health_thread.start()
Exemple #2
0
 def testCreateOptionsViaHints(self):
   with self.assertLogs() as cm:
     options = collective_util.Hints(50, 1)
   self.assertTrue(any("is deprecated" in msg for msg in cm.output))
   self.assertIsInstance(options, collective_util.Options)
   self.assertEqual(options.bytes_per_pack, 50)
   self.assertEqual(options.timeout_seconds, 1)
Exemple #3
0
  def testInputsAreFunctionArgs(self, communication):
    # Function inputs don't have device placement.
    hints = collective_util.Hints(bytes_per_pack=1)
    collective, devices, _ = self._get_test_objects(
        None,
        None,
        num_gpus=2,
        communication=communication,
        use_strategy_object=False,
        local_mode=True)
    devices = [device_util.canonicalize(d) for d in devices]

    @def_function.function
    def reduce_fn(v):
      self.assertEqual(v.values[0].device, "")
      self.assertEqual(v.values[1].device, "")
      # We only use NCCL for batch reduce with two or more values, so we use two
      # values here.
      reduced = collective.batch_reduce(
          reduce_util.ReduceOp.SUM, [(v, v), (v, v)], experimental_hints=hints)
      self.assertEqual(reduced[0].values[0].device, devices[0])
      self.assertEqual(reduced[0].values[1].device, devices[1])
      self.assertEqual(reduced[1].values[0].device, devices[0])
      self.assertEqual(reduced[1].values[1].device, devices[1])
      # Returning Mirrored only evaluates the primary value, which causes
      # hanging,
      return [reduced[0].values, reduced[1].values]

    v = _make_per_replica([1.0, 2.0], devices)
    reduced = reduce_fn(v)
    self.assertAllEqual(self.evaluate(reduced), [[3.0, 3.0], [3.0, 3.0]])
    def testTimeoutReduceSparse(self, communication, required_gpus):
        hints = collective_util.Hints(timeout_seconds=1)
        collective, devices, _ = self._get_test_objects(
            "worker",
            0,
            num_gpus=required_gpus,
            communication=communication,
            use_strategy_object=False)
        remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec(
            self._cluster_spec),
                                  protocol="grpc")
        devices = [device_util.canonicalize(d) for d in devices]
        v = value_lib.PerReplica([
            _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2],
                                 devices[0])
        ])

        @def_function.function
        def reduce_sparse():
            collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)

        # The collective should time out because we only launch it on worker-0,
        # while there're three workers in total.
        with self.assertRaises(errors.DeadlineExceededError):
            reduce_sparse()

        # Reset since collective failures poison the context.
        context._reset_context()  # pylint: disable=protected-access
    def testTimeoutBatchReduceDense(self, communication, required_gpus):
        hints = collective_util.Hints(timeout_seconds=1)
        collective, devices, _ = self._get_test_objects(
            "worker",
            0,
            num_gpus=required_gpus,
            communication=communication,
            use_strategy_object=False)
        remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec(
            self._cluster_spec),
                                  protocol="grpc")
        devices = [device_util.canonicalize(d) for d in devices]
        v = _make_per_replica([1.0], devices)

        @def_function.function
        def batch_reduce_dense():
            collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
                                    hints)

        # The collective should time out because we only launch it on worker-0,
        # while there're three workers in total.
        with self.assertRaises(errors.DeadlineExceededError):
            batch_reduce_dense()

        # Reset since collective failures poison the context.
        context._reset_context()  # pylint: disable=protected-access
Exemple #6
0
 def testReductionDistributed(self, required_gpus, use_strategy_object,
                              bytes_per_pack):
   hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
   self._run_between_graph_clients(
       self._test_reduction,
       self._cluster_spec,
       required_gpus,
       communication=CollectiveCommunication.RING,
       use_strategy_object=use_strategy_object,
       hints=hints)
    def batch_reduce(self,
                     reduce_op,
                     value_destination_pairs,
                     experimental_hints=None):
        """Reduce PerReplica objects in a batch.

    Reduce each first element in `value_destination_pairs` to each second
    element which indicates the destinations.

    This can be faster than multiple individual `reduce`s because we can
    fuse several tensors into one or multiple packs before reduction.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
        `per_replica_value` will be reduced.
      value_destination_pairs: A list or a tuple of PerReplica objects (or
        tensors with device set if there is one device) and destinations.
      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
        to perform collective operations.

    Returns:
      a list of Mirrored objects.

    Raises:
      ValueError: if `value_destination_pairs` is not an iterable of
        tuples of PerReplica objects and destinations.
    """
        # TODO(yuefengz): if destinations are different, split into several
        # `_batch_reduce` invocations.
        if not _validate_value_destination_pairs(value_destination_pairs):
            # If the first element of each pair is a tensor, we try to turn it into a
            # PerReplica object.
            value_destination_pairs = _normalize_value_destination_pairs(
                value_destination_pairs)

        for _, d in value_destination_pairs:
            validate_destinations(d)

        # Shortcut all PerReplica objects only contain one value.
        if self._num_between_graph_workers == 1 and _all_devices_match(
                value_destination_pairs) and len(
                    value_destination_pairs[0][0].values) == 1:
            return [
                distribute_utils.regroup(v.values,
                                         wrap_class=value_lib.Mirrored)
                for v, _ in value_destination_pairs
            ]

        if experimental_hints is None:
            experimental_hints = collective_util.Hints()
        return self.batch_reduce_implementation(reduce_op,
                                                value_destination_pairs,
                                                experimental_hints)
    def replica_fn():
      collective, devices, task_id = self.make_collective(
          num_processes, required_gpus, communication)
      if task_id != 0:
        return

      v = make_per_replica_value(1.0, devices)
      hints = collective_util.Hints(timeout_seconds=1)

      @def_function.function
      def reduce_dense():
        collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints)

      # The collective should time out because we only launch it on worker-0,
      # while there're three workers in total.
      with self.assertRaises(errors.DeadlineExceededError):
        reduce_dense()
    def reduce(self,
               reduce_op,
               per_replica_value,
               destinations,
               experimental_hints=None):
        """Reduce `per_replica_value` to `destinations`.

    It runs the reduction operation defined by `reduce_op` and put the
    result on `destinations`.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
        per_replica_value will be reduced.
      per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
        with device set.
      destinations: the reduction destinations.
      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
        to perform collective operations.

    Returns:
      a Mirrored object.

    Raises:
      ValueError: if per_replica_value can't be converted to a PerReplica
        object or if destinations aren't strings, Variables or DistributedValues
    """
        if not isinstance(per_replica_value, value_lib.DistributedValues):
            per_replica_value = _make_tensor_into_per_replica(
                per_replica_value)

        validate_destinations(destinations)

        # Shortcut if `per_replica_value` only contains one value.
        if self._num_between_graph_workers == 1 and len(
                per_replica_value.values) == 1 and _devices_match(
                    per_replica_value, destinations):
            with ops.device(per_replica_value.values[0].device):
                v = array_ops.identity(per_replica_value.values[0])
            return distribute_utils.regroup((v, ),
                                            wrap_class=value_lib.Mirrored)

        if experimental_hints is None:
            experimental_hints = collective_util.Hints()
        return self.reduce_implementation(reduce_op, per_replica_value,
                                          destinations, experimental_hints)
    def replica_fn():
      collective, devices, task_id = self.make_collective(
          num_processes, required_gpus, communication)
      if task_id != 0:
        return

      v = make_per_replica_value(
          IndexedSlicesValue(
              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
      hints = collective_util.Hints(timeout_seconds=1)

      @def_function.function
      def batch_reduce_sparse():
        collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
                                hints)

      # The collective should time out because we only launch it on worker-0,
      # while there're two workers in total.
      with self.assertRaises(errors.DeadlineExceededError):
        batch_reduce_sparse()
  def _start_check_health_thread(self):
    if not context.executing_eagerly():
      logging.info("Check health is only supported in eager.")
      return
    # Use a dummy all-reduce as a barrier to wait for all workers to be up,
    # otherwise the check health may fail immediately.

    # Use array_ops.identity to create the dummy tensor so that we have a new
    # Tensor. If we use constant it may be a cached from on a /job:localhost
    # device, which will cause some code that relies on tensor.device to error.
    #
    # TODO(b/151232436): change to an explicit barrier if we have it.
    dummy_value = array_ops.identity([])
    logging.info("Waiting for the cluster, timeout = %s",
                 self._check_health_initial_timeout or "inf")
    try:
      self._host_cross_device_ops.reduce(
          reduce_util.ReduceOp.SUM,
          dummy_value,
          dummy_value,
          experimental_hints=collective_util.Hints(
              timeout_seconds=self._check_health_initial_timeout))
      if context.is_async():
        context.async_wait()
    except errors.DeadlineExceededError:
      raise RuntimeError(
          "Timeout waiting for the cluster, timeout is %d seconds" %
          self._check_health_initial_timeout)
    logging.info("Cluster is ready.")
    self._check_health_thread_should_stop = threading.Event()
    # Start the thread as daemon to avoid it blocking the program from exiting.
    # We try best to shutdown the thread but __del__ is not guaranteed to be
    # called when program exists.
    self._check_health_thread = threading.Thread(
        target=self._check_health,
        daemon=True)
    self._check_health_thread.start()