Ejemplo n.º 1
0
def triangular_solve_batching_rule(batched_args, batch_dims, left_side, lower,
                                   transpose_a, conjugate_a, unit_diagonal):
    x, y = batched_args
    bx, by = batch_dims
    if bx is batching.not_mapped:
        if left_side:
            y = batching.moveaxis(y, by, -1)
            y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1], ))
            bdim_out = y.ndim - 1
        else:
            y = batching.moveaxis(y, by, -2)
            y_flat = y.reshape(y.shape[:-3] +
                               (y.shape[-3] * y.shape[-2], y.shape[-1]))
            bdim_out = y.ndim - 2
        out_flat = triangular_solve(x,
                                    y_flat,
                                    left_side=left_side,
                                    lower=lower,
                                    transpose_a=transpose_a,
                                    conjugate_a=conjugate_a,
                                    unit_diagonal=unit_diagonal)
        return out_flat.reshape(y.shape), bdim_out
    else:
        size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
                    if i is not None)
        x = batching.bdim_at_front(x, bx, size)
        y = batching.bdim_at_front(y, by, size)
        return triangular_solve(x,
                                y,
                                left_side=left_side,
                                lower=lower,
                                transpose_a=transpose_a,
                                conjugate_a=conjugate_a,
                                unit_diagonal=unit_diagonal), 0
Ejemplo n.º 2
0
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
                              body_nconsts, body_jaxpr):
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_batched = [d is not batching.not_mapped for d in dims]
  cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])

  carry_bat = init_bat
  for _ in range(1000):
    batched = bconst_bat + carry_bat
    body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
        body_jaxpr, size, batched, instantiate=carry_bat)
    cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
        cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
    carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
    if carry_bat_out == carry_bat:
      break
    else:
      carry_bat = carry_bat_out
  else:
    raise FixedPointError

  consts, init = split_list(args, [cond_nconsts + body_nconsts])
  const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
                else x for x, d in zip(consts, const_dims)]
  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
              else batching.moveaxis(x, d, 0) if now_bat else x
              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]

  outs = while_p.bind(*(new_consts + new_init),
                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
  return outs, out_bdims
Ejemplo n.º 3
0
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
                        num_carry, linear):
    num_ys = len(jaxpr.out_avals) - num_carry
    size, = {
        x.shape[d]
        for x, d in zip(args, dims) if d is not batching.not_mapped
    }
    orig_batched = [d is not batching.not_mapped for d in dims]
    const_batched, init_batched, xs_batched = split_list(
        orig_batched, [num_consts, num_carry])

    carry_batched = init_batched
    for _ in range(1000):
        batched = const_batched + carry_batched + xs_batched
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
        carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[
            num_carry:]
        if carry_batched_out == carry_batched:
            break
        else:
            carry_batched = carry_batched_out
    else:
        raise FixedPointError

    consts, init, xs = split_list(args, [num_consts, num_carry])
    consts_bdims, init_bdims, xs_bdims = split_list(dims,
                                                    [num_consts, num_carry])
    new_consts = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(consts, consts_bdims)
    ]
    new_init = [
        batching.broadcast(x, size, 0) if now_batched and not was_batched else
        batching.moveaxis(x, d, 0) if now_batched else x
        for x, d, was_batched, now_batched in zip(init, init_bdims,
                                                  init_batched, carry_batched)
    ]
    new_xs = [
        batching.moveaxis(x, d, 1)
        if d is not batching.not_mapped and d != 1 else x
        for x, d in zip(xs, xs_bdims)
    ]
    new_args = new_consts + new_init + new_xs

    outs = scan_p.bind(*new_args,
                       forward=forward,
                       length=length,
                       jaxpr=jaxpr_batched,
                       num_consts=num_consts,
                       num_carry=num_carry,
                       linear=linear)
    carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
    ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
    return outs, carry_bdims + ys_bdims
Ejemplo n.º 4
0
def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
                                            permutation_size):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)
  return lu_pivots_to_permutation_p.bind(
      x, permutation_size=permutation_size), 0
Ejemplo n.º 5
0
def _custom_vjp_call_jaxpr_vmap(
    args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
    fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
    bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
  axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
  args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
          else x for x, d in zip(args, in_dims)]

  in_batched = [d is not not_mapped for d in in_dims]
  _, args_batched = split_list(in_batched, [num_consts])
  batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
      fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
  out_dims1 = [0 if b else not_mapped for b in out_batched]
  out_dims2 = []

  @pe._memoize
  def batched_fwd_jaxpr_thunk():
    fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk())  # consts can be tracers
    batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
        fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
    out_dims2.append([0 if b else not_mapped for b in out_batched])
    return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts

  fwd_args_batched = [0 if b else not_mapped for b in args_batched]
  fwd_out_dims = lambda: out_dims2[0]
  batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
                                              fwd_args_batched, main_type)

  batched_outs = custom_vjp_call_jaxpr_p.bind(
      *args, fun_jaxpr=batched_fun_jaxpr,
      fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
      out_trees=out_trees, num_consts=num_consts)
  out_dims = out_dims2[0] if out_dims2 else out_dims1
  return batched_outs, out_dims
Ejemplo n.º 6
0
def _custom_jvp_call_jaxpr_vmap(
    args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
    jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
    num_consts: int):
  size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
  args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
          else x for x, d in zip(args, in_dims)]
  num_out = len(fun_jaxpr.out_avals)

  in_batched = [d is not not_mapped for d in in_dims]
  batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
      fun_jaxpr, size, in_batched, False, axis_name, main_type)
  out_dims1 = [0 if b else not_mapped for b in out_batched]
  out_dims2 = []  # mutable cell updated by batched_jvp_jaxpr_thunk

  @pe._memoize
  def batched_jvp_jaxpr_thunk():
    jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk())  # consts can be tracers
    _, args_batched = split_list(in_batched, [num_consts])
    _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
                                          axis_name, main_type)
    primals_batched, tangents_batched = split_list(all_batched, [num_out])
    out_batched = map(op.or_, primals_batched, tangents_batched)
    out_dims2.append([0 if b else not_mapped for b in out_batched])
    batched_jvp_jaxpr, _ = batching.batch_jaxpr(
        jvp_jaxpr, size, args_batched * 2, out_batched * 2,
        axis_name, main_type)
    return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

  batched_outs = custom_jvp_call_jaxpr_p.bind(
      *args, fun_jaxpr=batched_fun_jaxpr,
      jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
  out_dims = out_dims2[0] if out_dims2 else out_dims1
  return batched_outs, out_dims
Ejemplo n.º 7
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
          else x for x, d in zip(args, dims)]
  true_nops = len(true_jaxpr.in_avals) - true_nconsts
  (pred,), true_consts, true_ops, false_consts, false_ops = split_list(
      args, [1, true_nconsts, true_nops, false_nconsts])
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_bat = [d is not batching.not_mapped for d in dims]
  (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list(
    orig_bat, [1, true_nconsts, true_nops, false_nconsts])

  _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
  _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
  out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

  true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
  false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)

  if pred_bat:
    true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
    false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
    true_out = [batching.broadcast(x, size, 0) if not b else x
                for x, b in zip(true_out, out_bat)]
    false_out = [batching.broadcast(x, size, 0) if not b else x
                 for x, b in zip(false_out, out_bat)]
    return [_cond_pred_bcast_select(pred, t, f)
            for t, f in zip(true_out, false_out)], [0] * len(true_out)
  else:
    out_dims = [0 if b else batching.not_mapped for b in out_bat]
    return cond_p.bind(
      *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
      true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
      true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
Ejemplo n.º 8
0
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)

  return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
                     compute_right_eigenvectors=compute_right_eigenvectors),
          (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))
Ejemplo n.º 9
0
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)

    if compute_uv:
        return outs, (0, 0, 0)
    else:
        return outs, (0, )
Ejemplo n.º 10
0
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches,
                        linear):
    index, *ops = args
    index_dim, *op_dims = dims

    if index_dim is not batching.not_mapped:
        # Convert to a lax.select. While we could get away with not broadcasting
        # some operands yet, because all outputs must be broadcast together anyway
        # for the select we broadcast the input operands for simplicity and leave
        # optimizations to XLA.
        # TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
        index, *ops = (batching.bdim_at_front(x, d, axis_size)
                       for x, d in zip(args, dims))

        in_batched = [True] * len(branches[0].in_avals)
        out_batched = [True] * len(branches[0].out_avals)

        branches_batched = [
            batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched,
                                 axis_name, main_type)[0] for jaxpr in branches
        ]

        branch_outs = []
        for i, jaxpr in enumerate(branches_batched):
            # Perform a select on the inputs for safety of reverse-mode autodiff; see
            # https://github.com/google/jax/issues/1052
            predicate = lax.eq(index, lax._const(index, i))
            ops_ = [
                _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops
            ]
            branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
        out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
        return out, [0 if b else None for b in out_batched]
    else:
        ops_bat = [d is not batching.not_mapped for d in op_dims]
        ops = [
            batching.moveaxis(x, d, 0) if b else x
            for b, x, d in zip(ops_bat, ops, op_dims)
        ]

        branches_out_bat = [
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name,
                                 main_type)[1] for jaxpr in branches
        ]
        out_bat = [any(bat) for bat in zip(*branches_out_bat)]
        branches_batched = tuple(
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name,
                                 main_type)[0] for jaxpr in branches)

        out_dims = [0 if b else batching.not_mapped for b in out_bat]
        out = cond_p.bind(index,
                          *ops,
                          branches=branches_batched,
                          linear=linear)
        return out, out_dims
Ejemplo n.º 11
0
def dex_call_batched(batched_args, batched_dims, func_atom):
    """Batching function for dex primitives.

  Args:
    batched_args: The possibly-batched arguments.
    batched_dims: A sequence of the same length as `batched_args`, where each
      entry indicates the batching axis of the corresponding entry to `args`,
      or None if that argument should not be batched. Not all entries can be
      None.

  Returns:
    2-tuple containing the result of the batched function, and the result axis
    which was batched, which is always zero.
  """
    module = func_atom.module.copy()

    # Move axes so that we only have to deal with the zero axis being batched.
    uniform_batched_args = [
        batching.moveaxis(arg, bd, 0) if bd is not batching.not_mapped else arg
        for arg, bd in zip(batched_args, batched_dims)
    ]

    # This assumes not all entries in batched_dims are None.
    batch_size = next(arg.shape[0]
                      for arg, bd in zip(uniform_batched_args, batched_dims)
                      if bd is not batching.not_mapped)

    # Add the current function atom as a variable in the context, so that we can
    # use it to apply batching.
    func_name = func_atom.name
    assert func_name is not None

    # Only index into the arguments which are batched. `i` is the index used for
    # the Dex for loop constructor.
    batched_fn_params = [
        f"x{param_idx}" if dim is batching.not_mapped else f"x{param_idx}.i"
        for param_idx, dim in enumerate(batched_dims)
    ]

    # This is the actual batching expression
    batched_fn = module.eval(r"\ " +
                             " ".join(f"x{i}"
                                      for i in range(len(batched_args))) +
                             ". " + f"for i:(Fin {batch_size}). {func_name} " +
                             " ".join(batched_fn_params))

    return primitive(batched_fn)(*uniform_batched_args), 0
Ejemplo n.º 12
0
Archivo: ann.py Proyecto: 0x0is1/jax
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
                             reduction_dimension, recall_target, is_max_k,
                             reduction_input_size_override, aggregate_to_topk):
  prototype_arg, new_bdim = next(
      (a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
  new_args = []
  for arg, bdim in zip(batched_args, batch_dims):
    if bdim is None:
      dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
      new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims))
    else:
      new_args.append(batching.moveaxis(arg, bdim, new_bdim))
  new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension)
  bdims = (new_bdim,) * len(new_args)
  return (approx_top_k_p.bind(
      *new_args,
      k=k,
      reduction_dimension=new_reduction_dim,
      recall_target=recall_target,
      is_max_k=False,
      reduction_input_size_override=reduction_input_size_override,
      aggregate_to_topk=aggregate_to_topk), bdims)
Ejemplo n.º 13
0
def qr_batching_rule(batched_args, batch_dims, full_matrices):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
Ejemplo n.º 14
0
def _lu_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return lu_p.bind(x), (0, 0)
Ejemplo n.º 15
0
def eigh_batching_rule(batched_args, batch_dims, lower):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return eigh_p.bind(x, lower=lower), (0, 0)
Ejemplo n.º 16
0
def cholesky_batching_rule(batched_args, batch_dims):
    x, = batched_args
    bd, = batch_dims
    x = batching.moveaxis(x, bd, 0)
    return cholesky(x), 0
Ejemplo n.º 17
0
def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
  x, = batched_args
  bd, = batch_dims
  x = batching.moveaxis(x, bd, 0)
  return fft(x, fft_type, fft_lengths), 0
Ejemplo n.º 18
0
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
                                const_lengths, jaxprs):
    orig_bat = [d is not batching.not_mapped for d in dims]

    params, b = _split_linear_solve_args(args, const_lengths)
    params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
    params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)

    (matvec, vecmat, solve, solve_t) = jaxprs
    (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat

    num_aux = len(solve.out_avals) - len(matvec.out_avals)
    # Fixpoint computation of which parts of x and b are batched; we need to
    # ensure this is consistent between all four jaxprs
    b_bat = orig_b_bat
    x_bat = [False] * len(solve.out_avals)
    for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
        # Apply vecmat and solve -> new batched parts of x
        solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
            solve,
            axis_size,
            solve_bat + b_bat,
            instantiate=x_bat,
            axis_name=axis_name,
            main_type=main_type)
        if vecmat is None:
            vecmat_jaxpr_batched = None
            x_bat_out = solve_x_bat
        else:
            vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
                vecmat,
                axis_size,
                vecmat_bat + b_bat,
                instantiate=x_bat,
                axis_name=axis_name,
                main_type=main_type)
            # batch all aux data by default
            x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux,
                             solve_x_bat)

        # Apply matvec and solve_t -> new batched parts of b
        matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
            matvec,
            axis_size,
            matvec_bat + x_bat_out,
            instantiate=b_bat,
            axis_name=axis_name,
            main_type=main_type)
        if solve_t is None:
            solve_t_jaxpr_batched = None
            b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
        else:
            solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
                solve_t,
                axis_size,
                solve_t_bat + x_bat_out,
                instantiate=b_bat,
                axis_name=axis_name,
                main_type=main_type)
            assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
            solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
            b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat,
                             solve_t_b_bat, orig_b_bat)
        if x_bat_out == x_bat and b_bat_out == b_bat:
            break
        else:
            x_bat = x_bat_out
            b_bat = b_bat_out
    else:
        assert False, "Fixedpoint not reached"

    batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched,
                                       vecmat_jaxpr_batched,
                                       solve_jaxpr_batched,
                                       solve_t_jaxpr_batched)

    # Move batched axes to the front
    new_params = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(_flatten(params), _flatten(params_dims))
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims