Ejemplo n.º 1
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
          else x for x, d in zip(args, dims)]
  true_nops = len(true_jaxpr.in_avals) - true_nconsts
  (pred,), true_consts, true_ops, false_consts, false_ops = split_list(
      args, [1, true_nconsts, true_nops, false_nconsts])
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_bat = [d is not batching.not_mapped for d in dims]
  (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list(
    orig_bat, [1, true_nconsts, true_nops, false_nconsts])

  _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
  _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
  out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

  true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
  false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)

  if pred_bat:
    true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
    false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
    true_out = [batching.broadcast(x, size, 0) if not b else x
                for x, b in zip(true_out, out_bat)]
    false_out = [batching.broadcast(x, size, 0) if not b else x
                 for x, b in zip(false_out, out_bat)]
    return [_cond_pred_bcast_select(pred, t, f)
            for t, f in zip(true_out, false_out)], [0] * len(true_out)
  else:
    out_dims = [0 if b else batching.not_mapped for b in out_bat]
    return cond_p.bind(
      *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
      true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
      true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
Ejemplo n.º 2
0
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
                              body_nconsts, body_jaxpr):
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_batched = [d is not batching.not_mapped for d in dims]
  cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])

  carry_bat = init_bat
  for _ in range(1000):
    batched = bconst_bat + carry_bat
    body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
        body_jaxpr, size, batched, instantiate=carry_bat)
    cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
        cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
    carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
    if carry_bat_out == carry_bat:
      break
    else:
      carry_bat = carry_bat_out
  else:
    raise FixedPointError

  consts, init = split_list(args, [cond_nconsts + body_nconsts])
  const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
                else x for x, d in zip(consts, const_dims)]
  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
              else batching.moveaxis(x, d, 0) if now_bat else x
              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]

  outs = while_p.bind(*(new_consts + new_init),
                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
  return outs, out_bdims
Ejemplo n.º 3
0
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
                        num_carry, linear):
    num_ys = len(jaxpr.out_avals) - num_carry
    size, = {
        x.shape[d]
        for x, d in zip(args, dims) if d is not batching.not_mapped
    }
    orig_batched = [d is not batching.not_mapped for d in dims]
    const_batched, init_batched, xs_batched = split_list(
        orig_batched, [num_consts, num_carry])

    carry_batched = init_batched
    for _ in range(1000):
        batched = const_batched + carry_batched + xs_batched
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
        carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[
            num_carry:]
        if carry_batched_out == carry_batched:
            break
        else:
            carry_batched = carry_batched_out
    else:
        raise FixedPointError

    consts, init, xs = split_list(args, [num_consts, num_carry])
    consts_bdims, init_bdims, xs_bdims = split_list(dims,
                                                    [num_consts, num_carry])
    new_consts = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(consts, consts_bdims)
    ]
    new_init = [
        batching.broadcast(x, size, 0) if now_batched and not was_batched else
        batching.moveaxis(x, d, 0) if now_batched else x
        for x, d, was_batched, now_batched in zip(init, init_bdims,
                                                  init_batched, carry_batched)
    ]
    new_xs = [
        batching.moveaxis(x, d, 1)
        if d is not batching.not_mapped and d != 1 else x
        for x, d in zip(xs, xs_bdims)
    ]
    new_args = new_consts + new_init + new_xs

    outs = scan_p.bind(*new_args,
                       forward=forward,
                       length=length,
                       jaxpr=jaxpr_batched,
                       num_consts=num_consts,
                       num_carry=num_carry,
                       linear=linear)
    carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
    ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
    return outs, carry_bdims + ys_bdims
Ejemplo n.º 4
0
def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
                                const_lengths, jaxprs):
    orig_bat = [d is not batching.not_mapped for d in dims]

    params, b = _split_linear_solve_args(args, const_lengths)
    params_dims, b_dims = _split_linear_solve_args(dims, const_lengths)
    params_bat, orig_b_bat = _split_linear_solve_args(orig_bat, const_lengths)

    (matvec, vecmat, solve, solve_t) = jaxprs
    (matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat

    num_aux = len(solve.out_avals) - len(matvec.out_avals)
    # Fixpoint computation of which parts of x and b are batched; we need to
    # ensure this is consistent between all four jaxprs
    b_bat = orig_b_bat
    x_bat = [False] * len(solve.out_avals)
    for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
        # Apply vecmat and solve -> new batched parts of x
        solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
            solve,
            axis_size,
            solve_bat + b_bat,
            instantiate=x_bat,
            axis_name=axis_name,
            main_type=main_type)
        if vecmat is None:
            vecmat_jaxpr_batched = None
            x_bat_out = solve_x_bat
        else:
            vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
                vecmat,
                axis_size,
                vecmat_bat + b_bat,
                instantiate=x_bat,
                axis_name=axis_name,
                main_type=main_type)
            # batch all aux data by default
            x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux,
                             solve_x_bat)

        # Apply matvec and solve_t -> new batched parts of b
        matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
            matvec,
            axis_size,
            matvec_bat + x_bat_out,
            instantiate=b_bat,
            axis_name=axis_name,
            main_type=main_type)
        if solve_t is None:
            solve_t_jaxpr_batched = None
            b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
        else:
            solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
                solve_t,
                axis_size,
                solve_t_bat + x_bat_out,
                instantiate=b_bat,
                axis_name=axis_name,
                main_type=main_type)
            assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
            solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
            b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat,
                             solve_t_b_bat, orig_b_bat)
        if x_bat_out == x_bat and b_bat_out == b_bat:
            break
        else:
            x_bat = x_bat_out
            b_bat = b_bat_out
    else:
        assert False, "Fixedpoint not reached"

    batched_jaxprs = _LinearSolveTuple(matvec_jaxpr_batched,
                                       vecmat_jaxpr_batched,
                                       solve_jaxpr_batched,
                                       solve_t_jaxpr_batched)

    # Move batched axes to the front
    new_params = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(_flatten(params), _flatten(params_dims))
    ]
    # Broadcast out b if necessary
    new_b = [
        batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else
        batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
        for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
    ]

    outs = linear_solve_p.bind(*(new_params + new_b),
                               const_lengths=const_lengths,
                               jaxprs=batched_jaxprs)
    out_dims = [
        0 if batched else batching.not_mapped for batched in solve_x_bat
    ]
    return outs, out_dims