示例#1
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
          else x for x, d in zip(args, dims)]
  true_nops = len(true_jaxpr.in_avals) - true_nconsts
  (pred,), true_consts, true_ops, false_consts, false_ops = split_list(
      args, [1, true_nconsts, true_nops, false_nconsts])
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_bat = [d is not batching.not_mapped for d in dims]
  (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list(
    orig_bat, [1, true_nconsts, true_nops, false_nconsts])

  _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
  _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
  out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

  true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
  false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)

  if pred_bat:
    true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
    false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
    true_out = [batching.broadcast(x, size, 0) if not b else x
                for x, b in zip(true_out, out_bat)]
    false_out = [batching.broadcast(x, size, 0) if not b else x
                 for x, b in zip(false_out, out_bat)]
    return [_cond_pred_bcast_select(pred, t, f)
            for t, f in zip(true_out, false_out)], [0] * len(true_out)
  else:
    out_dims = [0 if b else batching.not_mapped for b in out_bat]
    return cond_p.bind(
      *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
      true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
      true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
示例#2
0
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
                              body_nconsts, body_jaxpr):
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_batched = [d is not batching.not_mapped for d in dims]
  cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])

  carry_bat = init_bat
  for _ in range(1000):
    batched = bconst_bat + carry_bat
    body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
        body_jaxpr, size, batched, instantiate=carry_bat)
    cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
        cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
    carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
    if carry_bat_out == carry_bat:
      break
    else:
      carry_bat = carry_bat_out
  else:
    raise FixedPointError

  consts, init = split_list(args, [cond_nconsts + body_nconsts])
  const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
                else x for x, d in zip(consts, const_dims)]
  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
              else batching.moveaxis(x, d, 0) if now_bat else x
              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]

  outs = while_p.bind(*(new_consts + new_init),
                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
  return outs, out_bdims
示例#3
0
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches,
                        linear):
    index, *ops = args
    index_dim, *op_dims = dims

    if index_dim is not batching.not_mapped:
        # Convert to a lax.select. While we could get away with not broadcasting
        # some operands yet, because all outputs must be broadcast together anyway
        # for the select we broadcast the input operands for simplicity and leave
        # optimizations to XLA.
        # TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
        index, *ops = (batching.bdim_at_front(x, d, axis_size)
                       for x, d in zip(args, dims))

        in_batched = [True] * len(branches[0].in_avals)
        out_batched = [True] * len(branches[0].out_avals)

        branches_batched = [
            batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched,
                                 axis_name, main_type)[0] for jaxpr in branches
        ]

        branch_outs = []
        for i, jaxpr in enumerate(branches_batched):
            # Perform a select on the inputs for safety of reverse-mode autodiff; see
            # https://github.com/google/jax/issues/1052
            predicate = lax.eq(index, lax._const(index, i))
            ops_ = [
                _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops
            ]
            branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
        out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
        return out, [0 if b else None for b in out_batched]
    else:
        ops_bat = [d is not batching.not_mapped for d in op_dims]
        ops = [
            batching.moveaxis(x, d, 0) if b else x
            for b, x, d in zip(ops_bat, ops, op_dims)
        ]

        branches_out_bat = [
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name,
                                 main_type)[1] for jaxpr in branches
        ]
        out_bat = [any(bat) for bat in zip(*branches_out_bat)]
        branches_batched = tuple(
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name,
                                 main_type)[0] for jaxpr in branches)

        out_dims = [0 if b else batching.not_mapped for b in out_bat]
        out = cond_p.bind(index,
                          *ops,
                          branches=branches_batched,
                          linear=linear)
        return out, out_dims
示例#4
0
 def batched_jvp_jaxpr_thunk():
   jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk())  # consts can be tracers
   _, args_batched = split_list(in_batched, [num_consts])
   _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
                                         axis_name, main_type)
   primals_batched, tangents_batched = split_list(all_batched, [num_out])
   out_batched = map(op.or_, primals_batched, tangents_batched)
   out_dims2.append([0 if b else not_mapped for b in out_batched])
   batched_jvp_jaxpr, _ = batching.batch_jaxpr(
       jvp_jaxpr, size, args_batched * 2, out_batched * 2,
       axis_name, main_type)
   return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
示例#5
0
 def batched_fwd_jaxpr_thunk():
     fwd_jaxpr = core.ClosedJaxpr(
         *fwd_jaxpr_thunk())  # consts can be tracers
     batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
         fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
     out_dims2.append([0 if b else not_mapped for b in out_batched])
     return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
示例#6
0
def _custom_vjp_call_jaxpr_vmap(
    args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
    fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
    bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
  axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
  args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
          else x for x, d in zip(args, in_dims)]

  in_batched = [d is not not_mapped for d in in_dims]
  _, args_batched = split_list(in_batched, [num_consts])
  batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
      fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
  out_dims1 = [0 if b else not_mapped for b in out_batched]
  out_dims2 = []

  @pe._memoize
  def batched_fwd_jaxpr_thunk():
    fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk())  # consts can be tracers
    batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
        fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
    out_dims2.append([0 if b else not_mapped for b in out_batched])
    return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts

  fwd_args_batched = [0 if b else not_mapped for b in args_batched]
  fwd_out_dims = lambda: out_dims2[0]
  batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
                                              fwd_args_batched, main_type)

  batched_outs = custom_vjp_call_jaxpr_p.bind(
      *args, fun_jaxpr=batched_fun_jaxpr,
      fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
      out_trees=out_trees, num_consts=num_consts)
  out_dims = out_dims2[0] if out_dims2 else out_dims1
  return batched_outs, out_dims
示例#7
0
def _custom_jvp_call_jaxpr_vmap(
    args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
    jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
    num_consts: int):
  size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
  args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
          else x for x, d in zip(args, in_dims)]
  num_out = len(fun_jaxpr.out_avals)

  in_batched = [d is not not_mapped for d in in_dims]
  batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
      fun_jaxpr, size, in_batched, False, axis_name, main_type)
  out_dims1 = [0 if b else not_mapped for b in out_batched]
  out_dims2 = []  # mutable cell updated by batched_jvp_jaxpr_thunk

  @pe._memoize
  def batched_jvp_jaxpr_thunk():
    jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk())  # consts can be tracers
    _, args_batched = split_list(in_batched, [num_consts])
    _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
                                          axis_name, main_type)
    primals_batched, tangents_batched = split_list(all_batched, [num_out])
    out_batched = map(op.or_, primals_batched, tangents_batched)
    out_dims2.append([0 if b else not_mapped for b in out_batched])
    batched_jvp_jaxpr, _ = batching.batch_jaxpr(
        jvp_jaxpr, size, args_batched * 2, out_batched * 2,
        axis_name, main_type)
    return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

  batched_outs = custom_jvp_call_jaxpr_p.bind(
      *args, fun_jaxpr=batched_fun_jaxpr,
      jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
  out_dims = out_dims2[0] if out_dims2 else out_dims1
  return batched_outs, out_dims
示例#8
0
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
                        num_carry, linear):
    num_ys = len(jaxpr.out_avals) - num_carry
    size, = {
        x.shape[d]
        for x, d in zip(args, dims) if d is not batching.not_mapped
    }
    orig_batched = [d is not batching.not_mapped for d in dims]
    const_batched, init_batched, xs_batched = split_list(
        orig_batched, [num_consts, num_carry])

    carry_batched = init_batched
    for _ in range(1000):
        batched = const_batched + carry_batched + xs_batched
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
        carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[
            num_carry:]
        if carry_batched_out == carry_batched:
            break
        else:
            carry_batched = carry_batched_out
    else:
        raise FixedPointError

    consts, init, xs = split_list(args, [num_consts, num_carry])
    consts_bdims, init_bdims, xs_bdims = split_list(dims,
                                                    [num_consts, num_carry])
    new_consts = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(consts, consts_bdims)
    ]
    new_init = [
        batching.broadcast(x, size, 0) if now_batched and not was_batched else
        batching.moveaxis(x, d, 0) if now_batched else x
        for x, d, was_batched, now_batched in zip(init, init_bdims,
                                                  init_batched, carry_batched)
    ]
    new_xs = [
        batching.moveaxis(x, d, 1)
        if d is not batching.not_mapped and d != 1 else x
        for x, d in zip(xs, xs_bdims)
    ]
    new_args = new_consts + new_init + new_xs

    outs = scan_p.bind(*new_args,
                       forward=forward,
                       length=length,
                       jaxpr=jaxpr_batched,
                       num_consts=num_consts,
                       num_carry=num_carry,
                       linear=linear)
    carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
    ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
    return outs, carry_bdims + ys_bdims
示例#9
0
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
  assert not jaxpr.constvars
  in_batched = [d is not batching.not_mapped for d in dims]
  jaxpr_ = core.ClosedJaxpr(jaxpr, ())
  jaxpr_batched_, out_batched = batching.batch_jaxpr(
      jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name,
      main_type=main_type)
  jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
  out_dims = [0 if b else None for b in out_batched]
  return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
示例#10
0
def _scan_batching_rule(batched_args, batch_dims, forward, length, jaxpr):
    consts, init, xs = batched_args
    consts_bdim, init_bdim, xs_bdim = batch_dims

    sizes = lax._reduce(set.union,
                        map(batching.dimsize, batch_dims, batched_args))
    size = sizes.pop()
    assert not sizes

    consts_batched = batching.where_batched(consts_bdim)
    init_batched = batching.where_batched(init_bdim)
    xs_batched = batching.where_batched(xs_bdim)

    carry_batched = init_batched
    for _ in range(1000):
        which_batched = (consts_batched, carry_batched, xs_batched)
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, which_batched, instantiate=(carry_batched, False))
        carry_batched_out, ys_batched = batched_out
        if _binary_lattice_eq(carry_batched_out, carry_batched):
            break
        else:
            carry_batched = _binary_lattice_join(carry_batched_out,
                                                 carry_batched)
    else:
        raise FixedPointError

    consts_batched = batching.instantiate_bdim(size, 0, consts_batched,
                                               consts_bdim, consts)
    init_batched = batching.instantiate_bdim(size, 0, carry_batched, init_bdim,
                                             init)
    xs_batched = batching.instantiate_bdim(size, 1, xs_batched, xs_bdim, xs)

    carry_out, ys = scan_p.bind(consts_batched,
                                init_batched,
                                xs_batched,
                                forward=forward,
                                length=length,
                                jaxpr=jaxpr_batched)

    carry_out_bdim = batching.bools_to_bdims(0, carry_batched)
    ys_bdim = batching.bools_to_bdims(1, ys_batched)
    return core.pack((carry_out, ys)), (carry_out_bdim, ys_bdim)
示例#11
0
文件: solves.py 项目: xueeinstein/jax
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
                                const_lengths, jaxprs):
    orig_bat = [d is not batching.not_mapped for d in dims]

    params, b = _split_linear_solve_args(args, const_lengths)
    params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
    params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)

    (matvec, vecmat, solve, solve_t) = jaxprs
    (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat

    num_aux = len(solve.out_avals) - len(matvec.out_avals)
    # Fixpoint computation of which parts of x and b are batched; we need to
    # ensure this is consistent between all four jaxprs
    b_bat = orig_b_bat
    x_bat = [False] * len(solve.out_avals)
    for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
        # Apply vecmat and solve -> new batched parts of x
        solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
            solve,
            axis_size,
            solve_bat + b_bat,
            instantiate=x_bat,
            axis_name=axis_name,
            main_type=main_type)
        if vecmat is None:
            vecmat_jaxpr_batched = None
            x_bat_out = solve_x_bat
        else:
            vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
                vecmat,
                axis_size,
                vecmat_bat + b_bat,
                instantiate=x_bat,
                axis_name=axis_name,
                main_type=main_type)
            # batch all aux data by default
            x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux,
                             solve_x_bat)

        # Apply matvec and solve_t -> new batched parts of b
        matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
            matvec,
            axis_size,
            matvec_bat + x_bat_out,
            instantiate=b_bat,
            axis_name=axis_name,
            main_type=main_type)
        if solve_t is None:
            solve_t_jaxpr_batched = None
            b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
        else:
            solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
                solve_t,
                axis_size,
                solve_t_bat + x_bat_out,
                instantiate=b_bat,
                axis_name=axis_name,
                main_type=main_type)
            assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
            solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
            b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat,
                             solve_t_b_bat, orig_b_bat)
        if x_bat_out == x_bat and b_bat_out == b_bat:
            break
        else:
            x_bat = x_bat_out
            b_bat = b_bat_out
    else:
        assert False, "Fixedpoint not reached"

    batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched,
                                       vecmat_jaxpr_batched,
                                       solve_jaxpr_batched,
                                       solve_t_jaxpr_batched)

    # Move batched axes to the front
    new_params = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(_flatten(params), _flatten(params_dims))
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims