Exemple #1
0
 def _start_check_health_thread(self):
     # Use a dummy all-reduce as a barrier to wait for all workers to be up,
     # otherwise the check health may fail immediately.
     #
     # TODO(b/151232436): change to an explicit barrier if we have it.
     dummy_value = ops.convert_to_tensor([])
     logging.info("Waiting for the cluster, timeout = %s",
                  self._check_health_initial_timeout or "inf")
     try:
         self._host_cross_device_ops.reduce(
             reduce_util.ReduceOp.SUM,
             dummy_value,
             dummy_value,
             experimental_hints=collective_util.Hints(
                 timeout_seconds=self._check_health_initial_timeout))
         if context.is_async():
             context.async_wait()
     except errors.DeadlineExceededError:
         raise RuntimeError(
             "Timeout waiting for the cluster, timeout is %d seconds" %
             self._check_health_initial_timeout)
     self._check_health_thread_should_stop = threading.Event()
     # Start the thread as daemon to avoid it blocking the program from exiting.
     # We try best to shutdown the thread but __del__ is not guaranteed to be
     # called when program exists.
     self._check_health_thread = threading.Thread(target=self._check_health,
                                                  daemon=True)
     self._check_health_thread.start()
Exemple #2
0
 def _check_health(self, device, group_key, instance_key):
     first = True
     # We need to use a large enough value so that the all-reduce forms a
     # complete RING. In RING implementation, when value is too small, the
     # all-reduce may degrade into broadcasts. This means that some worker
     # failure may not be detected.
     value = array_ops.ones((32, 32), dtype=dtypes.float32)
     while True:
         if self._check_health_thread_should_stop.is_set():
             return
         timeout = None
         if first:
             # For the first check health we set timeout since it may need to do
             # group resolution, which may hang if the cluster is never healthy.
             timeout = self._check_health_initial_timeout
             first = False
         try:
             # We use an dummy all-reduce as a way to check the health of a cluster.
             # For RING it should be able to detect failed workers in the cluster if
             # the values are large enough.
             #
             # We're not using CrossDeviceOps because we need to run it with
             # pre-allocated group and instance keys.
             #
             # TODO(b/151232436): Replace the reduce with a check health op once we
             # add that.
             with ops.device(device):
                 collective_ops.all_reduce(value,
                                           group_size=self._num_workers,
                                           group_key=group_key,
                                           instance_key=instance_key,
                                           merge_op="Add",
                                           final_op="Id",
                                           subdiv_offsets=[0],
                                           communication_hint="ring",
                                           timeout=timeout)
                 if context.is_async():
                     context.async_wait()
         except (errors.UnavailableError, errors.DeadlineExceededError,
                 errors.FailedPreconditionError,
                 errors.CancelledError) as e:
             # TODO(b/151232436): Always raise UnavailableError when a peer fails.
             # Now there could be many kinds of errors:
             # - Unavailable: when the peer is not reachable, e.g. it's down.
             # - FailedPrecondition: when the peer has restarted.
             # - DeadlineExceeded: when the first check health exceeds the deadline,
             #   e.g. the peers take too long to be ready.
             # - Cancelled: when failures in organic collectives aborts first,
             #   outgoing RPCs may be aborted with Cancelled.
             logging.error(
                 "Cluster check alive failed, aborting collectives")
             context.context().abort_collective_ops(
                 errors.UNAVAILABLE, "cluster check alive failed: %s" % e)
         except Exception as e:  # pylint: disable=broad-except
             logging.exception("Unexpected exception in check alive.")
             context.context().abort_collective_ops(
                 errors.INTERNAL,
                 "unexecpted exception in check alive: %s" % e)
             return
         time.sleep(self._check_health_interval)
Exemple #3
0
    def _start_check_health_thread(self):
        # Use a dummy all-reduce as a barrier to wait for all workers to be up,
        # otherwise the check health may fail immediately.

        # Use array_ops.identity to create the dummy tensor so that we have a new
        # Tensor. If we use constant it may be a cached from on a /job:localhost
        # device, which will cause some code that relies on tensor.device to error.
        #
        # TODO(b/151232436): change to an explicit barrier if we have it.
        dummy_value = array_ops.identity([])
        logging.info("Waiting for the cluster, timeout = %s",
                     self._check_health_initial_timeout or "inf")
        try:
            self._host_cross_device_ops.reduce(
                reduce_util.ReduceOp.SUM,
                dummy_value,
                dummy_value,
                options=collective_util.Options(
                    timeout_seconds=self._check_health_initial_timeout,
                    implementation=collective_util.CommunicationImplementation.
                    RING))
            if context.is_async():
                context.async_wait()
        except errors.DeadlineExceededError:
            raise RuntimeError(
                "Timeout waiting for the cluster, timeout is %d seconds" %
                self._check_health_initial_timeout)
        logging.info("Cluster is ready.")
        self._check_health_thread_should_stop = threading.Event()
        # Start the thread as daemon to avoid it blocking the program from exiting.
        # We try best to shutdown the thread but __del__ is not guaranteed to be
        # called when program exists.
        self._check_health_thread = threading.Thread(target=self._check_health,
                                                     daemon=True)
        self._check_health_thread.start()
Exemple #4
0
def _internal_py_func(func,
                      inp,
                      Tout,
                      stateful=None,
                      eager=False,
                      is_grad_func=False,
                      name=None):
  """See documentation for py_func and eager_py_func."""
  if not callable(func):
    raise ValueError("Expected func to be callable, got func of type {}".format(
        type(func)))

  is_list_or_tuple = False
  if isinstance(Tout, (list, tuple)):
    is_list_or_tuple = True
  else:
    Tout = [Tout]

  if eager:
    func = EagerFunc(func, Tout, is_grad_func)

  token = _py_funcs.insert(func)
  # We tie the registered function's lifetime with the current default graph,
  # i.e., when the current graph is destroyed, we remove its py funcs.
  graph = ops.get_default_graph()

  while True:
    current_graph = graph
    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
      graph = graph._outer_graph  # pylint: disable=protected-access
    elif isinstance(graph, func_graph.FuncGraph):
      graph = graph.outer_graph
    if graph is current_graph:
      break

  # TODO(zhifengc): Consider adding a Graph method to collect
  # `cleanup` objects in one of its member.
  if not hasattr(graph, "_py_funcs_used_in_graph"):
    graph._py_funcs_used_in_graph = []  # pylint: disable=protected-access

  # Store a reference to the function in the graph to ensure it stays alive
  # as long as the graph lives. When the graph is destroyed, the function
  # is left to the garbage collector for destruction as well.
  graph._py_funcs_used_in_graph.append(func)  # pylint: disable=protected-access

  if eager:
    result = gen_script_ops.eager_py_func(
        input=inp,
        token=token,
        is_async=context.is_async(),
        Tout=Tout,
        name=name)
  else:
    if stateful:
      result = gen_script_ops.py_func(
          input=inp, token=token, Tout=Tout, name=name)
    else:
      result = gen_script_ops.py_func_stateless(
          input=inp, token=token, Tout=Tout, name=name)
  return result if is_list_or_tuple else result[0]
Exemple #5
0
    def __call__(self, device, token, args):
        """Passes `args` to `self._func`, which is executed eagerly."""

        func_executor = executor.Executor(context.is_async())
        with context.executor_scope(func_executor):
            with context.eager_mode(), backprop.GradientTape() as tape:
                # Only watch tensors with a floating dtype.
                for tensor in args:
                    for t in nest.flatten(tensor):
                        if t.dtype.is_floating:
                            tape.watch(t)
                ret = self._func(*args)
                # Use tf.identity to copy the returned tensors to device if necessary.
                with ops.device(device):
                    if isinstance(ret, (tuple, list)):
                        outputs = [
                            array_ops.identity(self._convert(x, dtype=dtype))
                            for (x, dtype) in zip(ret, self._out_dtypes)
                        ]
                    elif ret is None:
                        outputs = None
                    else:
                        outputs = array_ops.identity(
                            self._convert(ret, dtype=self._out_dtypes[0]))
            tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
            return outputs
        func_executor.wait()
Exemple #6
0
def _internal_py_func(func,
                      inp,
                      Tout,
                      stateful=None,
                      use_eager_py_func=False,
                      is_grad_func=False,
                      name=None):
    """See documentation for py_func and eager_py_func."""
    if not callable(func):
        raise ValueError(
            f"Expected func to be callable. Received func={func} of type "
            f"{type(func)}.")

    original_func = func
    func = autograph.do_not_convert(func)
    inp = list(inp)

    # Normalize Tout.
    is_list_or_tuple = isinstance(Tout, (list, tuple))
    Tout = Tout if is_list_or_tuple else [Tout]
    Tout = [_as_dtype_or_type_spec(t) for t in Tout]

    # Check if we need to handle CompositeTensor inputs or outputs.
    handle_composite_tensors = (use_eager_py_func and (any(
        isinstance(v, composite_tensor.CompositeTensor)
        for v in inp) or any(isinstance(t, type_spec.TypeSpec) for t in Tout)))
    if handle_composite_tensors:
        func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout)

    if use_eager_py_func:
        func = EagerFunc(func, Tout, is_grad_func)

    # Tying the registered function's lifetime with the current default graph is
    # not reliable. For example, Estimator-based binaries may switch graphs in
    # between model training end evaluation, via saved_model. Those binaries work
    # because the original function is global, and break once the registered
    # function is an anonymous lambda, like the one produced by do_not_convert.
    # To avoid breaking those cases, we attach the wrapper to the original
    # function so that their lifetime is connected.
    # TODO(b/144286616): Remove this.
    if tf_inspect.isfunction(original_func):
        # Note: this check is needed because original_func may be a descriptor
        # (https://docs.python.org/3/howto/descriptor.html)
        # and we can't attach attributes to those.
        original_func.ag_dnc_wrapper__ = func

    token = _py_funcs.insert(func)
    # We tie the registered function's lifetime with the current default graph,
    # i.e., when the current graph is destroyed, we remove its py funcs.
    graph = ops.get_default_graph()

    while True:
        current_graph = graph
        if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
            graph = graph._outer_graph  # pylint: disable=protected-access
        elif isinstance(graph, func_graph.FuncGraph):
            graph = graph.outer_graph
        if graph is current_graph:
            break

    # TODO(zhifengc): Consider adding a Graph method to collect
    # `cleanup` objects in one of its member.
    if not hasattr(graph, "_py_funcs_used_in_graph"):
        graph._py_funcs_used_in_graph = []  # pylint: disable=protected-access

    # Store a reference to the function in the graph to ensure it stays alive
    # as long as the graph lives. When the graph is destroyed, the function
    # is left to the garbage collector for destruction as well.
    graph._py_funcs_used_in_graph.append(func)  # pylint: disable=protected-access

    if use_eager_py_func:
        result = gen_script_ops.eager_py_func(input=inp,
                                              token=token,
                                              is_async=context.is_async(),
                                              Tout=Tout,
                                              name=name)
    else:
        if stateful:
            result = gen_script_ops.py_func(input=inp,
                                            token=token,
                                            Tout=Tout,
                                            name=name)
        else:
            result = gen_script_ops.py_func_stateless(input=inp,
                                                      token=token,
                                                      Tout=Tout,
                                                      name=name)

    if handle_composite_tensors and Tout:
        result = nest.pack_sequence_as(out_structure,
                                       result,
                                       expand_composites=True)

    return result if is_list_or_tuple else result[0]
Exemple #7
0
def _internal_py_func(func,
                      inp,
                      Tout,
                      stateful=None,
                      eager=False,
                      is_grad_func=False,
                      name=None,
                      use_tape_cache=True):
    """See documentation for py_func and eager_py_func."""
    if not callable(func):
        raise ValueError(
            "Expected func to be callable, got func of type {}".format(
                type(func)))

    original_func = func
    func = autograph.do_not_convert(func)

    is_list_or_tuple = False
    if isinstance(Tout, (list, tuple)):
        is_list_or_tuple = True
    else:
        Tout = [Tout]

    if eager:
        func = EagerFunc(func,
                         Tout,
                         is_grad_func,
                         use_tape_cache=use_tape_cache)

    # Tying the registered function's lifetime with the current default graph is
    # not reliable. For example, Estimator-based binaries may switch graphs in
    # between model training end evaluation, via saved_model. Those binaries work
    # because the original function is global, and break once the registered
    # function is an anonymous lambda, like the one produced by do_not_convert.
    # To avoid breaking those cases, we attach the wrapper to the original
    # function so that their lifetime is connected.
    # TODO(b/144286616): Remove this.
    if tf_inspect.isfunction(original_func):
        # Note: this check is needed because original_func may be a descriptor
        # (https://docs.python.org/3/howto/descriptor.html)
        # and we can't attach attributes to those.
        original_func.ag_dnc_wrapper__ = func

    token = _py_funcs.insert(func)
    # We tie the registered function's lifetime with the current default graph,
    # i.e., when the current graph is destroyed, we remove its py funcs.
    graph = ops.get_default_graph()

    while True:
        current_graph = graph
        if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
            graph = graph._outer_graph  # pylint: disable=protected-access
        elif isinstance(graph, func_graph.FuncGraph):
            graph = graph.outer_graph
        if graph is current_graph:
            break

    # TODO(zhifengc): Consider adding a Graph method to collect
    # `cleanup` objects in one of its member.
    if not hasattr(graph, "_py_funcs_used_in_graph"):
        graph._py_funcs_used_in_graph = []  # pylint: disable=protected-access

    # Store a reference to the function in the graph to ensure it stays alive
    # as long as the graph lives. When the graph is destroyed, the function
    # is left to the garbage collector for destruction as well.
    graph._py_funcs_used_in_graph.append(func)  # pylint: disable=protected-access

    if eager:
        result = gen_script_ops.eager_py_func(input=inp,
                                              token=token,
                                              is_async=context.is_async(),
                                              Tout=Tout,
                                              name=name)
    else:
        if stateful:
            result = gen_script_ops.py_func(input=inp,
                                            token=token,
                                            Tout=Tout,
                                            name=name)
        else:
            result = gen_script_ops.py_func_stateless(input=inp,
                                                      token=token,
                                                      Tout=Tout,
                                                      name=name)
    return result if is_list_or_tuple else result[0]