예제 #1
0
                              call=jvp_call,
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents


custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
mlir.register_lowering(custom_vmap_p,
                       mlir.lower_fun(custom_vmap_impl, multiple_results=True))

# -- custom vmap applications


def tree_split(mask, tree):
    lhs = tree_map(lambda l, x: x if l else None, mask, tree)
    rhs = tree_map(lambda l, x: None if l else x, mask, tree)
    return lhs, rhs


def tree_merge(mask, lhs_tree, rhs_tree):
    return tree_map(lambda l, x_l, x_r: x_l
                    if l else x_r, mask, lhs_tree, rhs_tree)
예제 #2
0
    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_lowering(*args, call_jaxpr, **params):
    return core.jaxpr_as_fun(call_jaxpr)(*args)


custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_initial_style_primitive(custom_transpose_p)
예제 #3
0
    return core.AxisPrimitive.bind(cond_p,
                                   *args,
                                   branches=branches,
                                   linear=linear)


cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
    partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')


def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
예제 #4
0
파일: solves.py 프로젝트: xueeinstein/jax
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims


linear_solve_p = core.AxisPrimitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(
    linear_solve_p,
    mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
    partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'linear_solve')
    primals_batched, tangents_batched = split_list(all_batched, [num_out])
    out_batched = map(op.or_, primals_batched, tangents_batched)
    out_dims2.append([0 if b else not_mapped for b in out_batched])
    batched_jvp_jaxpr, _ = batching.batch_jaxpr(
        jvp_jaxpr, axis_size, args_batched * 2, out_batched * 2,
        axis_name, main_type)
    return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

  batched_outs = custom_jvp_call_jaxpr_p.bind(
      *args, fun_jaxpr=batched_fun_jaxpr,
      jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
  out_dims = out_dims2[0] if out_dims2 else out_dims1
  return batched_outs, out_dims
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p)

# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
                                     jvp_jaxpr_thunk, num_consts):
  del jvp_jaxpr_thunk, num_consts
  return ad.backward_pass(
      fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose

def custom_jvp_jaxpr_custom_partial_eval_rule(
    saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
    eqn: core.JaxprEqn
  ) -> Tuple[Optional[core.JaxprEqn], core.JaxprEqn, List[bool], List[bool], List[core.Var]]: