def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.fwd or not self.bwd: msg = "No VJP defined for custom_vjp function {} using defvjp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: raise NotImplementedError( 'nondiff_argnums not implemented for new custom_vjp') return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args) else: if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args, require_static_args_hashable=False) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) fst, aux = lu.merge_linear_aux(out_tree, out_trees) out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat)
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.jvp: msg = "No JVP defined for custom_jvp function {} using defjvp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple( _stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) return tree_unflatten(out_tree, out_flat)
def value_and_jacfwd_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) pushfwd = partial(_jvp, f_partial, dyn_args) y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
def wrapped(*args, **kwargs): if kwargs: raise NotImplementedError( "sharded_jit over kwargs not yet supported") f = lu.wrap_init(fun) if static_argnums: if max(static_argnums) >= len(args): raise ValueError( f"jitted function has static_argnums={static_argnums}" f" but was called with only {len(args)} positional " f"argument{'s' if len(args) > 1 else ''}. " "All static broadcasted arguments must be passed positionally." ) dyn_argnums = [ i for i in range(len(args)) if i not in static_argnums ] f, args = argnums_partial(f, dyn_argnums, args) args_flat, in_tree = tree_flatten((args, kwargs)) in_parts_flat = tuple( flatten_axes("sharded_jit in_parts", in_tree.children()[0], in_parts)) if local_in_parts is not None: local_in_parts_flat = tuple( flatten_axes("sharded_jit local_in_parts", in_tree.children()[0], local_in_parts)) else: local_in_parts_flat = None flat_fun, out_tree = flatten_fun(f, in_tree) # TODO(skye): having a function-typed param in a primitive seems dicey, is # there a better way? out_parts_thunk = HashableFunction(lambda: tuple( flatten_axes("sharded_jit out_parts", out_tree(), out_parts)), closure=out_parts) if local_out_parts: local_out_parts_thunk = HashableFunction(lambda: tuple( flatten_axes("sharded_jit local_out_parts", out_tree(), local_out_parts)), closure=local_out_parts) else: local_out_parts_thunk = HashableFunction(lambda: None, closure=None) out = sharded_call(flat_fun, *args_flat, nparts=nparts, in_parts=in_parts_flat, out_parts_thunk=out_parts_thunk, local_in_parts=local_in_parts_flat, local_out_parts_thunk=local_out_parts_thunk, local_nparts=local_nparts, name=flat_fun.__name__) return tree_unflatten(out_tree(), out)
def value_and_jacrev_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) if not has_aux: return y, tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) else: return (y, aux), tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) return