def linear_jvp(primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) if all(type(tangent) is Zero for tangent in tangents): return val_out, Zero.from_value(val_out) else: tangents = map(instantiate_zeros, tangents) return val_out, primitive.bind(*tangents, **params)
def standard_jvp2(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
def f_jvp_traceable(nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = yield (primals, tangents), {} out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] yield list(primals_out) + nonzero_tangents_out, out_nonzeros
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, primals_tangents_in, cotangents_in): primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)] tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)] res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in) cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (), (*res, *tangents_in), cotangents_in) cotangents_out = iter(cotangents_out_[len(res):]) outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x) for x in primals_tangents_in] assert next(cotangents_out, None) is None return outs
def jvpfun(instantiate, transform_stack, primals, tangents): tangents = [Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with core.new_main(JVPTrace) as main, ctx: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] yield out_primals, out_tangents
def jvpfun(instantiate, primals, tangents): tangents = [ Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents ] with core.new_main(JVPTrace) as main: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [ instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate) ] yield out_primals, out_tangents
def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) return r, Zero.from_value(r)
return Zero if out is Zero else (out, None) else: out = rhs_rule(cotangent, x, **kwargs) return Zero if out is Zero else (None, out) def defjvp_zero(primitive): assert isinstance(primitive, Primitive) primitive_jvps[primitive] = partial(zero_jvp, primitive) def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) return r, Zero.from_value(r) deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)]) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) def instantiate_zeros(tangent): if type(tangent) is Zero: return zeros_like_aval(tangent.aval) else: return tangent # This function seems similar to instantiate_zeros, but it is sometimes used # to instantiate zero abstract units with a different aval def instantiate_zeros_aval(aval, tangent): if type(tangent) is Zero: assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval return zeros_like_aval(aval) else: