Ejemplo n.º 1
0
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
                   donated_invars, inline):
    del inline  # Only used at tracing time
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
                                 *unsafe_map(arg_spec, args))
    try:
        out = compiled_fun(*args)
    except FloatingPointError:
        assert config.jax_debug_nans or config.jax_debug_infs  # compiled_fun can only raise in this case
        print(
            "Invalid value encountered in the output of a jit/pmap-ed function. "
            "Calling the de-optimized version.")
        # We want to run the wrapped function again (after _xla_callable already ran
        # it), but linear_util.WrappedFun instances are meant to be run only once.
        # In addition to re-executing the Python code, which is usually undesirable
        # but which config.jax_debug_nans is meant to opt into, we'll be re-executing
        # any linear_util.py-style side effects, i.e. re-populating Stores created
        # by any transformation_with_aux's applied to fun. Since this is
        # intentional here, to avoid "Store occupied" errors we clone the WrappedFun
        # with empty stores.
        stores = [lu.Store() for _ in fun.stores]
        clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params)
        with core.new_sublevel():
            _ = clone.call_wrapped(*args)  # probably won't return
    return out
Ejemplo n.º 2
0
 def bind(self, f, *args, **params):
     top_trace = jax_core.find_top_trace(args)
     trace_stack = jax_core.thread_local_state.trace_state.trace_stack
     level = (trace_stack.next_level(True)
              if top_trace is None else top_trace.level)
     params_tuple = tuple(params.items())
     f, env_trace_todo = jax_core.process_env_traces(
         f, self, level, params_tuple)
     if top_trace is None:
         with jax_core.new_sublevel():
             outs = self.impl(f, *args, **params)
     else:
         tracers = safe_map(top_trace.full_raise, args)
         if (isinstance(top_trace, batching.BatchTrace)
                 and self in custom_batch_rules):
             outs = custom_batch_rules[self](top_trace, f, tracers, params)
         else:
             if isinstance(top_trace, ad.JVPTrace):
                 prim = self.subcall('jvp')
             else:
                 prim = self
             outs = safe_map(
                 jax_core.full_lower,
                 top_trace.process_call(prim, f, tracers, params))
     return jax_core.apply_todos(env_trace_todo(), outs)
Ejemplo n.º 3
0
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
                   donated_invars, inline, keep_unused: bool):
    del inline  # Only used at tracing time
    arg_specs = unsafe_map(arg_spec, args)
    if fun.in_type is not None:
        arg_specs = [(None, *xs) for _, *xs in arg_specs]
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
                                 keep_unused, *arg_specs)
    try:
        return compiled_fun(*args)
    except FloatingPointError:
        assert config.jax_debug_nans or config.jax_debug_infs  # compiled_fun can only raise in this case
        print(
            "Invalid value encountered in the output of a jit-decorated function. "
            "Calling the de-optimized version.")
        # We want to run the wrapped function again (after _xla_callable already ran
        # it), but linear_util.WrappedFun instances are meant to be run only once.
        # In addition to re-executing the Python code, which is usually undesirable
        # but which config.jax_debug_nans is meant to opt into, we'll be
        # re-executing any linear_util.py-style side effects, i.e. re-populating
        # Stores created by any transformation_with_aux's applied to fun. Since this
        # is intentional here, to avoid "Store occupied" errors we clone the
        # WrappedFun with empty stores.
        stores = [lu.Store() for _ in fun.stores]
        clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params,
                              fun.in_type)

        with core.new_sublevel():
            _ = clone.call_wrapped(*args)  # may raise, not return

        # If control reaches this line, we got a NaN on the output of `compiled_fun`
        # but not `clone.call_wrapped` on the same arguments. Let's tell the user.
        fun_info = pe.fun_sourceinfo(fun.f)
        msg = (
            "An invalid value was encountered in the output of the "
            f"`jit`-decorated function {fun_info}. Because "
            "config.jax_debug_nans and/or config.jax_debug_infs is set, the "
            "de-optimized function (i.e., the function as if the `jit` "
            "decorator were removed) was called in an attempt to get a more "
            "precise error message. However, the de-optimized function did not "
            "produce invalid values during its execution. This behavior can "
            "result from `jit` optimizations causing the invalud value to be "
            "produced. It may also arise from having nan/inf constants as "
            "outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
            "\n\n"
            "It may be possible to avoid the invalid value by removing the "
            "`jit` decorator, at the cost of losing optimizations. "
            "\n\n"
            "If you see this error, consider opening a bug report at "
            "https://github.com/google/jax.")
        raise FloatingPointError(msg)
Ejemplo n.º 4
0
def _nest_impl(f, *args, **_):
    with jax_core.new_sublevel():
        return f.call_wrapped(*args)
Ejemplo n.º 5
0
 def impl(self, fun, fwd, bwd, *args, out_trees):
     del fwd, bwd, out_trees
     with core.new_sublevel():
         return fun.call_wrapped(*args)
Ejemplo n.º 6
0
 def impl(self, fun, _, *args):
     with core.new_sublevel():
         return fun.call_wrapped(*args)
Ejemplo n.º 7
0
 def impl(self, f, *args, **params):
     del params
     with jax_core.new_sublevel():
         return f.call_wrapped(*args)