Ejemplo n.º 1
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
   if not self.fwd or not self.bwd:
     msg = "No VJP defined for custom_vjp function {} using defvjp."
     raise AttributeError(msg.format(self.__name__))
   args = _resolve_kwargs(self.fun, args, kwargs)
   if config.jax_enable_custom_vjp_by_custom_transpose:
     if self.nondiff_argnums:
       raise NotImplementedError(
           'nondiff_argnums not implemented for new custom_vjp')
     return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args)
   else:
     if self.nondiff_argnums:
       for i in self.nondiff_argnums: _check_for_tracers(args[i])
       nondiff_argnums = set(self.nondiff_argnums)
       dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
       f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                      args, require_static_args_hashable=False)
       static_args = [args[i] for i in self.nondiff_argnums]
       fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
                                require_static_args_hashable=False)
       bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
       f_, dyn_args = lu.wrap_init(self.fun), args
       fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
                                       *args_flat, out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Ejemplo n.º 2
0
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
                    cotangent_in_avals, reduce_axes):
    # backward_pass can only transpose linear computations, but the call_jaxpr embedded in
    # remat contains primal (non-linear) equations too. Hence, we have to eliminate those
    # (in this case via partial_eval) before we call into backward_pass again.
    typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
    unknowns = map(is_undefined_primal, primals_in)
    primal_jaxpr, tangent_jaxpr, out_unknowns = \
      pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True)  # type: ignore

    def do_transpose(primals_in, cotangents_in):
        # NOTE: This is passing in undefined primals in place of tangent arguments, but it
        #       should all work out, because we're only computing the primal part here.
        residuals = core.jaxpr_as_fun(primal_jaxpr)(
            *primals_in)[len(cotangents_in):]
        # Now that we have a purely linear jaxpr, we can transpose it
        cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (),
                                       primals_in + residuals, cotangents_in)
        # backward_pass will return cotangents computed for all invars, but some of them
        # are residuals appended by partial eval, so we need to skip those before we return.
        return cotangents_out[:len(primals_in)]

    flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
    flat_do_transpose, out_tree = flatten_fun_nokwargs(
        lu.wrap_init(do_transpose), in_tree_def)
    flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args,
                                               **params)
    return tree_unflatten(out_tree(), flat_cotangents_out)
Ejemplo n.º 3
0
  def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
    in_batched_ps, in_batched_ts = in_batched

    mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts)
    extra_batched_ps = tree_map(lambda pb, tb: 0 if pb and not tb else None,
                                in_batched_ps, in_batched_ts)
    extra_batched_ts = tree_map(lambda pb, tb: 0 if tb and not pb else None,
                                in_batched_ps, in_batched_ts)

    out_mutually_batched = lu.Store()
    flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents))
    flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten(
        (extra_batched_ps, extra_batched_ts),
        is_leaf=lambda x: x is None)

    # TODO(frostig): assert these also equal:
    #   treedef_tuple((in_tree, in_tree))
    # once https://github.com/google/jax/issues/9066 is fixed
    assert tree_ps_ts == tree_ps_ts2
    del tree_ps_ts2

    def to_jvp(*primals):
      out, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
      check_vmap_rule_trees(
          rule, out_tree, tree_structure(out), tree_structure(out_batched))
      out_mutually_batched.store(out_batched)
      return out

    def to_vmap_over_extra_batched_dims(primals, tangents):
      return jax.jvp(to_jvp, primals, tangents)

    to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
        lu.wrap_init(to_vmap_over_extra_batched_dims),
        tree_ps_ts)

    flat_out_ps_ts, flat_out_axes = vmap_unrestricted(
        to_vmap_over_extra_batched_dims_flat, *flat_ps_ts,
        in_axes=flat_extra_batched_ps_ts,
        axis_name=core.no_axis_name, axis_size=axis_size)

    n, ragged = divmod(len(flat_out_ps_ts), 2)
    assert not ragged
    flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:]
    flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:]
    flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p)
    flat_out_extra_batched_ps = [d is not not_mapped for d in flat_out_axes_p]
    flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t)
    flat_out_extra_batched_ts = [d is not not_mapped for d in flat_out_axes_t]

    out_ps, out_ts = tree_unflatten(
        out_tree2(), [*flat_out_ps, *flat_out_ts])
    out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
        out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

    out_batched_ps = tree_map(
        operator.or_, out_mutually_batched.val, out_extra_batched_ps)
    out_batched_ts = tree_map(
        operator.or_, out_mutually_batched.val, out_extra_batched_ts)

    return (out_ps, out_ts), (out_batched_ps, out_batched_ts)
Ejemplo n.º 4
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
     if not self.jvp:
         msg = "No JVP defined for custom_jvp function {} using defjvp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         nondiff_argnums = set(self.nondiff_argnums)
         args = tuple(
             _stop_gradient(x) if i in nondiff_argnums else x
             for i, x in enumerate(args))
         diff_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
                                        diff_argnums,
                                        args,
                                        require_static_args_hashable=False)
         static_args = [args[i] for i in self.nondiff_argnums]
         jvp = _add_args(lu.wrap_init(self.jvp), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         jvp = lu.wrap_init(self.jvp)
     args_flat, in_tree = tree_flatten(dyn_args)
     flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
     flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
     out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
     _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
     return tree_unflatten(out_tree, out_flat)
Ejemplo n.º 5
0
 def fwd(*args, **kwargs):
     ans, rule = fun(*args, **kwargs)
     ans_flat, out_tree = tree_flatten((ans, ))
     rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
     ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
     return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
Ejemplo n.º 6
0
 def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]:
   spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
   in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
   jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
   result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
   if len(out_avals_flat) != len(result):
     raise Exception("Internal: eval_sparse does not return expected number of arguments. "
                     "Got {result} for avals {out_avals_flat}")
   return result, out_tree()
Ejemplo n.º 7
0
Archivo: ad.py Proyecto: jbampton/jax
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
                    cotangent_in_avals, reduce_axes):
  call_jaxpr = _close_jaxpr(call_jaxpr)
  unknowns = map(is_undefined_primal, primals_in)
  primal_jaxpr, tangent_jaxpr, _ = \
    pe.partial_eval_jaxpr(call_jaxpr, unknowns=unknowns, instantiate=True)  # type: ignore
  args, in_tree_def = tree_flatten((primals_in, cotangents_in))
  transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
                                  tangent_jaxpr, reduce_axes)
  flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree_def)
  flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params)
  return tree_unflatten(out_tree(), flat_cotangents_out)
Ejemplo n.º 8
0
 def __call__(self, *args, **kwargs):
   if self.ivjp is None:
     msg = "No IVJP defined for custom_vjp function {}. Did you forget to use defivjp?"
     raise AttributeError(msg.format(self.__name__))
   args = custom_derivatives._resolve_kwargs(self.fun, args, kwargs)
   # TODO: Support nondiff_argnums
   fun, ivjp = lu.wrap_init(self.fun), lu.wrap_init(self.ivjp)
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
   flat_ivjp = _flatten_ivjp(ivjp, in_tree, out_tree)
   out_flat = _custom_ivjp(flat_fun, flat_ivjp, args_flat)
   return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 9
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
    all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
    fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                              reduce_axes)
    fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
    new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
    update_params = call_transpose_param_updaters.get(primitive)
    if update_params:
        new_params = update_params(new_params, map(is_undefined_primal, args),
                                   [type(x) is not Zero for x in ct])
    out_flat = primitive.bind(fun, *all_args, **new_params)
    return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 10
0
 def fwd(*args):
   flat_args, in_tree = tree_flatten(args)
   in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args)
   fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals)
   # TODO: Don't warn if consts contain JVP tracers?
   if consts:
     warnings.warn("Values that an @invertible function closes over will not have their " +
                   "gradients computed correctly (their uses inside this function will be ignored)!")
   # TODO: This requires the body to be jittable, but this shouldn't be necessary.
   #       Is there a way to trace a jaxpr while running it?
   flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
   return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
Ejemplo n.º 11
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 12
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                            reduce_axes, False)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  if not config.jax_experimental_name_stack:
    params = dict(params, name=wrap_name(params['name'], 'transpose'))
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    params = update_params(params, map(is_undefined_primal, args),
                           [type(x) is not Zero for x in ct])
  if config.jax_dynamic_shapes:
    in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
    fun = lu.annotate(fun, tuple(in_type))
  out_flat = primitive.bind(fun, *all_args, **params)
  return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 13
0
Archivo: ad.py Proyecto: John1Tang/jax
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
    all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
    fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                              reduce_axes, False)
    fun, nz_arg_cts = nonzero_outputs(fun)
    fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
    # Preserve axis for primal arguments, skip tangents (represented as undefined primals).
    in_axes, out_axes = params['in_axes'], params['out_axes']
    new_in_axes = (*[
        axis for axis, x in zip(in_axes, args) if not is_undefined_primal(x)
    ], *[axis for axis, x in zip(out_axes, ct) if type(x) is not Zero])
    # The interim strategy we use below (until avals-with-names) only works
    # when all outputs are mapped.
    assert all(out_axis is not None for out_axis in out_axes), out_axes
    # NOTE: This assumes that the output cotangents being zero is a deterministic
    #       function of which input cotangents were zero.
    @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero
                                                  for c in ct)))
    def out_axes_thunk():
        return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts())
                     if nz)

    new_params = dict(params,
                      name=wrap_name(params['name'], 'transpose'),
                      in_axes=new_in_axes,
                      out_axes_thunk=out_axes_thunk)
    del new_params['out_axes']
    update_params = call_transpose_param_updaters.get(primitive)
    if update_params:
        new_params = update_params(new_params, map(is_undefined_primal, args),
                                   [type(x) is not Zero for x in ct])
    out_flat = primitive.bind(fun, *all_args, **new_params)
    arg_cts = tree_unflatten(out_tree(), out_flat)

    # The freevars are being fanned out (not mapped). During transpose the
    # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
    assert len(in_axes) == len(arg_cts)

    def unmap_zero(zero, in_axis):
        return (zero if in_axis is None else Zero(
            core.unmapped_aval(params['axis_size'], params['axis_name'],
                               in_axis, zero.aval)))

    arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
               arg_ct if in_axis is not None else arg_ct.sum(0)
               for arg_ct, in_axis in zip(arg_cts, in_axes))
    return tuple(arg_cts)
Ejemplo n.º 14
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
  wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  out_tree = out_tree()

  (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
  num_consts = len(hoisted_consts)

  def converted_fun(*args_hconsts):
    num_args = len(args_hconsts) - num_consts
    args, hoisted_consts = split_list(args_hconsts, [num_args])
    consts = merge(closure_consts, hoisted_consts)
    all_args, in_tree2 = tree_flatten(tuple(args))
    assert in_tree == in_tree2
    out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
    return tree_unflatten(out_tree, out_flat)

  return converted_fun, hoisted_consts
Ejemplo n.º 15
0
    def __call__(self, residual_arg, linear_arg):
        res_arg, lin_arg = residual_arg, linear_arg
        _, res_tree = tree_flatten(res_arg)
        _, lin_tree = tree_flatten(lin_arg)
        args_flat, in_tree = tree_flatten((res_arg, lin_arg))

        flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun),
                                                  in_tree)
        in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
        debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
        jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
        assert not len(consts)
        closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        out_flat = custom_transpose_p.bind(*consts,
                                           *args_flat,
                                           call=closed_call,
                                           rule=self.transpose,
                                           lin_tree=lin_tree,
                                           res_tree=res_tree,
                                           out_tree=out_tree())
        return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 16
0
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
                linear_args):
    """Call a linear function, with a custom implementation for its transpose.

  The type signatures of ``fun`` and ``fun_transpose`` are:

  .. code-block:: haskell

    fun           :: r -> a -o b
    fun_transpose :: r -> b -o a

  where the ``-o`` arrow indicates a linear function, ``r`` is the
  residual input type and ``a`` is the linear input type.

  The functions ``fun`` and ``fun_transpose`` are coupled as
  transposes of one another. Specifically, the transpose of a
  ``linear_call`` primitive is another ``linear_call`` to
  ``fun_transpose``, with ``fun`` as its custom transposition.

  For example:

  >>> def f(r, x):
  ...   return x / r

  >>> def t(r, t):
  ...   return t / r

  >>> def div_add(x, denom):
  ...   return x + linear_call(f, t, denom, x)

  >>> def transpose(f, x_example):
  ...   def transposed(y):
  ...     x, = jax.linear_transpose(f, x_example)(y)
  ...     return x
  ...   return transposed

  >>> div_add(9., 3.)
  DeviceArray(12., dtype=float32, weak_type=True)

  >>> transpose(partial(div_add, denom=3.), 1.)(18.)  # custom
  DeviceArray(24., dtype=float32, weak_type=True)

  >>> transpose(lambda x: x + x / 3., 1.)(18.)  # reference
  DeviceArray(24., dtype=float32, weak_type=True)

  The above definition of ``f`` illustrates the purpose of a residual
  argument: division is linear in one of its inputs (the dividend
  ``x``) but not the other (the divisor ``r``).

  As another example:

  >>> def custom_id(x):
  ...   def f(_, x): return x
  ...   def t(_, t): return 7.
  ...   return linear_call(f, t, (), x)
  >>> custom_id(1.)
  1.0
  >>> transpose(custom_id, 1.)(1.)
  7.0
  >>> transpose(transpose(custom_id, 1.), 1.)(1.)
  1.0
  >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
  7.0

  Args:
    fun: a Python callable specifying a linear function. It should
      take two arguments: one of "residual" inputs (type ``r``),
      i.e. inputs in which the function is not necessarly linear, and
      one of "linear" inputs (type ``a``).  It should return output
      whose components are linear in the linear input (type ``b``).
    fun_transpose: a Python callable specifying a structurally linear
      function that is the transpose of ``fun`` with respect to its
      linear inputs. Its first argument is the same residual inputs
      (``r``) as ``fun``. Its second argument is of type
      ``b``. Finally, its output is of type ``a`` and each of its
      component are linear in its second argument (the ``b`` inputs).
    residual_args: Argument in which ``fun`` and ``fun_transpose`` are
      not necessarily linear. Not involved in transposition.
    linear_args: Argument in which ``fun`` and ``fun_transpose`` are
      linear and with respect to which the two are transposes.

  Returns:
    The call result, i.e. ``fun(residual_args, linear_args)``.

  """
    operands_res, res_tree = tree_flatten(residual_args)
    operands_lin, lin_tree = tree_flatten(linear_args)

    f_in_tree = treedef_tuple((res_tree, lin_tree))
    f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

    res_avals = map(abstractify, operands_res)
    lin_avals = map(abstractify, operands_lin)
    f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
    f_jaxpr = _close_jaxpr(f_jaxpr)
    out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)

    t_in_tree = treedef_tuple((res_tree, out_tree()))
    t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose),
                                         t_in_tree)

    t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
    t_jaxpr = _close_jaxpr(t_jaxpr)

    if t_out_tree() != lin_tree:
        raise TypeError(
            'transpose output pytree structure must match that of linear inputs, '
            f'got output structure {t_out_tree()} '
            f'and input structure {lin_tree}.')

    out = linear_call_p.bind(*f_consts,
                             *t_consts,
                             *operands_res,
                             *operands_lin,
                             callee=f_jaxpr,
                             transpose=t_jaxpr,
                             num_callee_consts=len(f_consts),
                             num_transpose_consts=len(t_consts),
                             num_res=len(operands_res))

    return tree_unflatten(out_tree(), out)
Ejemplo n.º 17
0
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
  if all(type(ct) is ad.Zero for ct in cotangents_in):
    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)

  def write_cotangent(v, ct):
    # assert v not in primal_env
    if ct is not None and type(v) is not Literal:
      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct

  def read_cotangent(v):
    return ct_env.get(v, ad.Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, ad.UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if type(v) is Literal:
      return
    primal_env.setdefault(v, val)

  # Invert while computing cotangents
  ct_env: Dict[Any, Any] = {}
  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.invars, primals_in)
  map(write_primal, jaxpr.outvars, primals_out)
  map(write_primal, jaxpr.constvars, consts)
  map(write_cotangent, jaxpr.outvars, cotangents_in)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.invars)
    primals_out = map(read_primal, eqn.outvars)
    cts_in = map(read_cotangent, eqn.outvars)
    should_invert = any(type(primal) is not ad.UndefinedPrimal
                        for primal in primals_out)
    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
    assert not eqn.primitive.call_primitive

    # Skip primals equations that are only jvp coefficients and don't affect
    # primal outputs.
    if not should_invert and not should_vjp:
      continue

    def abstract(value):
      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))

    # Get the ivjp_jaxpr
    if eqn.primitive is custom_ivjp_p:
      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
    else:
      if eqn.primitive in primitive_ivjps:
        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
      else:
        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
      _, in_tree = tree_flatten(
          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)

      in_avals = map(abstract, primals_in + primals_out + primals_out)
      # TODO: Actually we do know some of the inputs, because they might be literals!
      ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
      ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])

    # Once we know what the ivjp can do exactly, we have to isolate the part we are
    # actually able to compute with the values we have at hand.
    num_inputs = len(eqn.invars)
    unknowns = (map(ad.is_undefined_primal, primals_in) +
                map(ad.is_undefined_primal, primals_out) +
                [False] * len(cts_in))
    jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
        ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
    # Make sure we're able to compute all cotangents. We don't really care if we
    # can reconstruct or primals or not, although failure to do so might result in
    # failing to compute cotangents later.
    assert not any(unknown_cotangents)
    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
    # TODO: We could drop the outputs that correspond to primals that we already know.
    #       This only matters in eager mode, so leaving it out for now...
    ivjp = core.jaxpr_as_fun(jaxpr_known)
    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
                                         [num_inputs])
    # Unknown rec_primals_in are core.units, so we have to replace them
    # with UnknownPrimals because that's what write_primal will ignore.
    rec_primals_in = [prev if unknown else rec
                      for prev, rec, unknown
                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
    map(write_primal, eqn.invars, rec_primals_in)
    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)

  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
  #       transpose is to return them in positions associated with tangent variables, which
  #       is what causes this whole confusion.
  return map(read_cotangent, jaxpr.invars)
Ejemplo n.º 18
0
 def wrapped_fun(*args):
   args_flat, in_tree  = tree_flatten(args)
   f = lu.wrap_init(fun)
   flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
   out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls)
   return tree_unflatten(out_tree(), out_flat)
Ejemplo n.º 19
0
 def _wrapped(*args):
   args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
   out = sparsify_fun(wrapped_fun, args_flat)
   return tree_unflatten(out_tree(), out)