def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
   if not self.fwd or not self.bwd:
     msg = "No VJP defined for custom_vjp function {} using defvjp."
     raise AttributeError(msg.format(self.__name__))
   args = _resolve_kwargs(self.fun, args, kwargs)
   if config.jax_enable_custom_vjp_by_custom_transpose:
     if self.nondiff_argnums:
       raise NotImplementedError(
           'nondiff_argnums not implemented for new custom_vjp')
     return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args)
   else:
     if self.nondiff_argnums:
       for i in self.nondiff_argnums: _check_for_tracers(args[i])
       nondiff_argnums = set(self.nondiff_argnums)
       dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
       f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                      args, require_static_args_hashable=False)
       static_args = [args[i] for i in self.nondiff_argnums]
       fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
                                require_static_args_hashable=False)
       bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
       f_, dyn_args = lu.wrap_init(self.fun), args
       fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
                                       *args_flat, out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Esempio n. 2
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
     if not self.jvp:
         msg = "No JVP defined for custom_jvp function {} using defjvp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         nondiff_argnums = set(self.nondiff_argnums)
         args = tuple(
             _stop_gradient(x) if i in nondiff_argnums else x
             for i, x in enumerate(args))
         diff_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
                                        diff_argnums,
                                        args,
                                        require_static_args_hashable=False)
         static_args = [args[i] for i in self.nondiff_argnums]
         jvp = _add_args(lu.wrap_init(self.jvp), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         jvp = lu.wrap_init(self.jvp)
     args_flat, in_tree = tree_flatten(dyn_args)
     flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
     flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
     out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
     _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
     return tree_unflatten(out_tree, out_flat)
Esempio n. 3
0
 def value_and_jacfwd_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
     pushfwd = partial(_jvp, f_partial, dyn_args)
     y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
     tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
Esempio n. 4
0
    def wrapped(*args, **kwargs):
        if kwargs:
            raise NotImplementedError(
                "sharded_jit over kwargs not yet supported")

        f = lu.wrap_init(fun)
        if static_argnums:
            if max(static_argnums) >= len(args):
                raise ValueError(
                    f"jitted function has static_argnums={static_argnums}"
                    f" but was called with only {len(args)} positional "
                    f"argument{'s' if len(args) > 1 else ''}. "
                    "All static broadcasted arguments must be passed positionally."
                )
            dyn_argnums = [
                i for i in range(len(args)) if i not in static_argnums
            ]
            f, args = argnums_partial(f, dyn_argnums, args)

        args_flat, in_tree = tree_flatten((args, kwargs))
        in_parts_flat = tuple(
            flatten_axes("sharded_jit in_parts",
                         in_tree.children()[0], in_parts))
        if local_in_parts is not None:
            local_in_parts_flat = tuple(
                flatten_axes("sharded_jit local_in_parts",
                             in_tree.children()[0], local_in_parts))
        else:
            local_in_parts_flat = None

        flat_fun, out_tree = flatten_fun(f, in_tree)
        # TODO(skye): having a function-typed param in a primitive seems dicey, is
        # there a better way?
        out_parts_thunk = HashableFunction(lambda: tuple(
            flatten_axes("sharded_jit out_parts", out_tree(), out_parts)),
                                           closure=out_parts)
        if local_out_parts:
            local_out_parts_thunk = HashableFunction(lambda: tuple(
                flatten_axes("sharded_jit local_out_parts", out_tree(),
                             local_out_parts)),
                                                     closure=local_out_parts)
        else:
            local_out_parts_thunk = HashableFunction(lambda: None,
                                                     closure=None)

        out = sharded_call(flat_fun,
                           *args_flat,
                           nparts=nparts,
                           in_parts=in_parts_flat,
                           out_parts_thunk=out_parts_thunk,
                           local_in_parts=local_in_parts_flat,
                           local_out_parts_thunk=local_out_parts_thunk,
                           local_nparts=local_nparts,
                           name=flat_fun.__name__)
        return tree_unflatten(out_tree(), out)
Esempio n. 5
0
 def value_and_jacrev_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if not has_aux:
         y, pullback = _vjp(f_partial, *dyn_args)
     else:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
     if not has_aux:
         return y, tree_transpose(tree_structure(example_args),
                                  tree_structure(y), jac_tree)
     else:
         return (y, aux), tree_transpose(tree_structure(example_args),
                                         tree_structure(y), jac_tree)
     return