def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars, inline): del inline # Only used at tracing time compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, *unsafe_map(arg_spec, args)) try: out = compiled_fun(*args) except FloatingPointError: assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case print( "Invalid value encountered in the output of a jit/pmap-ed function. " "Calling the de-optimized version.") # We want to run the wrapped function again (after _xla_callable already ran # it), but linear_util.WrappedFun instances are meant to be run only once. # In addition to re-executing the Python code, which is usually undesirable # but which config.jax_debug_nans is meant to opt into, we'll be re-executing # any linear_util.py-style side effects, i.e. re-populating Stores created # by any transformation_with_aux's applied to fun. Since this is # intentional here, to avoid "Store occupied" errors we clone the WrappedFun # with empty stores. stores = [lu.Store() for _ in fun.stores] clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params) with core.new_sublevel(): _ = clone.call_wrapped(*args) # probably won't return return out
def bind(self, f, *args, **params): top_trace = jax_core.find_top_trace(args) trace_stack = jax_core.thread_local_state.trace_state.trace_stack level = (trace_stack.next_level(True) if top_trace is None else top_trace.level) params_tuple = tuple(params.items()) f, env_trace_todo = jax_core.process_env_traces( f, self, level, params_tuple) if top_trace is None: with jax_core.new_sublevel(): outs = self.impl(f, *args, **params) else: tracers = safe_map(top_trace.full_raise, args) if (isinstance(top_trace, batching.BatchTrace) and self in custom_batch_rules): outs = custom_batch_rules[self](top_trace, f, tracers, params) else: if isinstance(top_trace, ad.JVPTrace): prim = self.subcall('jvp') else: prim = self outs = safe_map( jax_core.full_lower, top_trace.process_call(prim, f, tracers, params)) return jax_core.apply_todos(env_trace_todo(), outs)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars, inline, keep_unused: bool): del inline # Only used at tracing time arg_specs = unsafe_map(arg_spec, args) if fun.in_type is not None: arg_specs = [(None, *xs) for _, *xs in arg_specs] compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, keep_unused, *arg_specs) try: return compiled_fun(*args) except FloatingPointError: assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case print( "Invalid value encountered in the output of a jit-decorated function. " "Calling the de-optimized version.") # We want to run the wrapped function again (after _xla_callable already ran # it), but linear_util.WrappedFun instances are meant to be run only once. # In addition to re-executing the Python code, which is usually undesirable # but which config.jax_debug_nans is meant to opt into, we'll be # re-executing any linear_util.py-style side effects, i.e. re-populating # Stores created by any transformation_with_aux's applied to fun. Since this # is intentional here, to avoid "Store occupied" errors we clone the # WrappedFun with empty stores. stores = [lu.Store() for _ in fun.stores] clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type) with core.new_sublevel(): _ = clone.call_wrapped(*args) # may raise, not return # If control reaches this line, we got a NaN on the output of `compiled_fun` # but not `clone.call_wrapped` on the same arguments. Let's tell the user. fun_info = pe.fun_sourceinfo(fun.f) msg = ( "An invalid value was encountered in the output of the " f"`jit`-decorated function {fun_info}. Because " "config.jax_debug_nans and/or config.jax_debug_infs is set, the " "de-optimized function (i.e., the function as if the `jit` " "decorator were removed) was called in an attempt to get a more " "precise error message. However, the de-optimized function did not " "produce invalid values during its execution. This behavior can " "result from `jit` optimizations causing the invalud value to be " "produced. It may also arise from having nan/inf constants as " "outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " "\n\n" "It may be possible to avoid the invalid value by removing the " "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " "https://github.com/google/jax.") raise FloatingPointError(msg)
def _nest_impl(f, *args, **_): with jax_core.new_sublevel(): return f.call_wrapped(*args)
def impl(self, fun, fwd, bwd, *args, out_trees): del fwd, bwd, out_trees with core.new_sublevel(): return fun.call_wrapped(*args)
def impl(self, fun, _, *args): with core.new_sublevel(): return fun.call_wrapped(*args)
def impl(self, f, *args, **params): del params with jax_core.new_sublevel(): return f.call_wrapped(*args)