def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: if not self.fwd or not self.bwd: msg = "No VJP defined for custom_vjp function {} using defvjp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args) static_args = [args[i] for i in self.nondiff_argnums] fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) fst, aux = lu.merge_linear_aux(out_tree, out_trees) out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat)
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.jvp: msg = "No JVP defined for custom_jvp function {} using defjvp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple( _stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) return tree_unflatten(out_tree, out_flat)
def process_custom_jvp_call(self, prim, fun, jvp, tracers): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') msgs = tuple(e.msgs.items()) fun, msgs1 = checkify_subtrace(fun, self.main, msgs) jvp, msgs2 = checkify_custom_jvp_subtrace(jvp, self.main, msgs) err, code, *out_vals = prim.bind(fun, jvp, e.err, e.code, *in_vals) fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2) setattr(self.main, 'error', Error(err, code, out_msgs)) return [CheckifyTracer(self, x) for x in out_vals]
def process_custom_jvp_call(self, prim, fun, jvp, tracers): in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) out_vals = prim.bind(fun, jvp, *in_vals) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: assert out_dims == out_dims[:len(out_dims) // 2] * 2 out_dims = out_dims[:len(out_dims) // 2] return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
def bind(self, fun, jvp, *args): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) jvp, env_trace_todo2 = core.process_env_traces( jvp, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, out_dims2, in_dims, self.main.trace_type) out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: out_dims = out_dims[-len(out_vals) % len(out_dims):] src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = core.process_env_traces( fun, self, top_trace and top_trace.level, (), None) fwd, env_trace_todo2 = core.process_env_traces( fwd, self, top_trace and top_trace.level, (), None) tracers = map(top_trace.full_raise, args) # type: ignore with core.maybe_new_sublevel(top_trace): outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, out_trees=out_trees) _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def bind(self, fun, fwd, bwd, *args, out_trees): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = process_env_traces( fun, self, top_trace and top_trace.level, False) fwd, env_trace_todo2 = process_env_traces_fwd( fwd, top_trace and top_trace.level, out_trees) tracers = map(top_trace.full_raise, args) # type: ignore bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args)) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, out_trees=out_trees) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if fst: return _apply_todos(env_trace_todo, map(core.full_lower, outs)) else: env_trace_todo, bwd_transform = env_trace_todo bwd = _apply_bwd_transform(bwd_transform, bwd) return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') msgs = tuple(e.msgs.items()) fun, msgs1 = checkify_subtrace(fun, self.main, msgs) fwd, msgs2 = checkify_custom_vjp_subtrace(fwd, self.main, msgs) out = prim.bind(fun, fwd, bwd, e.err, e.code, *in_vals, out_trees=out_trees) fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2) if fst: err, code, *out = out else: err, code = e.err, e.code # forward input error values to output setattr(self.main, 'error', Error(err, code, out_msgs)) return [CheckifyTracer(self, x) for x in out]