call=jvp_call, rule=jvp_of_rule_rule, in_tree=jvp_in_tree, out_tree=jvp_out_tree) assert len(outs) % 2 == 0, len(outs) out_primals, out_tangents = util.split_list(outs, [len(outs) // 2]) return out_primals, out_tangents custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp xla.register_initial_style_primitive(custom_vmap_p) mlir.register_lowering(custom_vmap_p, mlir.lower_fun(custom_vmap_impl, multiple_results=True)) # -- custom vmap applications def tree_split(mask, tree): lhs = tree_map(lambda l, x: x if l else None, mask, tree) rhs = tree_map(lambda l, x: None if l else x, mask, tree) return lhs, rhs def tree_merge(mask, lhs_tree, rhs_tree): return tree_map(lambda l, x_l, x_r: x_l if l else x_r, mask, lhs_tree, rhs_tree)
res_arg, lin_arg = tree_unflatten(call_in_tree, args) del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct for ct in cts ] ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), is_leaf=lambda x: x is None) return [None] * len(tree_leaves(res_arg)) + ct_lin_flat def custom_transpose_lowering(*args, call_jaxpr, **params): return core.jaxpr_as_fun(call_jaxpr)(*args) custom_transpose_p = CustomTransposePrimitive('custom_transpose_call') core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule mlir.register_lowering( custom_transpose_p, mlir.lower_fun(custom_transpose_lowering, multiple_results=True)) xla.register_initial_style_primitive(custom_transpose_p)
return core.AxisPrimitive.bind(cond_p, *args, branches=branches, linear=linear) cond_p = core.AxisPrimitive('cond') cond_p.multiple_results = True cond_p.def_impl(partial(xla.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) cond_p.def_custom_bind(cond_bind) ad.primitive_jvps[cond_p] = _cond_jvp ad.reducing_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.axis_primitive_batchers[cond_p] = _cond_batching_rule xla.register_initial_style_primitive(cond_p) core.custom_typechecks[cond_p] = _cond_typecheck pe.partial_eval_jaxpr_custom_rules[cond_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond') def _cond_lowering(ctx, index, *args, branches, linear): del linear # Unused. joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = [ eff for eff in joined_effects if eff in core.ordered_effects ] num_tokens = len(ordered_effects) tokens_in = ctx.tokens_in.subset(ordered_effects) output_token_types = [mlir.token_type() for _ in ordered_effects]
] # Broadcast out b if necessary new_b = [ batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] outs = linear_solve_p.bind(*(new_params + new_b), const_lengths=const_lengths, jaxprs=batched_jaxprs) out_dims = [ 0 if batched else batching.not_mapped for batched in solve_x_bat ] return outs, out_dims linear_solve_p = core.AxisPrimitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp xla.register_initial_style_primitive(linear_solve_p) mlir.register_lowering( linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( jvp_jaxpr, axis_size, args_batched * 2, out_batched * 2, axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts batched_outs = custom_jvp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p) # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's # already been linearized, we can drop the jvp rule. def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr, jvp_jaxpr_thunk, num_consts): del jvp_jaxpr_thunk, num_consts return ad.backward_pass( fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts) ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose def custom_jvp_jaxpr_custom_partial_eval_rule( saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool], eqn: core.JaxprEqn ) -> Tuple[Optional[core.JaxprEqn], core.JaxprEqn, List[bool], List[bool], List[core.Var]]: