def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
def _flatten_jvp(in_tree, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) pair_out = yield (py_primals, py_tangents), {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = ("Custom JVP rule must produce a pair (list or tuple of length two) " "representing primal and tangent outputs, got {}.") raise TypeError(msg.format(pair_out)) py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) if out_tree != out_tree2: msg = ("Custom JVP rule must produce primal and tangent outputs with equal " "container (pytree) structures, but got {} and {} respectively.") raise TypeError(msg.format(out_tree, out_tree2)) primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out] if primal_avals_out != tangent_avals_out: if len(primal_avals_out) == 1: (av1,), (av2,) = primal_avals_out, tangent_avals_out msg = ("Custom JVP rule must produce primal and tangent outputs with " "equal shapes and dtypes, but got {} and {} respectively.") raise TypeError(msg.format(av1.str_short(), av2.str_short())) else: msg = ("Custom JVP rule must produce primal and tangent outputs with " "equal shapes and dtypes, but got:\n{}") disagreements = ( " primal {} for tangent {}".format(av1.str_short(), av2.str_short()) for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2) raise TypeError(msg.format('\n'.join(disagreements))) yield primals_out + tangents_out, out_tree
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): f, msgs = checkify_subtrace(f) f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f, msgs = check_errors_subtrace(f) f = check_errors_traceable(f, tuple(error.msgs.items())) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *jaxpr.in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_vars = core.gensym([jaxpr]) new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars) new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts)
def aval(self): aval = raise_to_shaped(core.get_aval(self.val)) if self.batch_dim is not_mapped or aval is core.abstract_unit: return aval else: return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
def _sparsify_jaxpr(spenv, jaxpr, *argspecs): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): nonlocal out_tree args = tree_unflatten(in_tree, args_flat) argspecs = arrays_to_argspecs(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv) out = argspecs_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = argspecs_to_arrays(spenv, argspecs) args_flat, in_tree = tree_flatten(args) avals_flat = [ core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat ] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts) return sp_jaxpr, out_tree
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups): input_aval = raise_to_shaped(x) shape = list(input_aval.shape) size = shape.pop(split_axis) shape.insert(concat_axis, size) return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
def _sparsify_jaxpr(spenv, jaxpr, *spvalues): # TODO(jakevdp): currently this approach discards all information about # shared data & indices when generating the sparsified jaxpr. The # current approach produces valid sparsified while loops, but they # don't work in corner cases (see associated TODO in sparsify_test.py) out_tree = None @lu.wrap_init def wrapped(*args_flat): # TODO(frostig,jakevdp): This closes over `spenv`, which can bring # in buffers from the "outer scope" as constants. Is this a # problem for primitives like cond and while_loop, which always # convert constvars to invars when staging out their subjaxprs? nonlocal out_tree args = tree_unflatten(in_tree, args_flat) spvalues = arrays_to_spvalues(spenv, args) result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv) out = spvalues_to_arrays(spenv, result) out_flat, out_tree = tree_flatten(out) return out_flat args = spvalues_to_arrays(spenv, spvalues) args_flat, in_tree = tree_flatten(args) avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat] sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts) return sp_jaxpr, out_tree
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: if not self.fwd or not self.bwd: msg = "No VJP defined for custom_vjp function {} using defvjp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args) static_args = [args[i] for i in self.nondiff_argnums] fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) fst, aux = lu.merge_linear_aux(out_tree, out_trees) out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat)
def _cond_transpose(reduce_axes, cts, *args, branches, linear): index, *ops = args in_avals = _map(raise_to_shaped, branches[0].in_avals) num_res = len(ops) - sum(linear) branches_trans = tuple( _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches) lin_in_avals = [ raise_to_shaped(a, weak_type=False) for a, l in zip(in_avals, linear) if l ] assert all( core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) res = ops[:num_res] cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts) linear_trans = (False, ) * num_res + (True, ) * len(cts) out = cond_p.bind(index, *res, *cts, branches=branches_trans, linear=linear_trans) assert all(_map(core.typecheck, lin_in_avals, out)) out_iter = iter(out) out = [next(out_iter) if l else None for l in linear] assert next(out_iter, None) is None return [None] + out
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): assert not jaxpr.constvars cell = lambda: None @lu.wrap_init def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts_ = iter(in_cts) in_cts = [next(in_cts_) if ad.is_undefined_primal(x) else ad_util.Zero(x.aval) for x in in_primals] assert next(in_cts_, None) is None in_cts, cell.treedef = tree_flatten(in_cts) return in_cts args, treedef = tree_flatten((in_primals, out_cts)) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_) in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) return tree_unflatten(cell.treedef, in_cts) # type: ignore
def typecheck_atom(env, x): if isinstance(x, Literal): return core.raise_to_shaped(core.get_aval(x.val)) elif isinstance(x, Var): return typecheck_type(env, x.aval) else: raise TypeError(f'atom of unexpected type {x}')
def _custom_vjp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomVJPException() fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers! out_tree, res_tree = out_trees() args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom vjp rules don't # currently handle float0s args_dot = map(ad.replace_float0s, args, args_dot) res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = ad.custom_lin_p.bind(*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out) tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out
def array_result_handler(sticky_device: Optional[Device], aval: core.ShapedArray): if aval.dtype == dtypes.float0: return lambda _, __: np.zeros(aval.shape, dtypes.float0) aval = core.raise_to_shaped(aval) handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device) handler.args = aval, sticky_device # for C++ dispatch path in api.py return handler
def _custom_ivjp(fun, ivjp, args): in_avals = [raise_to_shaped(get_aval(x)) for x in args] fun_jaxpr = _initial_style_jaxpr(fun, in_avals) try: ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2) except RecursionError: raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__)) return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr, ivjp_jaxpr=ivjp_jaxpr)
def shaped_abstractify(x): try: return core.raise_to_shaped(core.get_aval(x)) except TypeError: pass weak_type = getattr(x, 'weak_type', False) named_shape = getattr(x, 'named_shape', {}) return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type, named_shape=named_shape)
def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(fun, in_tree, False, "checkpoint") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) out_flat = remat_p.bind( *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), prevent_cse=prevent_cse, differentiated=False, policy=policy) return tree_unflatten(out_tree(), out_flat)
def __init__(self, trace, val, batch_dim: Optional[int]): if config.jax_enable_checks: assert type(batch_dim) in (int, NotMapped) if type(batch_dim) is int: aval = raise_to_shaped(core.get_aval(val)) assert aval is core.abstract_unit or 0 <= batch_dim < len( aval.shape) # type: ignore self._trace = trace self.val = val self.batch_dim = batch_dim
def __init__(self, trace, val, batch_dim: Optional[int], source_info: Optional[source_info_util.SourceInfo] = None): if config.jax_enable_checks: assert type(batch_dim) in (int, NotMapped) if type(batch_dim) is int: aval = raise_to_shaped(core.get_aval(val)) assert batch_dim is not_mapped or 0 <= batch_dim < len(aval.shape) # type: ignore self._trace = trace self.val = val self.batch_dim = batch_dim self.source_info = source_info
def f_jitted(*args, **kwargs): args, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] jaxpr, consts, unconverted_binders = trace_to_jaxpr_dynamic(f, in_avals) num_consts = len(consts) args = [*consts, *args] dim_vals, args = _extract_dim_vals(jaxpr.in_dim_binders, jaxpr.in_binders, unconverted_binders, args) out_flat = dynamic_xla_call_p.bind(*dim_vals, *args, jaxpr=jaxpr, num_consts=num_consts) return tree_unflatten(out_tree(), out_flat)
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) tangents_in = map(instantiate_zeros, tangents_in) res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) out_tree, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out)
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue, *idx: int): if not isinstance(ref_aval, ShapedArrayRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) expected_output_shape = ref_aval.shape[len(idx):] if expected_output_shape != val_aval.shape: raise ValueError("Invalid shape for `swap`. " f"Ref shape: {ref_aval.shape}. " f"Value shape: {val_aval.shape}. " f"Indices: {idx}. ") return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
def fwd(*args): flat_args, in_tree = tree_flatten(args) in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args) fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals) # TODO: Don't warn if consts contain JVP tracers? if consts: warnings.warn("Values that an @invertible function closes over will not have their " + "gradients computed correctly (their uses inside this function will be ignored)!") # TODO: This requires the body to be jittable, but this shouldn't be necessary. # Is there a way to trace a jaxpr while running it? flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args) return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def _lu_abstract_eval(operand): operand = raise_to_shaped(operand) if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] pivot = operand.update(shape=batch_dims + (min(m, n),), dtype=jnp.int32) perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32) else: pivot = operand perm = operand return operand, pivot, perm
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) if not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) if config.jax_dynamic_shapes: in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args] fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat)
def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): pivots = raise_to_shaped(pivots) if isinstance(pivots, ShapedArray): if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): raise ValueError( 'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype ' 'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype)) if permutation_size < pivots.shape[-1]: raise ValueError( 'Output permutation size {} has to exceed the trailing dimension of ' 'the pivots. Got shape {}'.format(permutation_size, pivots.shape)) batch_dims = pivots.shape[:-1] permutations = pivots.update(shape=batch_dims + (permutation_size,)) else: permutations = pivots return permutations
def __call__(self, residual_arg, linear_arg): res_arg, lin_arg = residual_arg, linear_arg _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_transpose_p.bind(*consts, *args_flat, call=closed_call, rule=self.transpose, lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat)
def abstractify(x): return core.raise_to_shaped(core.get_aval(x))