Example #1
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
             _check_for_tracers(args[i])
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                        args)
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
                                       flat_fwd,
                                       flat_bwd,
                                       *args_flat,
                                       out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Example #2
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
     if not self.jvp:
         msg = "No JVP defined for custom_jvp function {} using defjvp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         nondiff_argnums = set(self.nondiff_argnums)
         args = tuple(
             _stop_gradient(x) if i in nondiff_argnums else x
             for i, x in enumerate(args))
         diff_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
                                        diff_argnums,
                                        args,
                                        require_static_args_hashable=False)
         static_args = [args[i] for i in self.nondiff_argnums]
         jvp = _add_args(lu.wrap_init(self.jvp), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         jvp = lu.wrap_init(self.jvp)
     args_flat, in_tree = tree_flatten(dyn_args)
     flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
     flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
     out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
     _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
     return tree_unflatten(out_tree, out_flat)
Example #3
0
 def process_custom_jvp_call(self, prim, fun, jvp, tracers):
     in_vals = [t.val for t in tracers]
     e = popattr(self.main, 'error')
     msgs = tuple(e.msgs.items())
     fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
     jvp, msgs2 = checkify_custom_jvp_subtrace(jvp, self.main, msgs)
     err, code, *out_vals = prim.bind(fun, jvp, e.err, e.code, *in_vals)
     fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
     setattr(self.main, 'error', Error(err, code, out_msgs))
     return [CheckifyTracer(self, x) for x in out_vals]
Example #4
0
 def process_custom_jvp_call(self, prim, fun, jvp, tracers):
     in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
     fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
     jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
     out_vals = prim.bind(fun, jvp, *in_vals)
     fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
     if not fst:
         assert out_dims == out_dims[:len(out_dims) // 2] * 2
         out_dims = out_dims[:len(out_dims) // 2]
     return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
Example #5
0
 def bind(self, fun, jvp, *args):
   args = map(core.full_lower, args)
   top_trace = core.find_top_trace(args)
   fun, env_trace_todo1 = core.process_env_traces(
       fun, self, top_trace and top_trace.level, (), None)
   jvp, env_trace_todo2 = core.process_env_traces(
       jvp, self, top_trace and top_trace.level, (), None)
   tracers = map(top_trace.full_raise, args)  # type: ignore
   outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)  # type: ignore
   _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
   return _apply_todos(env_trace_todo, map(core.full_lower, outs))
Example #6
0
 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
   in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
   axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
                 if d is not not_mapped}
   fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
   fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
   bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
                              out_dims2, in_dims, self.main.trace_type)
   out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
   fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
   if not fst:
     out_dims = out_dims[-len(out_vals) % len(out_dims):]
   src = source_info_util.current()
   return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
Example #7
0
 def bind(self, fun, fwd, bwd, *args, out_trees):
     args = map(core.full_lower, args)
     top_trace = core.find_top_trace(args)
     fun, env_trace_todo1 = core.process_env_traces(
         fun, self, top_trace and top_trace.level, (), None)
     fwd, env_trace_todo2 = core.process_env_traces(
         fwd, self, top_trace and top_trace.level, (), None)
     tracers = map(top_trace.full_raise, args)  # type: ignore
     with core.maybe_new_sublevel(top_trace):
         outs = top_trace.process_custom_vjp_call(self,
                                                  fun,
                                                  fwd,
                                                  bwd,
                                                  tracers,
                                                  out_trees=out_trees)
     _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1,
                                             env_trace_todo2)
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))
 def bind(self, fun, fwd, bwd, *args, out_trees):
   args = map(core.full_lower, args)
   top_trace = core.find_top_trace(args)
   fun, env_trace_todo1 = process_env_traces(
       fun, self, top_trace and top_trace.level, False)
   fwd, env_trace_todo2 = process_env_traces_fwd(
     fwd, top_trace and top_trace.level, out_trees)
   tracers = map(top_trace.full_raise, args)  # type: ignore
   bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
   outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
                                            out_trees=out_trees)
   fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
   if fst:
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))
   else:
     env_trace_todo, bwd_transform = env_trace_todo
     bwd = _apply_bwd_transform(bwd_transform, bwd)
     return _apply_todos(env_trace_todo, map(core.full_lower, outs))
Example #9
0
 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
     in_vals = [t.val for t in tracers]
     e = popattr(self.main, 'error')
     msgs = tuple(e.msgs.items())
     fun, msgs1 = checkify_subtrace(fun, self.main, msgs)
     fwd, msgs2 = checkify_custom_vjp_subtrace(fwd, self.main, msgs)
     out = prim.bind(fun,
                     fwd,
                     bwd,
                     e.err,
                     e.code,
                     *in_vals,
                     out_trees=out_trees)
     fst, out_msgs = lu.merge_linear_aux(msgs1, msgs2)
     if fst:
         err, code, *out = out
     else:
         err, code = e.err, e.code  # forward input error values to output
     setattr(self.main, 'error', Error(err, code, out_msgs))
     return [CheckifyTracer(self, x) for x in out]