return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

    batched_outs = custom_jvp_call_jaxpr_p.bind(
        *args,
        fun_jaxpr=batched_fun_jaxpr,
        jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk,
        num_consts=num_consts)
    out_dims = out_dims2[0] if out_dims2 else out_dims1
    return batched_outs, out_dims


batching.axis_primitive_batchers[
    custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
    xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)


# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
                                     jvp_jaxpr_thunk, num_consts):
    del jvp_jaxpr_thunk, num_consts
    return ad.backward_pass(fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts,
                            args, cts)


ad.reducing_transposes[
    custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
Exemple #2
0
        return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

    batched_outs = custom_jvp_call_jaxpr_p.bind(
        *args,
        fun_jaxpr=batched_fun_jaxpr,
        jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk,
        num_consts=num_consts)
    out_dims = out_dims2[0] if out_dims2 else out_dims1
    return batched_outs, out_dims


batching.initial_style_batchers[
    custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
    xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)


# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
                                     num_consts):
    del jvp_jaxpr_thunk, num_consts
    return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)


ad.primitive_transposes[
    custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose

### VJPs