def _start_check_health_thread(self): # Use a dummy all-reduce as a barrier to wait for all workers to be up, # otherwise the check health may fail immediately. # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = ops.convert_to_tensor([]) logging.info("Waiting for the cluster, timeout = %s", self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, dummy_value, dummy_value, experimental_hints=collective_util.Hints( timeout_seconds=self._check_health_initial_timeout)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( "Timeout waiting for the cluster, timeout is %d seconds" % self._check_health_initial_timeout) self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread(target=self._check_health, daemon=True) self._check_health_thread.start()
def testCreateOptionsViaHints(self): with self.assertLogs() as cm: options = collective_util.Hints(50, 1) self.assertTrue(any("is deprecated" in msg for msg in cm.output)) self.assertIsInstance(options, collective_util.Options) self.assertEqual(options.bytes_per_pack, 50) self.assertEqual(options.timeout_seconds, 1)
def testInputsAreFunctionArgs(self, communication): # Function inputs don't have device placement. hints = collective_util.Hints(bytes_per_pack=1) collective, devices, _ = self._get_test_objects( None, None, num_gpus=2, communication=communication, use_strategy_object=False, local_mode=True) devices = [device_util.canonicalize(d) for d in devices] @def_function.function def reduce_fn(v): self.assertEqual(v.values[0].device, "") self.assertEqual(v.values[1].device, "") # We only use NCCL for batch reduce with two or more values, so we use two # values here. reduced = collective.batch_reduce( reduce_util.ReduceOp.SUM, [(v, v), (v, v)], experimental_hints=hints) self.assertEqual(reduced[0].values[0].device, devices[0]) self.assertEqual(reduced[0].values[1].device, devices[1]) self.assertEqual(reduced[1].values[0].device, devices[0]) self.assertEqual(reduced[1].values[1].device, devices[1]) # Returning Mirrored only evaluates the primary value, which causes # hanging, return [reduced[0].values, reduced[1].values] v = _make_per_replica([1.0, 2.0], devices) reduced = reduce_fn(v) self.assertAllEqual(self.evaluate(reduced), [[3.0, 3.0], [3.0, 3.0]])
def testTimeoutReduceSparse(self, communication, required_gpus): hints = collective_util.Hints(timeout_seconds=1) collective, devices, _ = self._get_test_objects( "worker", 0, num_gpus=required_gpus, communication=communication, use_strategy_object=False) remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec( self._cluster_spec), protocol="grpc") devices = [device_util.canonicalize(d) for d in devices] v = value_lib.PerReplica([ _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) ]) @def_function.function def reduce_sparse(): collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) # The collective should time out because we only launch it on worker-0, # while there're three workers in total. with self.assertRaises(errors.DeadlineExceededError): reduce_sparse() # Reset since collective failures poison the context. context._reset_context() # pylint: disable=protected-access
def testTimeoutBatchReduceDense(self, communication, required_gpus): hints = collective_util.Hints(timeout_seconds=1) collective, devices, _ = self._get_test_objects( "worker", 0, num_gpus=required_gpus, communication=communication, use_strategy_object=False) remote.connect_to_cluster(multi_worker_util.normalize_cluster_spec( self._cluster_spec), protocol="grpc") devices = [device_util.canonicalize(d) for d in devices] v = _make_per_replica([1.0], devices) @def_function.function def batch_reduce_dense(): collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], hints) # The collective should time out because we only launch it on worker-0, # while there're three workers in total. with self.assertRaises(errors.DeadlineExceededError): batch_reduce_dense() # Reset since collective failures poison the context. context._reset_context() # pylint: disable=protected-access
def testReductionDistributed(self, required_gpus, use_strategy_object, bytes_per_pack): hints = collective_util.Hints(bytes_per_pack=bytes_per_pack) self._run_between_graph_clients( self._test_reduction, self._cluster_spec, required_gpus, communication=CollectiveCommunication.RING, use_strategy_object=use_strategy_object, hints=hints)
def batch_reduce(self, reduce_op, value_destination_pairs, experimental_hints=None): """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. This can be faster than multiple individual `reduce`s because we can fuse several tensors into one or multiple packs before reduction. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the `per_replica_value` will be reduced. value_destination_pairs: A list or a tuple of PerReplica objects (or tensors with device set if there is one device) and destinations. experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: a list of Mirrored objects. Raises: ValueError: if `value_destination_pairs` is not an iterable of tuples of PerReplica objects and destinations. """ # TODO(yuefengz): if destinations are different, split into several # `_batch_reduce` invocations. if not _validate_value_destination_pairs(value_destination_pairs): # If the first element of each pair is a tensor, we try to turn it into a # PerReplica object. value_destination_pairs = _normalize_value_destination_pairs( value_destination_pairs) for _, d in value_destination_pairs: validate_destinations(d) # Shortcut all PerReplica objects only contain one value. if self._num_between_graph_workers == 1 and _all_devices_match( value_destination_pairs) and len( value_destination_pairs[0][0].values) == 1: return [ distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) for v, _ in value_destination_pairs ] if experimental_hints is None: experimental_hints = collective_util.Hints() return self.batch_reduce_implementation(reduce_op, value_destination_pairs, experimental_hints)
def replica_fn(): collective, devices, task_id = self.make_collective( num_processes, required_gpus, communication) if task_id != 0: return v = make_per_replica_value(1.0, devices) hints = collective_util.Hints(timeout_seconds=1) @def_function.function def reduce_dense(): collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) # The collective should time out because we only launch it on worker-0, # while there're three workers in total. with self.assertRaises(errors.DeadlineExceededError): reduce_dense()
def reduce(self, reduce_op, per_replica_value, destinations, experimental_hints=None): """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `reduce_op` and put the result on `destinations`. Args: reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how per_replica_value will be reduced. per_replica_value: A `tf.distribute.DistributedValues` object or a tensor with device set. destinations: the reduction destinations. experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: a Mirrored object. Raises: ValueError: if per_replica_value can't be converted to a PerReplica object or if destinations aren't strings, Variables or DistributedValues """ if not isinstance(per_replica_value, value_lib.DistributedValues): per_replica_value = _make_tensor_into_per_replica( per_replica_value) validate_destinations(destinations) # Shortcut if `per_replica_value` only contains one value. if self._num_between_graph_workers == 1 and len( per_replica_value.values) == 1 and _devices_match( per_replica_value, destinations): with ops.device(per_replica_value.values[0].device): v = array_ops.identity(per_replica_value.values[0]) return distribute_utils.regroup((v, ), wrap_class=value_lib.Mirrored) if experimental_hints is None: experimental_hints = collective_util.Hints() return self.reduce_implementation(reduce_op, per_replica_value, destinations, experimental_hints)
def replica_fn(): collective, devices, task_id = self.make_collective( num_processes, required_gpus, communication) if task_id != 0: return v = make_per_replica_value( IndexedSlicesValue( values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) hints = collective_util.Hints(timeout_seconds=1) @def_function.function def batch_reduce_sparse(): collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], hints) # The collective should time out because we only launch it on worker-0, # while there're two workers in total. with self.assertRaises(errors.DeadlineExceededError): batch_reduce_sparse()
def _start_check_health_thread(self): if not context.executing_eagerly(): logging.info("Check health is only supported in eager.") return # Use a dummy all-reduce as a barrier to wait for all workers to be up, # otherwise the check health may fail immediately. # Use array_ops.identity to create the dummy tensor so that we have a new # Tensor. If we use constant it may be a cached from on a /job:localhost # device, which will cause some code that relies on tensor.device to error. # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = array_ops.identity([]) logging.info("Waiting for the cluster, timeout = %s", self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, dummy_value, dummy_value, experimental_hints=collective_util.Hints( timeout_seconds=self._check_health_initial_timeout)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( "Timeout waiting for the cluster, timeout is %d seconds" % self._check_health_initial_timeout) logging.info("Cluster is ready.") self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread( target=self._check_health, daemon=True) self._check_health_thread.start()