def _start_check_health_thread(self): # Use a dummy all-reduce as a barrier to wait for all workers to be up, # otherwise the check health may fail immediately. # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = ops.convert_to_tensor([]) logging.info("Waiting for the cluster, timeout = %s", self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, dummy_value, dummy_value, experimental_hints=collective_util.Hints( timeout_seconds=self._check_health_initial_timeout)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( "Timeout waiting for the cluster, timeout is %d seconds" % self._check_health_initial_timeout) self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread(target=self._check_health, daemon=True) self._check_health_thread.start()
def _check_health(self, device, group_key, instance_key): first = True # We need to use a large enough value so that the all-reduce forms a # complete RING. In RING implementation, when value is too small, the # all-reduce may degrade into broadcasts. This means that some worker # failure may not be detected. value = array_ops.ones((32, 32), dtype=dtypes.float32) while True: if self._check_health_thread_should_stop.is_set(): return timeout = None if first: # For the first check health we set timeout since it may need to do # group resolution, which may hang if the cluster is never healthy. timeout = self._check_health_initial_timeout first = False try: # We use an dummy all-reduce as a way to check the health of a cluster. # For RING it should be able to detect failed workers in the cluster if # the values are large enough. # # We're not using CrossDeviceOps because we need to run it with # pre-allocated group and instance keys. # # TODO(b/151232436): Replace the reduce with a check health op once we # add that. with ops.device(device): collective_ops.all_reduce(value, group_size=self._num_workers, group_key=group_key, instance_key=instance_key, merge_op="Add", final_op="Id", subdiv_offsets=[0], communication_hint="ring", timeout=timeout) if context.is_async(): context.async_wait() except (errors.UnavailableError, errors.DeadlineExceededError, errors.FailedPreconditionError, errors.CancelledError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer fails. # Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. # - FailedPrecondition: when the peer has restarted. # - DeadlineExceeded: when the first check health exceeds the deadline, # e.g. the peers take too long to be ready. # - Cancelled: when failures in organic collectives aborts first, # outgoing RPCs may be aborted with Cancelled. logging.error( "Cluster check alive failed, aborting collectives") context.context().abort_collective_ops( errors.UNAVAILABLE, "cluster check alive failed: %s" % e) except Exception as e: # pylint: disable=broad-except logging.exception("Unexpected exception in check alive.") context.context().abort_collective_ops( errors.INTERNAL, "unexecpted exception in check alive: %s" % e) return time.sleep(self._check_health_interval)
def _start_check_health_thread(self): # Use a dummy all-reduce as a barrier to wait for all workers to be up, # otherwise the check health may fail immediately. # Use array_ops.identity to create the dummy tensor so that we have a new # Tensor. If we use constant it may be a cached from on a /job:localhost # device, which will cause some code that relies on tensor.device to error. # # TODO(b/151232436): change to an explicit barrier if we have it. dummy_value = array_ops.identity([]) logging.info("Waiting for the cluster, timeout = %s", self._check_health_initial_timeout or "inf") try: self._host_cross_device_ops.reduce( reduce_util.ReduceOp.SUM, dummy_value, dummy_value, options=collective_util.Options( timeout_seconds=self._check_health_initial_timeout, implementation=collective_util.CommunicationImplementation. RING)) if context.is_async(): context.async_wait() except errors.DeadlineExceededError: raise RuntimeError( "Timeout waiting for the cluster, timeout is %d seconds" % self._check_health_initial_timeout) logging.info("Cluster is ready.") self._check_health_thread_should_stop = threading.Event() # Start the thread as daemon to avoid it blocking the program from exiting. # We try best to shutdown the thread but __del__ is not guaranteed to be # called when program exists. self._check_health_thread = threading.Thread(target=self._check_health, daemon=True) self._check_health_thread.start()
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( type(func))) is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if eager: result = gen_script_ops.eager_py_func( input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func( input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless( input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def __call__(self, device, token, args): """Passes `args` to `self._func`, which is executed eagerly.""" func_executor = executor.Executor(context.is_async()) with context.executor_scope(func_executor): with context.eager_mode(), backprop.GradientTape() as tape: # Only watch tensors with a floating dtype. for tensor in args: for t in nest.flatten(tensor): if t.dtype.is_floating: tape.watch(t) ret = self._func(*args) # Use tf.identity to copy the returned tensors to device if necessary. with ops.device(device): if isinstance(ret, (tuple, list)): outputs = [ array_ops.identity(self._convert(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: outputs = None else: outputs = array_ops.identity( self._convert(ret, dtype=self._out_dtypes[0])) tape_cache[compat.as_bytes(token)] = (tape, args, outputs) return outputs func_executor.wait()
def _internal_py_func(func, inp, Tout, stateful=None, use_eager_py_func=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError( f"Expected func to be callable. Received func={func} of type " f"{type(func)}.") original_func = func func = autograph.do_not_convert(func) inp = list(inp) # Normalize Tout. is_list_or_tuple = isinstance(Tout, (list, tuple)) Tout = Tout if is_list_or_tuple else [Tout] Tout = [_as_dtype_or_type_spec(t) for t in Tout] # Check if we need to handle CompositeTensor inputs or outputs. handle_composite_tensors = (use_eager_py_func and (any( isinstance(v, composite_tensor.CompositeTensor) for v in inp) or any(isinstance(t, type_spec.TypeSpec) for t in Tout))) if handle_composite_tensors: func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout) if use_eager_py_func: func = EagerFunc(func, Tout, is_grad_func) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in # between model training end evaluation, via saved_model. Those binaries work # because the original function is global, and break once the registered # function is an anonymous lambda, like the one produced by do_not_convert. # To avoid breaking those cases, we attach the wrapper to the original # function so that their lifetime is connected. # TODO(b/144286616): Remove this. if tf_inspect.isfunction(original_func): # Note: this check is needed because original_func may be a descriptor # (https://docs.python.org/3/howto/descriptor.html) # and we can't attach attributes to those. original_func.ag_dnc_wrapper__ = func token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if use_eager_py_func: result = gen_script_ops.eager_py_func(input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) if handle_composite_tensors and Tout: result = nest.pack_sequence_as(out_structure, result, expand_composites=True) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None, use_tape_cache=True): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError( "Expected func to be callable, got func of type {}".format( type(func))) original_func = func func = autograph.do_not_convert(func) is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in # between model training end evaluation, via saved_model. Those binaries work # because the original function is global, and break once the registered # function is an anonymous lambda, like the one produced by do_not_convert. # To avoid breaking those cases, we attach the wrapper to the original # function so that their lifetime is connected. # TODO(b/144286616): Remove this. if tf_inspect.isfunction(original_func): # Note: this check is needed because original_func may be a descriptor # (https://docs.python.org/3/howto/descriptor.html) # and we can't attach attributes to those. original_func.ag_dnc_wrapper__ = func token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if eager: result = gen_script_ops.eager_py_func(input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]