Exemple #1
0
def ignore_errors_jaxpr(jaxpr, error):
    """Constructs a jaxpr which takes two extra args but ignores them."""
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    consts = jaxpr.consts
    jaxpr = jaxpr.jaxpr
    new_vars = core.gensym([jaxpr])
    new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
    new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars,
                           jaxpr.eqns)
    return core.ClosedJaxpr(new_jaxpr, consts)
Exemple #2
0
 def array_to_spvalue(arg):
     if isinstance(arg, BCOO):
         return spenv.sparse(arg.shape, arg.data, arg.indices)
     elif core.get_aval(arg) is core.abstract_unit:
         return spenv.unit()
     else:
         return spenv.dense(arg)
Exemple #3
0
def typecheck_atom(env, x):
  if isinstance(x, Literal):
    return core.raise_to_shaped(core.get_aval(x.val))
  elif isinstance(x, Var):
    return typecheck_type(env, x.aval)
  else:
    raise TypeError(f'atom of unexpected type {x}')
Exemple #4
0
def _check_output_dtype_jacfwd(holomorphic, x):
    aval = core.get_aval(x)
    if holomorphic:
        if not dtypes.issubdtype(aval.dtype, np.complexfloating):
            raise TypeError(
                "jacfwd with holomorphic=True requires outputs with complex dtype, "
                f"but got {aval.dtype.name}.")
Exemple #5
0
 def array_to_argspec(arg):
   if isinstance(arg, BCOO):
     return ArgSpec(arg.shape, spenv.push(arg.data), spenv.push(arg.indices))
   elif core.get_aval(arg) is core.abstract_unit:
     return ArgSpec((), None, None)
   else:
     return ArgSpec(np.shape(arg), spenv.push(arg), None)
Exemple #6
0
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
    # TODO(jakevdp): currently this approach discards all information about
    #   shared data & indices when generating the sparsified jaxpr. The
    #   current approach produces valid sparsified while loops, but they
    #   don't work in corner cases (see associated TODO in sparsify_test.py)
    out_tree = None

    @lu.wrap_init
    def wrapped(*args_flat):
        nonlocal out_tree
        args = tree_unflatten(in_tree, args_flat)
        argspecs = arrays_to_argspecs(spenv, args)
        result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
        out = argspecs_to_arrays(spenv, result)
        out_flat, out_tree = tree_flatten(out)
        return out_flat

    args = argspecs_to_arrays(spenv, argspecs)
    args_flat, in_tree = tree_flatten(args)
    avals_flat = [
        core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat
    ]
    sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
    sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
    return sp_jaxpr, out_tree
Exemple #7
0
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  @lu.wrap_init
  def wrapped(*args_flat):
    # TODO(frostig,jakevdp): This closes over `spenv`, which can bring
    # in buffers from the "outer scope" as constants. Is this a
    # problem for primitives like cond and while_loop, which always
    # convert constvars to invars when staging out their subjaxprs?
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    spvalues = arrays_to_spvalues(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv)
    out = spvalues_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = spvalues_to_arrays(spenv, spvalues)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
  return sp_jaxpr, out_tree
Exemple #8
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
def _custom_vjp_call_jaxpr_jvp(primals, tangents, *,
                               fun_jaxpr: core.ClosedJaxpr,
                               fwd_jaxpr_thunk: Callable[[],
                                                         Tuple[core.Jaxpr,
                                                               Sequence[Any]]],
                               bwd: lu.WrappedFun, out_trees: Callable,
                               num_consts: int):
    _, args = split_list(primals, [num_consts])
    consts_dot, args_dot = split_list(tangents, [num_consts])
    if any(type(t) is not Zero for t in consts_dot):
        raise ad.CustomVJPException()
    fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk()  # consts can be tracers!
    out_tree, res_tree = out_trees()
    args_dot = map(ad.instantiate_zeros, args_dot)
    # Cast float0 to zeros with the primal dtype because custom vjp rules don't
    # currently handle float0s
    args_dot = map(ad.replace_float0s, args, args_dot)
    res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
    res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
    avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
    tangents_out = ad.custom_lin_p.bind(*res,
                                        *args_dot,
                                        num_res=res_tree.num_leaves,
                                        bwd=bwd,
                                        avals_out=avals_out)
    tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
    return primals_out, tangents_out
Exemple #10
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
             _check_for_tracers(args[i])
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                        args)
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
                                       flat_fwd,
                                       flat_bwd,
                                       *args_flat,
                                       out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Exemple #11
0
class ErrorTracer(core.Tracer):
  def __init__(self, trace, val):
    self._trace = trace
    self.val = val
    core.get_aval(val), val
  aval = property(lambda self: core.get_aval(self.val))
  full_lower = lambda self: self
 def fwd(*args, **kwargs):
     ans, rule = fun(*args, **kwargs)
     ans_flat, out_tree = tree_flatten((ans, ))
     rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
     ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
     return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
Exemple #13
0
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  @lu.wrap_init
  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
    in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
    in_cts_ = iter(in_cts)
    in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
              else ad_util.Zero(x.aval) for x in in_primals]
    assert next(in_cts_, None) is None
    in_cts, cell.treedef = tree_flatten(in_cts)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
Exemple #14
0
def bdim_at_front(x, bdim, size):
    if core.get_aval(x) is core.abstract_unit:
        return core.unit
    if bdim is not_mapped:
        return broadcast(x, size, 0)
    else:
        return moveaxis(x, bdim, 0)
Exemple #15
0
def broadcast(x, sz, axis):
    if core.get_aval(x) is core.abstract_unit:
        return core.unit
    shape = list(np.shape(x))
    shape.insert(axis, sz)
    broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
    return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
Exemple #16
0
 def aval(self):
     aval = raise_to_shaped(core.get_aval(self.val))
     if self.batch_dim is not_mapped or aval is core.abstract_unit:
         return aval
     else:
         return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim,
                                 aval)
Exemple #17
0
def maybe_bdim_at_front(x, bdim):
    if core.get_aval(x) is core.abstract_unit:
        return core.unit
    if bdim is not_mapped:
        return x
    else:
        return util.moveaxis(x, bdim, 0)
Exemple #18
0
 def new(cls, val):
     if val is jax_core.unit:
         return InverseAndILDJ.unknown(jax_core.abstract_unit)
     val = np.array(val)
     aval = jax_core.get_aval(val)
     aval = abstract_arrays.raise_to_shaped(aval)
     ndslice = NDSlice.new(val, np.zeros_like(val))
     return InverseAndILDJ(aval, frozenset([ndslice]))
Exemple #19
0
def _jaxtupletree_select(pred, on_true, on_false):
  aval = core.get_aval(on_true)
  if type(aval) is core.AbstractTuple:
    return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
  elif isinstance(aval, UnshapedArray):
    return lax.select(pred, on_true, on_false)
  else:
    raise TypeError(aval)
def _flatten_jvp(in_tree, *args):
    primals_in, tangents_in = split_list(args, [len(args) // 2])
    py_primals = tree_unflatten(in_tree, primals_in)
    py_tangents = tree_unflatten(in_tree, tangents_in)
    pair_out = yield (py_primals, py_tangents), {}
    if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
        msg = (
            "Custom JVP rule must produce a pair (list or tuple of length two) "
            "representing primal and tangent outputs, got {}.")
        raise TypeError(msg.format(pair_out))
    py_primals_out, py_tangents_out = pair_out
    primals_out, out_tree = tree_flatten(py_primals_out)
    tangents_out, out_tree2 = tree_flatten(py_tangents_out)
    if out_tree != out_tree2:
        msg = (
            "Custom JVP rule must produce primal and tangent outputs with equal "
            "container (pytree) structures, but got {} and {} respectively.")
        raise TypeError(msg.format(out_tree, out_tree2))
    # TODO(mattjj): compare primals' tangent types to tangent objects' types
    primal_avals_out = [
        raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
        for x in primals_out
    ]
    tangent_avals_out = [
        raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
        for t in tangents_out
    ]
    if primal_avals_out != tangent_avals_out:
        if len(primal_avals_out) == 1:
            (av1, ), (av2, ) = primal_avals_out, tangent_avals_out
            msg = (
                "Custom JVP rule must produce primal and tangent outputs with "
                "equal shapes and dtypes, but got {} and {} respectively.")
            raise TypeError(msg.format(av1.str_short(), av2.str_short()))
        else:
            msg = (
                "Custom JVP rule must produce primal and tangent outputs with "
                "equal shapes and dtypes, but got:\n{}")
            disagreements = (
                "  primal {} for tangent {}".format(av1.str_short(),
                                                    av2.str_short())
                for av1, av2 in zip(primal_avals_out, tangent_avals_out)
                if av1 != av2)
            raise TypeError(msg.format('\n'.join(disagreements)))
    yield primals_out + tangents_out, out_tree
Exemple #21
0
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
Exemple #22
0
def _custom_ivjp(fun, ivjp, args):
  in_avals = [raise_to_shaped(get_aval(x)) for x in args]
  fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
  try:
    ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2)
  except RecursionError:
    raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
  return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
                                   ivjp_jaxpr=ivjp_jaxpr)
Exemple #23
0
 def __init__(self, trace, val, batch_dim: Optional[int]):
     if config.jax_enable_checks:
         assert type(batch_dim) in (int, NotMapped)
         if type(batch_dim) is int:
             aval = raise_to_shaped(core.get_aval(val))
             assert aval is core.abstract_unit or 0 <= batch_dim < len(
                 aval.shape)  # type: ignore
     self._trace = trace
     self.val = val
     self.batch_dim = batch_dim
Exemple #24
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
Exemple #25
0
 def write_cotangent(prim, v, ct):
   # assert v not in primal_env
   assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
   if ct is None or type(v) is Literal:
     return
   if type(ct) is Zero:
     # FIXME: This triggers a lot of failures!
     # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
     return
   axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                          if axis_name in core.get_aval(ct).named_shape
                          and axis_name not in v.aval.named_shape)
   if axes_to_reduce:
     ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
   ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
   if config.jax_enable_checks:
     ct_aval = core.get_aval(ct_env[v])
     joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
     assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
Exemple #26
0
def shaped_abstractify(x):
  try:
    return core.raise_to_shaped(core.get_aval(x))
  except TypeError:
    pass

  weak_type = getattr(x, 'weak_type', False)
  named_shape = getattr(x, 'named_shape', {})
  return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
                          named_shape=named_shape)
Exemple #27
0
 def handle_sow(self, *values, name, tag, tree, mode):
   """Stores a sow in the reaps dictionary."""
   del tag
   if name in self.reaps:
     raise ValueError(f'Variable has already been reaped: {name}')
   avals = tree_util.tree_unflatten(
       tree,
       [abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values])
   self.reaps[name] = Reap(
       tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals))
   return values
Exemple #28
0
 def __init__(self, trace, val, batch_dim: Optional[int],
              source_info: Optional[source_info_util.SourceInfo] = None):
   if config.jax_enable_checks:
     assert type(batch_dim) in (int, NotMapped)
     if type(batch_dim) is int:
       aval = raise_to_shaped(core.get_aval(val))
       assert batch_dim is not_mapped or 0 <= batch_dim < len(aval.shape)  # type: ignore
   self._trace = trace
   self.val = val
   self.batch_dim = batch_dim
   self.source_info = source_info
Exemple #29
0
 def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
   primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
   tangents_in = map(instantiate_zeros, tangents_in)
   res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
   out_tree, res_tree = out_trees()
   res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
   avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
   tangents_out = custom_lin_p.bind(
       *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
       out_avals=avals_out)
   tangents_out = map(recast_to_float0, primals_out, tangents_out)
   return map(partial(JVPTracer, self), primals_out, tangents_out)
Exemple #30
0
 def get_arg(a, unknown):
     if unknown:
         return tree_flatten(
             (
                 tree_map(
                     lambda x: PartialVal.unknown(get_aval(x).at_least_vspace()), a
                 ),
                 {},
             )
         )[0]
     else:
         return PartialVal.known(a)