def standard_jvp2(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
def f_jvp_traceable(nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = yield (primals, tangents), {} out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] yield list(primals_out) + nonzero_tangents_out, out_nonzeros
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, primals_tangents_in, cotangents_in): primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)] tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)] res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in) cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (), (*res, *tangents_in), cotangents_in) cotangents_out = iter(cotangents_out_[len(res):]) outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x) for x in primals_tangents_in] assert next(cotangents_out, None) is None return outs
def jvpfun(instantiate, transform_stack, primals, tangents): tangents = [Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with core.new_main(JVPTrace) as main, ctx: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] yield out_primals, out_tangents
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis if isinstance(x, Zero): if src == dst: return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) return Zero(core.unmapped_aval(sz, name, dst, aval)) elif src is not_mapped and dst is not not_mapped: return Zero(core.unmapped_aval(sz, name, dst, x.aval)) elif dst is not_mapped and sum_match: return Zero(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((axis_name, x, src, dst)) else: return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
def jvpfun(instantiate, primals, tangents): tangents = [ Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents ] with core.new_main(JVPTrace) as main: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [ instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate) ] yield out_primals, out_tangents
def _flatten_bwd(in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) py_cts_in = yield (py_res, py_cts_out), {} # For each None in py_cts_in, indicating an argument for which the rule # produces no cotangent, we replace it with a pytree with the structure of the # corresponding subtree of in_tree and with leaves of a non-pytree sentinel # object, to be replaced with Nones in the final returned result. zero = object() # non-pytree sentinel to replace Nones in py_cts_in dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves) cts_in_flat = [] append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0])) try: if not isinstance(py_cts_in, tuple): raise ValueError tree_multimap(append_cts, tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ( "Custom VJP rule must produce an output with the same container " "(pytree) structure as the args tuple of the primal function, " "and in particular must produce a tuple of length equal to the " "number of arguments to the primal function, but got VJP output " "structure {} for primal input structure {}.") raise TypeError(msg.format(in_tree2, in_tree)) from None # Ignore any None cotangents, and any corresponding to inputs for which the # type doesn't equal the tangent type (i.e. float0s) # TODO(mattjj): change this to check if tangent type represents 0dim vspace yield [ Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace() else ct for a, ct in zip(in_avals, cts_in_flat) ]
def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) return r, Zero.from_value(r)
def lift(self, val): tangent_zero = Zero(get_aval(val).at_least_vspace()) return JVPTracer(self, val, tangent_zero)
def read_cotangent(v): return ct_env.pop(v, Zero(v.aval))
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return axes_to_reduce = tuple(axis_name for axis_name in reduce_axes if axis_name in core.get_aval(ct).named_shape and axis_name not in v.aval.named_shape) if axes_to_reduce: ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) call_jaxpr = params.pop('call_jaxpr') cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) elif eqn.primitive in reducing_transposes: cts_out = reducing_transposes[eqn.primitive]( reduce_axes, cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out
def recast_to_float0(primal, tangent): if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: return Zero(get_aval(primal).at_least_vspace()) else: return tangent
return Zero if out is Zero else (out, None) else: out = rhs_rule(cotangent, x, **kwargs) return Zero if out is Zero else (None, out) def defjvp_zero(primitive): assert isinstance(primitive, Primitive) primitive_jvps[primitive] = partial(zero_jvp, primitive) def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) return r, Zero.from_value(r) deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)]) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) def instantiate_zeros(tangent): if type(tangent) is Zero: return zeros_like_aval(tangent.aval) else: return tangent # This function seems similar to instantiate_zeros, but it is sometimes used # to instantiate zero abstract units with a different aval def instantiate_zeros_aval(aval, tangent): if type(tangent) is Zero: assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval return zeros_like_aval(aval) else:
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) def write_cotangent(prim, v, ct): # assert v not in primal_env assert ct is not Zero, (prim, v.aval ) # check for an old harmless type error if ct is None or type(v) is Literal: return if type(ct) is Zero: # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join( v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape( ) == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, UndefinedPrimal(v.aval)) def write_primal(v, val): if not is_undefined_primal(val): primal_env[v] = val primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! map(write_primal, jaxpr.invars, primals_in) ct_env: Dict[Any, Any] = {} map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: # FIXME: Some invars correspond to tangents invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) else: cts_in, = map(read_cotangent, eqn.outvars) with source_info_util.user_context(eqn.source_info): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr( eqn.primitive, eqn.params) cts_out = get_primitive_transpose(eqn.primitive)(params, call_jaxpr, invals, cts_in, cts_in_avals) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out