Exemplo n.º 1
0
    def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
        vals, dims, srcs = unzip3(
            (t.val, t.batch_dim, t.source_info) for t in out_tracers)
        axis_size, = {
            x.shape[d]
            for x, d in zip(vals, dims) if d is not not_mapped
        }
        main, trace_type = self.main, self.main.trace_type
        axis_name = self.axis_name
        _, res_tree = out_trees()
        num_res = res_tree.num_leaves
        res_dims, primal_dims = split_list(dims, [num_res])
        _, primal_srcs = split_list(srcs, [num_res])

        def todo(vals):
            trace = main.with_cur_sublevel()
            return map(partial(BatchTracer, trace), vals, primal_dims,
                       primal_srcs)

        def bwd_transform(bwd):
            return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims,
                                        (None, ), trace_type)

        return vals, todo, bwd_transform
Exemplo n.º 2
0
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr,
                           body_nconsts, body_jaxpr):
    # TODO(lenamartens): fix when an error occurs in the cond function and it then returns False.
    checked_body_jaxpr, msgs_ = checkify_while_body_jaxpr(
        cond_jaxpr, body_jaxpr, error)
    compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
    c_consts, b_consts, carry = split_list(in_flat,
                                           [cond_nconsts, body_nconsts])
    new_in_flat = [*c_consts, *b_consts, error.err, error.code, *carry]
    err, code, *out = control_flow.while_p.bind(*new_in_flat,
                                                cond_nconsts=cond_nconsts,
                                                cond_jaxpr=compat_cond_jaxpr,
                                                body_nconsts=body_nconsts,
                                                body_jaxpr=checked_body_jaxpr)
    new_msgs = {**error.msgs, **msgs_}
    return out, Error(err, code, new_msgs)
Exemplo n.º 3
0
def function_effect_lowering(ctx, *, effect):
  def _f(ctx):
    ctx.set_tokens_out(ctx.tokens_in)
    return []
  func = mlir._emit_lowering_rule_as_fun(_f, ctx)

  output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
  token_types = [mlir.token_type() for _ in ctx.tokens_in.items()]
  output_types = [*token_types, *output_types]
  flat_output_types = util.flatten(output_types)
  call = mlir.func_dialect.CallOp(flat_output_types,
                                  mlir.ir.FlatSymbolRefAttr.get(func.name.value),
                                  mlir.flatten_lowering_ir_args(ctx.tokens_in.tokens()))
  tokens, out = util.split_list(call.results, [len(ctx.tokens_in)])
  ctx.set_tokens_out(mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)))
  return out
Exemplo n.º 4
0
def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
    nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
    # We need to find out which `Ref`s have nonzero tangents after running the
    # for loop. Ordinarily we do this with a fixed point on the body jaxpr but
    # a `for` body jaxpr is stateful and has no outputs. We therefore discharge
    # the state effect from the jaxpr and we will now have a "symmetric" jaxpr
    # where the inputs line up with the outputs. We use this discharged jaxpr
    # for the fixed point.
    discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
    for _ in range(len(nonzero_tangents)):
        _, out_nonzero_tangents = ad.jvp_jaxpr(core.ClosedJaxpr(
            discharged_jaxpr, body_consts), [False] + nonzero_tangents,
                                               instantiate=nonzero_tangents)
        if out_nonzero_tangents == nonzero_tangents:
            break
        nonzero_tangents = map(operator.or_, nonzero_tangents,
                               out_nonzero_tangents)
    else:
        raise Exception("Invalid fixpoint")
    tangents = [
        ad.instantiate_zeros(t) if inst else t
        for t, inst in zip(tangents, nonzero_tangents)
    ]
    tangents = [t for t in tangents if type(t) is not ad_util.Zero]
    closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
    jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts
    jvp_which_linear = ((False, ) * len(jvp_consts) + which_linear +
                        (True, ) * len(tangents))
    out_flat = for_p.bind(*jvp_consts,
                          *primals,
                          *tangents,
                          jaxpr=jvp_jaxpr,
                          nsteps=nsteps,
                          reverse=reverse,
                          which_linear=jvp_which_linear)
    # `out_flat` includes constant inputs into the `for_loop` which are
    # converted into outputs as well. We don't care about these in AD so we
    # throw them out.
    _, out_primals, out_tangents = split_list(
        out_flat, [len(jvp_consts), len(primals)])
    out_tangents_iter = iter(out_tangents)
    out_tangents = [
        next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, nonzero_tangents)
    ]
    return out_primals, out_tangents
Exemplo n.º 5
0
Arquivo: ad.py Projeto: John1Tang/jax
 def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
     primals_in, tangents_in = unzip2(
         (t.primal, t.tangent) for t in tracers)
     tangents_in = map(instantiate_zeros, tangents_in)
     res_and_primals_out = fwd.call_wrapped(
         *map(core.full_lower, primals_in))
     out_tree, res_tree = out_trees()
     res, primals_out = split_list(res_and_primals_out,
                                   [res_tree.num_leaves])
     avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
     tangents_out = custom_lin_p.bind(*res,
                                      *tangents_in,
                                      num_res=res_tree.num_leaves,
                                      bwd=bwd,
                                      out_avals=avals_out)
     tangents_out = map(recast_to_float0, primals_out, tangents_out)
     return map(partial(JVPTracer, self), primals_out, tangents_out)
Exemplo n.º 6
0
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
                     num_consts, num_carry, linear, unroll):
    consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
    checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
    new_linear = (False, False, *linear)
    new_in_flat = [*consts, error.err, error.code, *carry, *xs]
    err, code, *outs = lax.scan_p.bind(*consts,
                                       *new_in_flat,
                                       reverse=reverse,
                                       length=length,
                                       jaxpr=checked_jaxpr,
                                       num_consts=len(consts),
                                       num_carry=len(carry) + 2,
                                       linear=new_linear,
                                       unroll=unroll)
    new_msgs = {**error.msgs, **msgs_}
    return outs, Error(err, code, new_msgs)
Exemplo n.º 7
0
def _reduce_window_abstract_eval_rule(
    *avals, jaxpr, consts, window_dimensions, window_strides, padding,
    base_dilation, window_dilation):
  operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
  if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
    msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
           " got operand dtypes {} and init_value dtypes {}.")
    raise TypeError(msg.format([o.dtype for o in operand_avals],
                               [iv.dtype for iv in init_val_avals]))
  if any(len(v.shape) != 0 for v in init_val_avals):
    msg = ("reduce_window expected init_values to be scalars but init_values "
           "have shapes {}.")
    raise TypeError(msg.format([v.shape for v in init_val_avals]))
  out_shape = _common_reduce_window_shape_rule(
    operand_avals[0], window_dimensions, window_strides, padding,
    base_dilation, window_dilation)
  return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
Exemplo n.º 8
0
def _flatten_jvp(in_tree, *args):
    primals_in, tangents_in = split_list(args, [len(args) // 2])
    py_primals = tree_unflatten(in_tree, primals_in)
    py_tangents = tree_unflatten(in_tree, tangents_in)
    pair_out = yield (py_primals, py_tangents), {}
    if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
        msg = (
            "Custom JVP rule must produce a pair (list or tuple of length two) "
            "representing primal and tangent outputs, got {}.")
        raise TypeError(msg.format(pair_out))
    py_primals_out, py_tangents_out = pair_out
    primals_out, out_tree = tree_flatten(py_primals_out)
    tangents_out, out_tree2 = tree_flatten(py_tangents_out)
    if out_tree != out_tree2:
        msg = (
            "Custom JVP rule must produce primal and tangent outputs with equal "
            "container (pytree) structures, but got {} and {} respectively.")
        raise TypeError(msg.format(out_tree, out_tree2))
    # TODO(mattjj): compare primals' tangent types to tangent objects' types
    primal_avals_out = [
        raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
        for x in primals_out
    ]
    tangent_avals_out = [
        raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
        for t in tangents_out
    ]
    if primal_avals_out != tangent_avals_out:
        if len(primal_avals_out) == 1:
            (av1, ), (av2, ) = primal_avals_out, tangent_avals_out
            msg = (
                "Custom JVP rule must produce primal and tangent outputs with "
                "equal shapes and dtypes, but got {} and {} respectively.")
            raise TypeError(msg.format(av1.str_short(), av2.str_short()))
        else:
            msg = (
                "Custom JVP rule must produce primal and tangent outputs with "
                "equal shapes and dtypes, but got:\n{}")
            disagreements = (
                "  primal {} for tangent {}".format(av1.str_short(),
                                                    av2.str_short())
                for av1, av2 in zip(primal_avals_out, tangent_avals_out)
                if av1 != av2)
            raise TypeError(msg.format('\n'.join(disagreements)))
    yield primals_out + tangents_out, out_tree
Exemplo n.º 9
0
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
  checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
  checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
  # Check if the first cond application will error.
  cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)

  checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
  compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
  c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
  new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
  err, code, *out = lax.while_p.bind(
      *new_in_flat,
      cond_nconsts=cond_nconsts,
      cond_jaxpr=compat_cond_jaxpr,
      body_nconsts=body_nconsts,
      body_jaxpr=checked_body_jaxpr)
  new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
  return out, Error(err, code, new_msgs)
Exemplo n.º 10
0
def _custom_vjp_call_jaxpr_vmap(
        args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
        fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
        bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
    axis_size, = {
        x.shape[d]
        for x, d in zip(args, in_dims) if d is not not_mapped
    }
    args = [
        batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x
        for x, d in zip(args, in_dims)
    ]

    in_batched = [d is not not_mapped for d in in_dims]
    _, args_batched = split_list(in_batched, [num_consts])
    batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
        fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
    out_dims1 = [0 if b else not_mapped for b in out_batched]
    out_dims2 = []

    @pe._memoize
    def batched_fwd_jaxpr_thunk():
        fwd_jaxpr = core.ClosedJaxpr(
            *fwd_jaxpr_thunk())  # consts can be tracers
        batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
            fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
        out_dims2.append([0 if b else not_mapped for b in out_batched])
        return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts

    fwd_args_batched = [0 if b else not_mapped for b in args_batched]
    fwd_out_dims = lambda: out_dims2[0]
    batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size,
                                                fwd_out_dims, fwd_args_batched,
                                                main_type)

    batched_outs = custom_vjp_call_jaxpr_p.bind(
        *args,
        fun_jaxpr=batched_fun_jaxpr,
        fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk,
        bwd=batched_bwd,
        out_trees=out_trees,
        num_consts=num_consts)
    out_dims = out_dims2[0] if out_dims2 else out_dims1
    return batched_outs, out_dims
Exemplo n.º 11
0
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
                     num_consts, num_carry, linear, unroll):
    consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
    checked_jaxpr_, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
    tomove = [False] * 3 + [True] * len(consts) + [False
                                                   ] * (len(carry) + len(xs))
    checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
    new_linear = (False, False, False, *linear)
    new_in_flat = [*consts, error.err, error.code, error.payload, *carry, *xs]
    err, code, payload, *outs = lax.scan_p.bind(*new_in_flat,
                                                reverse=reverse,
                                                length=length,
                                                jaxpr=checked_jaxpr,
                                                num_consts=len(consts),
                                                num_carry=len(carry) + 3,
                                                linear=new_linear,
                                                unroll=unroll)
    new_msgs = {**error.msgs, **msgs_}
    return outs, Error(err, code, new_msgs, payload)
Exemplo n.º 12
0
def _cond_lowering(ctx, index, *args, branches, linear):
    del linear  # Unused.
    joined_effects = core.join_effects(*(branch.effects
                                         for branch in branches))
    ordered_effects = [
        eff for eff in joined_effects if eff in core.ordered_effects
    ]
    num_tokens = len(ordered_effects)
    tokens_in = ctx.tokens_in.subset(ordered_effects)
    output_token_types = [mlir.token_type() for _ in ordered_effects]
    output_types = [
        *output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)
    ]
    flat_output_types = util.flatten(output_types)

    # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
    # have no arguments; the computation within the block uses implicit
    # captures.
    case_op = mhlo.CaseOp(flat_output_types,
                          index=index,
                          num_branches=len(branches))
    name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
    for i, jaxpr in enumerate(branches):
        branch = case_op.regions[i].blocks.append()
        with ir.InsertionPoint(branch):
            sub_ctx = ctx.module_context.replace(
                name_stack=xla.extend_name_stack(name_stack,
                                                 f'branch_{i}_fun'))
            out_vals, tokens_out = mlir.jaxpr_subcomp(
                sub_ctx, jaxpr.jaxpr, tokens_in,
                _map(mlir.ir_constants, jaxpr.consts),
                *_map(mlir.wrap_singleton_ir_values, args))
            out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
            out_vals = [*out_tokens, *out_vals]
            mhlo.ReturnOp(util.flatten(out_vals))

    tokens_and_outputs = util.unflatten(case_op.results,
                                        _map(len, output_types))
    tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
    ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
    return outputs
Exemplo n.º 13
0
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
    assert not jaxpr.constvars
    in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
    jaxpr_ = core.ClosedJaxpr(jaxpr, ())
    jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
    nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
    jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
    outs = remat_p.bind(*jaxpr_jvp_.consts,
                        *primals,
                        *nonzero_tangents,
                        jaxpr=jaxpr_jvp,
                        prevent_cse=prevent_cse,
                        differentiated=differentiated,
                        policy=policy)
    out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
    out_tangents_ = iter(out_tangents_)
    out_tangents = [
        next(out_tangents_) if nz else ad_util.Zero.from_value(p)
        for p, nz in zip(out_primals, out_nonzeros)
    ]
    return out_primals, out_tangents
Exemplo n.º 14
0
def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts,
                         body_nconsts):
    cond_const_tracers, body_const_tracers, init_tracers = split_list(
        tracers, [cond_nconsts, body_nconsts])
    init_avals = safe_map(lambda x: x.aval, init_tracers)
    cond_const_vals, body_const_vals, init_vals = tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))

    body_fun = jaxpr_as_fun(body_jaxpr)
    cond_fun = jaxpr_as_fun(cond_jaxpr)

    def cond(*carry):
        return cond_fun(*it.chain(cond_const_vals, carry))

    def body(*carry):
        return body_fun(*it.chain(body_const_vals, carry))

    main = trace.main
    new_cond = callback_transform(cond,
                                  main.callback,
                                  strip_calls=main.strip_calls)  # type: ignore
    new_body = callback_transform(body,
                                  main.callback,
                                  strip_calls=main.strip_calls)  # type: ignore
    in_tree = tree_structure(init_avals)

    new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(
        new_cond, in_tree, tuple(init_avals))
    new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(
        new_body, in_tree, tuple(init_avals))
    out = lcf.while_p.bind(*it.chain(new_cond_consts, new_body_consts,
                                     init_vals),
                           cond_nconsts=len(new_cond_consts),
                           body_nconsts=len(new_body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    return safe_map(trace.pure, out)
Exemplo n.º 15
0
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts,
                        num_carry, jaxpr, linear, unroll):
    const_tracers, carry_tracers, xs_tracers = split_list(
        tracers, [num_consts, num_carry])
    carry_avals, xs_avals = tree_map(lambda x: x.aval,
                                     (carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))

    x_tracers = [t[0] for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]

    body_fun = jaxpr_as_fun(jaxpr)

    def new_body(carry, x):
        flat_args = tree_leaves((carry, x))
        out = body_fun(*(const_vals + flat_args))
        out_carry, y = split_list(out, [num_carry])
        return out_carry, y

    main = trace.main
    new_body = callback_transform(new_body,
                                  main.callback,
                                  strip_calls=main.strip_calls)  # type: ignore
    in_tree = tree_structure(tuple(carry_avals + xs_avals))
    new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
        new_body, in_tree, tuple(carry_avals + x_avals))
    vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
    out_vals = lax.scan_p.bind(*vals,
                               reverse=reverse,
                               length=length,
                               num_consts=len(new_consts),
                               num_carry=num_carry,
                               jaxpr=new_jaxpr,
                               linear=linear,
                               unroll=unroll)
    return safe_map(trace.pure, out_vals)
Exemplo n.º 16
0
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
    out_tree, res_tree = out_trees()
    res, cts_out = split_list(args, [res_tree.num_leaves])
    py_res = tree_unflatten(res_tree, res)
    py_cts_out = tree_unflatten(out_tree, cts_out)
    py_cts_in = yield (py_res, py_cts_out), {}
    # For each None in py_cts_in, indicating an argument for which the rule
    # produces no cotangent, we replace it with a pytree with the structure of the
    # corresponding subtree of in_tree and with leaves of a non-pytree sentinel
    # object, to be replaced with Nones in the final returned result.
    zero = object()  # non-pytree sentinel to replace Nones in py_cts_in
    dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
    cts_in_flat = []
    append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
    try:
        if not isinstance(py_cts_in, tuple):
            raise ValueError
        tree_multimap(append_cts,
                      tuple(zero if ct is None else ct for ct in py_cts_in),
                      dummy)
    except ValueError:
        _, in_tree2 = tree_flatten(py_cts_in)
        msg = (
            "Custom VJP rule must produce an output with the same container "
            "(pytree) structure as the args tuple of the primal function, "
            "and in particular must produce a tuple of length equal to the "
            "number of arguments to the primal function, but got VJP output "
            "structure {} for primal input structure {}.")
        raise TypeError(msg.format(in_tree2, in_tree)) from None
    # Ignore any None cotangents, and any corresponding to inputs for which the
    # type doesn't equal the tangent type (i.e. float0s)
    # TODO(mattjj): change this to check if tangent type represents 0dim vspace
    yield [
        Zero(a.at_least_vspace())
        if ct is zero or a != a.at_least_vspace() else ct
        for a, ct in zip(in_avals, cts_in_flat)
    ]
Exemplo n.º 17
0
def _linear_call_impl(*args, callee, transpose, num_callee_consts,
                      num_transpose_consts, num_res):
  del transpose
  consts, _, operands_res, operands_lin = split_list(
      args, [num_callee_consts, num_transpose_consts, num_res])
  return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin)
Exemplo n.º 18
0
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
    def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
        in_batched_ps, in_batched_ts = in_batched

        mutually_batched = tree_map(operator.and_, in_batched_ps,
                                    in_batched_ts)
        extra_batched_ps = tree_map(
            lambda pb, tb: 0
            if pb and not tb else None, in_batched_ps, in_batched_ts)
        extra_batched_ts = tree_map(
            lambda pb, tb: 0
            if tb and not pb else None, in_batched_ps, in_batched_ts)

        out_mutually_batched = lu.Store()
        flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents))
        flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten(
            (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None)

        # TODO(frostig): assert these also equal:
        #   treedef_tuple((in_tree, in_tree))
        # once https://github.com/google/jax/issues/9066 is fixed
        assert tree_ps_ts == tree_ps_ts2
        del tree_ps_ts2

        def to_jvp(*primals):
            out, out_batched = call_rule(rule, axis_size, mutually_batched,
                                         primals)
            check_vmap_rule_trees(rule, out_tree, tree_structure(out),
                                  tree_structure(out_batched))
            out_mutually_batched.store(out_batched)
            return out

        def to_vmap_over_extra_batched_dims(primals, tangents):
            return jax.jvp(to_jvp, primals, tangents)

        to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs(
            lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts)

        flat_out_ps_ts, flat_out_axes = vmap_unrestricted(
            to_vmap_over_extra_batched_dims_flat,
            *flat_ps_ts,
            in_axes=flat_extra_batched_ps_ts,
            axis_name=core.no_axis_name,
            axis_size=axis_size)

        n, ragged = divmod(len(flat_out_ps_ts), 2)
        assert not ragged
        flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:]
        flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:]
        flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p)
        flat_out_extra_batched_ps = [
            d is not not_mapped for d in flat_out_axes_p
        ]
        flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t)
        flat_out_extra_batched_ts = [
            d is not not_mapped for d in flat_out_axes_t
        ]

        out_ps, out_ts = tree_unflatten(out_tree2(),
                                        [*flat_out_ps, *flat_out_ts])
        out_extra_batched_ps, out_extra_batched_ts = tree_unflatten(
            out_tree2(),
            [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts])

        out_batched_ps = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ps)
        out_batched_ts = tree_map(operator.or_, out_mutually_batched.val,
                                  out_extra_batched_ts)

        return (out_ps, out_ts), (out_batched_ps, out_batched_ts)

    tangents = map(ad.instantiate_zeros, tangents)
    jvp_call, _ = ad.jvp_jaxpr(call, [True] * len(primals), True)
    jvp_in_tree = treedef_tuple((in_tree, in_tree))
    jvp_out_tree = treedef_tuple((out_tree, out_tree))
    outs = custom_vmap_p.bind(*primals,
                              *tangents,
                              call=jvp_call,
                              rule=jvp_of_rule_rule,
                              in_tree=jvp_in_tree,
                              out_tree=jvp_out_tree)
    assert len(outs) % 2 == 0, len(outs)
    out_primals, out_tangents = util.split_list(outs, [len(outs) // 2])
    return out_primals, out_tangents
Exemplo n.º 19
0
def jvp_traceable(*primals_and_tangents):
  n = len(primals_and_tangents)
  primals, tangents = split_list(primals_and_tangents, [n // 2])
  primals_out, tangents_out = yield (primals, tangents), {}
  yield (*primals_out, *tangents_out)
Exemplo n.º 20
0
Arquivo: ad.py Projeto: jbampton/jax
def _perm(primal_counts, tangent_counts, lst):
  n = sum(primal_counts)
  primals, tangents = lst[:n], lst[n:]
  primal_groups = split_list(primals, primal_counts[:-1])
  tangent_groups = split_list(tangents, tangent_counts[:-1])
  return _interleave(primal_groups, tangent_groups)
Exemplo n.º 21
0
Arquivo: ad.py Projeto: jbampton/jax
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
  res, _ = split_list(invals, [num_res])
  cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
  cts_in = bwd.call_wrapped(*res, *cts_out)
  return [None] * num_res + list(cts_in)
Exemplo n.º 22
0
def _split_linear_solve_args(args, const_lengths):
    params_list = split_list(args, list(const_lengths))
    return _LinearSolveTuple(*params_list[:-1]), params_list[-1]
Exemplo n.º 23
0
 def new_body(*vals):
   out = body_fun(*vals)
   out_carry, y = split_list(out, [num_carry])
   return out_carry, y
Exemplo n.º 24
0
def _cond_partial_eval(trace, *tracers, branches, linear):
    in_unknowns = [t.pval[0] is not None for t in tracers]
    index_uk, *ops_uk = in_unknowns

    if index_uk:
        # When the branch index is unknown, we stage out the whole cond.
        # TODO(mattjj): remove this path when old remat is removed
        params = dict(branches=branches, linear=linear)
        return trace.default_process_primitive(cond_p, tracers, params)

    branches_out_uks = []
    for branch_jaxpr in branches:
        _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr,
                                                         ops_uk,
                                                         instantiate=False)
        branches_out_uks.append(out_uks)
    out_uks = [any(uks) for uks in zip(*branches_out_uks)]

    branches_known, branches_unknown, branch_res_avals = [], [], []
    for branch_jaxpr in branches:
        branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
            pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
        branches_known.append(branch_jaxpr_known)
        branches_unknown.append(branch_jaxpr_unknown)
        branch_res_avals.append(res_avals)

    all_res_avals, res_avals_per_branch = _merge_branch_residuals(
        branch_res_avals)
    num_res = len(all_res_avals)

    num_known_outs = len(out_uks) - sum(out_uks)
    branches_known = _join_cond_outputs(branches_known, all_res_avals,
                                        res_avals_per_branch, num_known_outs)
    branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
        branches_unknown, all_res_avals, res_avals_per_branch)
    assert all(
        all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
        for j in branches_known[1:])

    in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
    linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
    out_consts_res = cond_p.bind(*in_consts,
                                 branches=branches_known,
                                 linear=tuple(linear_known))
    out_consts, res = split_list(out_consts_res,
                                 [len(out_consts_res) - num_res])

    index_tracer = trace.instantiate_const(tracers[0])
    ops_tracers = [
        trace.instantiate_const(t)
        for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk
    ]
    res_tracers = _map(trace.new_instantiated_const, res)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
        for aval in branches_unknown[0].out_avals
    ]
    linear_unknown = ([False] * num_res +
                      [l for l, uk in zip(linear, in_unknowns[1:]) if uk])
    params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
    name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
    source = source_info_util.current().replace(name_stack=name_stack)
    eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers,
                            out_tracers, cond_p, params, core.no_effects,
                            source)
    for t in out_tracers:
        t.recipe = eqn
    return util.merge_lists(out_uks, out_consts, out_tracers)
Exemplo n.º 25
0
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
                                const_lengths, jaxprs):
    orig_bat = [d is not batching.not_mapped for d in dims]

    params, b = _split_linear_solve_args(args, const_lengths)
    params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
    params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)

    (matvec, vecmat, solve, solve_t) = jaxprs
    (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat

    num_aux = len(solve.out_avals) - len(matvec.out_avals)
    # Fixpoint computation of which parts of x and b are batched; we need to
    # ensure this is consistent between all four jaxprs
    b_bat = orig_b_bat
    x_bat = [False] * len(solve.out_avals)
    for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
        # Apply vecmat and solve -> new batched parts of x
        solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
            solve,
            axis_size,
            solve_bat + b_bat,
            instantiate=x_bat,
            axis_name=axis_name,
            main_type=main_type)
        if vecmat is None:
            vecmat_jaxpr_batched = None
            x_bat_out = solve_x_bat
        else:
            vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
                vecmat,
                axis_size,
                vecmat_bat + b_bat,
                instantiate=x_bat,
                axis_name=axis_name,
                main_type=main_type)
            # batch all aux data by default
            x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux,
                             solve_x_bat)

        # Apply matvec and solve_t -> new batched parts of b
        matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
            matvec,
            axis_size,
            matvec_bat + x_bat_out,
            instantiate=b_bat,
            axis_name=axis_name,
            main_type=main_type)
        if solve_t is None:
            solve_t_jaxpr_batched = None
            b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
        else:
            solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
                solve_t,
                axis_size,
                solve_t_bat + x_bat_out,
                instantiate=b_bat,
                axis_name=axis_name,
                main_type=main_type)
            assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
            solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
            b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat,
                             solve_t_b_bat, orig_b_bat)
        if x_bat_out == x_bat and b_bat_out == b_bat:
            break
        else:
            x_bat = x_bat_out
            b_bat = b_bat_out
    else:
        assert False, "Fixedpoint not reached"

    batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched,
                                       vecmat_jaxpr_batched,
                                       solve_jaxpr_batched,
                                       solve_t_jaxpr_batched)

    # Move batched axes to the front
    new_params = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(_flatten(params), _flatten(params_dims))
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims
Exemplo n.º 26
0
def _split_root_args(args, const_lengths):
    params_list = split_list(args, list(const_lengths))
    return _RootTuple(*params_list[:-1]), params_list[-1]
Exemplo n.º 27
0
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
  if all(type(ct) is ad.Zero for ct in cotangents_in):
    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)

  def write_cotangent(v, ct):
    # assert v not in primal_env
    if ct is not None and type(v) is not Literal:
      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct

  def read_cotangent(v):
    return ct_env.get(v, ad.Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, ad.UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if type(v) is Literal:
      return
    primal_env.setdefault(v, val)

  # Invert while computing cotangents
  ct_env: Dict[Any, Any] = {}
  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.invars, primals_in)
  map(write_primal, jaxpr.outvars, primals_out)
  map(write_primal, jaxpr.constvars, consts)
  map(write_cotangent, jaxpr.outvars, cotangents_in)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.invars)
    primals_out = map(read_primal, eqn.outvars)
    cts_in = map(read_cotangent, eqn.outvars)
    should_invert = any(type(primal) is not ad.UndefinedPrimal
                        for primal in primals_out)
    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
    assert not eqn.primitive.call_primitive

    # Skip primals equations that are only jvp coefficients and don't affect
    # primal outputs.
    if not should_invert and not should_vjp:
      continue

    def abstract(value):
      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))

    # Get the ivjp_jaxpr
    if eqn.primitive is custom_ivjp_p:
      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
    else:
      if eqn.primitive in primitive_ivjps:
        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
      else:
        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
      _, in_tree = tree_flatten(
          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)

      in_avals = map(abstract, primals_in + primals_out + primals_out)
      # TODO: Actually we do know some of the inputs, because they might be literals!
      ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
      ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])

    # Once we know what the ivjp can do exactly, we have to isolate the part we are
    # actually able to compute with the values we have at hand.
    num_inputs = len(eqn.invars)
    unknowns = (map(ad.is_undefined_primal, primals_in) +
                map(ad.is_undefined_primal, primals_out) +
                [False] * len(cts_in))
    jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
        ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
    # Make sure we're able to compute all cotangents. We don't really care if we
    # can reconstruct or primals or not, although failure to do so might result in
    # failing to compute cotangents later.
    assert not any(unknown_cotangents)
    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
    # TODO: We could drop the outputs that correspond to primals that we already know.
    #       This only matters in eager mode, so leaving it out for now...
    ivjp = core.jaxpr_as_fun(jaxpr_known)
    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
                                         [num_inputs])
    # Unknown rec_primals_in are core.units, so we have to replace them
    # with UnknownPrimals because that's what write_primal will ignore.
    rec_primals_in = [prev if unknown else rec
                      for prev, rec, unknown
                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
    map(write_primal, eqn.invars, rec_primals_in)
    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)

  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
  #       transpose is to return them in positions associated with tangent variables, which
  #       is what causes this whole confusion.
  return map(read_cotangent, jaxpr.invars)
Exemplo n.º 28
0
 def new_body(carry, x):
     flat_args = tree_leaves((carry, x))
     out = body_fun(*(const_vals + flat_args))
     out_carry, y = split_list(out, [num_carry])
     return out_carry, y
Exemplo n.º 29
0
 def partitioner(bufs):
   dim_bufs, *grouped_bufs = split_list(bufs, split_sizes)
   dims_dict = dict(it.chain(
       zip(in_dim_binders, in_dim_vals),
       zip(out_dimvars, map(dim_handler, out_dimvars, dim_bufs))))
   return dims_dict, grouped_bufs
Exemplo n.º 30
0
 def _hoist(i, *consts_args):
     const_refs, args = split_list(consts_args, [num_consts])
     # We immediately read the const values out of the `Ref`s.
     consts = [r[()] for r in const_refs]
     return core.eval_jaxpr(jaxpr, consts, i, *args)