示例#1
0
def _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr,
                                 num_consts, num_carry, linear):
    const_shexprs, init_shexprs, xs_shexprs = split_list(
        shape_exprs, [num_consts, num_carry])
    _, y_avals = split_list(jaxpr.out_avals, [num_carry])
    ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
    return init_shexprs + ys_shapes
示例#2
0
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
              linear):
    num_xs = len(jaxpr.in_avals) - num_carry - num_consts
    num_ys = len(jaxpr.out_avals) - num_carry
    nonzeros = [t is not ad_util.zero for t in tangents]
    const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])

    carry_nz = init_nz
    for _ in range(1000):
        nonzeros = const_nz + carry_nz + xs_nz
        jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(jaxpr,
                                               nonzeros,
                                               instantiate=carry_nz +
                                               [False] * num_ys)
        carry_nz_out, ys_nz = nonzeros_out[:num_carry], nonzeros_out[
            num_carry:]
        if carry_nz_out == carry_nz:
            break
        else:
            carry_nz = carry_nz_out
    else:
        raise FixedPointError
    tangents = [
        ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
        for x, t, nz in zip(primals, tangents, nonzeros)
    ]

    consts, init, xs = split_list(primals, [num_consts, num_carry])
    all_tangents = split_list(tangents, [num_consts, num_carry])
    consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)

    jaxpr_jvp_rearranged = ad.rearrange_binders(
        jaxpr_jvp, [num_consts, num_carry, num_xs],
        [len(consts_dot), len(init_dot),
         len(xs_dot)], [num_carry, num_ys],
        [len(init_dot), sum(nonzeros_out) - len(init_dot)])

    consts_linear, init_linear, xs_linear = split_list(linear,
                                                       [num_consts, num_carry])
    jaxpr_jvp_linear = (consts_linear + [True] * len(consts_dot) +
                        init_linear + [True] * len(init_dot) + xs_linear +
                        [True] * len(xs_dot))

    out_flat = scan_p.bind(*(consts + consts_dot + init + init_dot + xs +
                             xs_dot),
                           forward=forward,
                           length=length,
                           jaxpr=jaxpr_jvp_rearranged,
                           num_consts=num_consts + len(consts_dot),
                           num_carry=num_carry + len(init_dot),
                           linear=jaxpr_jvp_linear)

    carry, carry_dot, ys, ys_dot = split_list(
        out_flat, [num_carry, len(init_dot), num_ys])
    primals_out = carry + ys
    tangents_out = iter(carry_dot + ys_dot)
    tangents_out = [
        next(tangents_out) if nz else ad_util.zero for nz in nonzeros_out
    ]
    return primals_out, tangents_out
示例#3
0
def _scan_sparse(spenv, *spvalues, jaxpr, num_consts, num_carry, **params):
  const_spvalues, carry_spvalues, xs_spvalues = split_list(
    spvalues, [num_consts, num_carry])
  if xs_spvalues:
    # TODO(jakevdp): we don't want to pass xs_spvalues, we want to pass one row
    # of xs spvalues. How to do this?
    raise NotImplementedError("sparse rule for scan with x values.")
  sp_jaxpr, _ = _sparsify_jaxpr(spenv, jaxpr, *const_spvalues, *carry_spvalues, *xs_spvalues)

  consts, _ = tree_flatten(spvalues_to_arrays(spenv, const_spvalues))
  carry, carry_tree = tree_flatten(spvalues_to_arrays(spenv, carry_spvalues))
  xs, xs_tree = tree_flatten(spvalues_to_arrays(spenv, xs_spvalues))

  # params['linear'] has one entry per arg; expand it to match the sparsified args.
  const_linear, carry_linear, xs_linear = split_list(
    params.pop('linear'), [num_consts, num_carry])
  sp_linear = tuple([
    *_duplicate_for_sparse_spvalues(const_spvalues, const_linear),
    *_duplicate_for_sparse_spvalues(carry_spvalues, carry_linear),
    *_duplicate_for_sparse_spvalues(xs_spvalues, xs_linear)])

  out = lax.scan_p.bind(*consts, *carry, *xs, jaxpr=sp_jaxpr, linear=sp_linear,
                        num_consts=len(consts), num_carry=len(carry), **params)
  carry_out = tree_unflatten(carry_tree, out[:len(carry)])
  xs_out = tree_unflatten(xs_tree, out[len(carry):])
  return arrays_to_spvalues(spenv, carry_out + xs_out)
示例#4
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
  # TODO: maybe avoid moving arg axes to front if we're promoting to select?
  args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
          else x for x, d in zip(args, dims)]
  true_nops = len(true_jaxpr.in_avals) - true_nconsts
  (pred,), true_consts, true_ops, false_consts, false_ops = split_list(
      args, [1, true_nconsts, true_nops, false_nconsts])
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_bat = [d is not batching.not_mapped for d in dims]
  (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list(
    orig_bat, [1, true_nconsts, true_nops, false_nconsts])

  _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False)
  _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False)
  out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

  true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat)
  false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat)

  if pred_bat:
    true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops))
    false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops))
    true_out = [batching.broadcast(x, size, 0) if not b else x
                for x, b in zip(true_out, out_bat)]
    false_out = [batching.broadcast(x, size, 0) if not b else x
                 for x, b in zip(false_out, out_bat)]
    return [_cond_pred_bcast_select(pred, t, f)
            for t, f in zip(true_out, false_out)], [0] * len(true_out)
  else:
    out_dims = [0 if b else batching.not_mapped for b in out_bat]
    return cond_p.bind(
      *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops),
      true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched,
      true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
示例#5
0
def _scan_transpose(cts, *args, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])

  # we can only transpose scans for which the nonlinear values appear in xs
  consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
  num_lin = sum(xs_lin)
  if not all(consts_lin) or not all(init_lin) or not all(xs_lin[:num_lin]):
    raise NotImplementedError

  consts, init, xs, res = split_list(args, [num_consts, num_carry, num_lin])
  assert not any(r is ad.undefined_primal for r in res)

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  ct_carry, ct_ys = split_list(cts, [num_carry])
  ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
  ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
  ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[:num_consts])

  #       jaxpr :: [T d] -> [T c] -> [T a, res] -> ([T c], [T b])
  # jaxpr_trans :: [] -> [CT d, CT c] -> [CT b, res] -> ([CT d, CT c], [CT a])
  jaxpr_trans = _transpose_jaxpr(num_consts, len(res), jaxpr)
  linear_trans = ([True] * (len(ct_consts) + len(ct_carry) + len(ct_ys))
                  + [False] * len(res))

  outs = scan_p.bind(
      *(ct_consts + ct_carry + ct_ys + res), forward=not forward, length=length,
      jaxpr=jaxpr_trans, num_consts=0, num_carry=num_consts+num_carry,
      linear=linear_trans)
  ct_consts, ct_init, ct_xs = split_list(outs, [num_consts, num_carry])
  return ct_consts + ct_init + ct_xs + [None] * len(res)
示例#6
0
def scan_bind(*args, **kwargs):
    forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
        kwargs,
        ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
    consts, init, xs = split_list(args, [num_consts, num_carry])
    assert len(linear) == len(args)

    # check that args match input types
    consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals,
                                                   [num_consts, num_carry])
    xs_avals = _map(partial(_promote_aval_rank, length), x_avals)
    assert all(_map(typecheck, consts_avals, consts))
    assert all(_map(typecheck, init_avals, init))
    # assert all(_map(typecheck, xs_avals, xs))
    # check that output carry type matches input carry type
    carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
    assert all(_map(typematch, init_avals, carry_avals))

    # check that the data flow is sensible
    core.check_jaxpr(jaxpr.jaxpr)

    return core.Primitive.bind(scan_p,
                               *args,
                               forward=forward,
                               length=length,
                               jaxpr=jaxpr,
                               num_consts=num_consts,
                               num_carry=num_carry,
                               linear=linear)
示例#7
0
def _while_loop_batching_rule(args, dims, cond_nconsts, cond_jaxpr,
                              body_nconsts, body_jaxpr):
  size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
  orig_batched = [d is not batching.not_mapped for d in dims]
  cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])

  carry_bat = init_bat
  for _ in range(1000):
    batched = bconst_bat + carry_bat
    body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
        body_jaxpr, size, batched, instantiate=carry_bat)
    cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
        cond_jaxpr, size, cconst_bat + carry_bat, instantiate=False)
    carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
    if carry_bat_out == carry_bat:
      break
    else:
      carry_bat = carry_bat_out
  else:
    raise FixedPointError

  consts, init = split_list(args, [cond_nconsts + body_nconsts])
  const_dims, init_dims = split_list(dims, [cond_nconsts + body_nconsts])
  new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
                else x for x, d in zip(consts, const_dims)]
  new_init = [batching.broadcast(x, size, 0) if now_bat and not was_bat
              else batching.moveaxis(x, d, 0) if now_bat else x
              for x, d, was_bat, now_bat in zip(init, init_dims, init_bat, carry_bat)]

  outs = while_p.bind(*(new_consts + new_init),
                      cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
                      body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  out_bdims = [0 if b else batching.not_mapped for b in carry_bat]
  return outs, out_bdims
示例#8
0
 def body_fun(i, vals):
   i = i if forward else length - i - 1
   carry, ys = split_list(vals, [num_carry])
   x = _map(partial(_index_array, i), x_avals, xs)
   out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
   carry_out, y_updates = split_list(out_flat, [num_carry])
   ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates)
   return carry_out + ys_out
示例#9
0
 def masked(*args):
   [dynamic_length], consts, [i], carry, xs = split_list(
       args, [1, num_consts, 1, num_carry])
   out = fun(*(consts + carry + xs))
   new_carry, ys = split_list(out, [num_carry])
   new_carry = [lax.select(i < dynamic_length, new_c, c)
                for new_c, c in zip(new_carry, carry)]
   return [i + 1] + new_carry + ys
示例#10
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    cond_outs = cond_c.Call(
        xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals,
                              (), *_map(cond_c.GetShape, x + z)), x + z)
    pred = cond_c.GetTupleElement(cond_outs, 0)
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    body_out = body_c.Call(
        xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals,
                              (), *_map(body_c.GetShape, y + z)), y + z)
    new_z = [
        body_c.GetTupleElement(body_out, i) for i in range(len(init_vals))
    ]
    if batched:
        body_cond_outs = body_c.Call(
            xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env,
                                  cond_jaxpr.literals, (),
                                  *_map(body_c.GetShape, x + z)), x + z)
        body_pred = body_c.GetTupleElement(body_cond_outs, 0)
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*(x + y + new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
示例#11
0
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
                        num_carry, linear):
    num_ys = len(jaxpr.out_avals) - num_carry
    size, = {
        x.shape[d]
        for x, d in zip(args, dims) if d is not batching.not_mapped
    }
    orig_batched = [d is not batching.not_mapped for d in dims]
    const_batched, init_batched, xs_batched = split_list(
        orig_batched, [num_consts, num_carry])

    carry_batched = init_batched
    for _ in range(1000):
        batched = const_batched + carry_batched + xs_batched
        jaxpr_batched, batched_out = batching.batch_jaxpr(
            jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys)
        carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[
            num_carry:]
        if carry_batched_out == carry_batched:
            break
        else:
            carry_batched = carry_batched_out
    else:
        raise FixedPointError

    consts, init, xs = split_list(args, [num_consts, num_carry])
    consts_bdims, init_bdims, xs_bdims = split_list(dims,
                                                    [num_consts, num_carry])
    new_consts = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(consts, consts_bdims)
    ]
    new_init = [
        batching.broadcast(x, size, 0) if now_batched and not was_batched else
        batching.moveaxis(x, d, 0) if now_batched else x
        for x, d, was_batched, now_batched in zip(init, init_bdims,
                                                  init_batched, carry_batched)
    ]
    new_xs = [
        batching.moveaxis(x, d, 1)
        if d is not batching.not_mapped and d != 1 else x
        for x, d in zip(xs, xs_bdims)
    ]
    new_args = new_consts + new_init + new_xs

    outs = scan_p.bind(*new_args,
                       forward=forward,
                       length=length,
                       jaxpr=jaxpr_batched,
                       num_consts=num_consts,
                       num_carry=num_carry,
                       linear=linear)
    carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
    ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
    return outs, carry_bdims + ys_bdims
示例#12
0
 def transposed(*cbar_bbar_res):
   c_bar, b_bar, res = split_list(cbar_bbar_res, [num_c, num_b])
   primals = [ad.undefined_primal] * (num_c + num_a) + res
   _, cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.literals, (), primals,
                                   b_bar)
   new_c_bar, a_bar, _ = split_list(cbar_abar, [num_c, num_a])
   a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
   c_bar = _map(ad.instantiate_zeros_aval, c_avals,
               _map(ad.add_tangents, c_bar, new_c_bar))
   return c_bar + a_bar
示例#13
0
def _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr,
                                 num_consts, num_carry, linear):
    const_shexprs, init_shexprs, xs_shexprs = split_list(
        shape_exprs, [num_consts, num_carry])
    if (any(any(type(d) is Id for d in shexpr) for shexpr in const_shexprs)
            or any(
                any(type(d) is Id for d in shexpr) for shexpr in init_shexprs)
            or any(
                any(type(d) is Id for d in shexpr[1:])
                for shexpr in xs_shexprs)):
        raise NotImplementedError
    _, y_avals = split_list(jaxpr.out_avals, [num_carry])
    ys_shapes = [ShapeExpr(length, *y_aval.shape) for y_aval in y_avals]
    return init_shexprs + ys_shapes
示例#14
0
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
                       jaxpr, num_consts, num_carry, linear):
  out_shape = _scan_polymorphic_shape_rule(shape_exprs, forward, length, jaxpr,
                                           num_consts, num_carry, linear)
  dynamic_length = masking.eval_dim_expr(shape_envs.logical, length)
  masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
  consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
  max_length, = {x.shape[0] for x in xs}
  const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
  out_vals = scan_p.bind(
      *itertools.chain([dynamic_length] + consts, [0], init, xs),
      forward=forward, length=max_length, jaxpr=masked_jaxpr,
      num_consts=1 + num_consts, num_carry=1 + num_carry,
      linear=[False] + const_linear + [False] + init_linear + xs_linear)
  return out_vals[1:], out_shape
示例#15
0
def _scan_partial_eval(trace, *tracers, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
  num_ys = len(jaxpr.out_avals) - num_carry

  unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

  carry_uk = init_uk
  for _ in range(1000):
    unknowns = const_uk + carry_uk + xs_uk
    jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
    carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
    if carry_uk_out == carry_uk:
      break
    else:
      carry_uk = carry_uk_out
  else:
    raise FixedPointError

  in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)]
  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
                 for uk, t in zip(unknowns, tracers)]

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  out_avals = carry_avals + ys_avals
  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

  linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
  out_flat = scan_p.bind(
      *in_consts, forward=forward, length=length, jaxpr=jaxpr_1,
      num_consts=num_consts, num_carry=num_carry, linear=linear_1)
  out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
  out_consts = out_carry + ys
  residual_tracers = _map(trace.new_instantiated_const, residuals)
  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
                 for pv, const in zip(out_pvs, out_consts)]
  linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)]
              + [False] * len(residual_tracers))
  eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p,
                         (), dict(forward=forward, length=length, jaxpr=jaxpr_2,
                                  num_consts=num_consts, num_carry=num_carry,
                                  linear=linear_2))
  for t in out_tracers: t.recipe = eqn
  return out_tracers
示例#16
0
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
    const_cells, incells = jax_util.split_list(incells, [num_consts])
    env, state = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr,
                                     const_cells, incells, outcells)  # pytype: disable=wrong-arg-types
    new_incells = [env.read(invar) for invar in jaxpr.invars]
    new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
    return const_cells + new_incells, new_outcells, state
示例#17
0
def _cond_translation_rule(c, axis_env, pred, *args, **kwargs):
    backend = kwargs.pop("backend", None)
    true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
        kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
    true_nops = len(true_jaxpr.in_avals) - true_nconsts
    true_consts, true_ops, false_consts, false_ops = split_list(
        args, [true_nconsts, true_nops, false_nconsts])

    def make_computation(name, jaxpr, op_shape):
        c = xb.make_computation_builder(name)
        op = c.ParameterWithShape(op_shape)
        ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
        out = c.Call(
            xla.jaxpr_computation(jaxpr.jaxpr, backend,
                                  axis_env, jaxpr.literals, (),
                                  *_map(c.GetShape, ops)), ops)
        return c.Build(out)

    true_op = c.Tuple(*(true_consts + true_ops))
    true_c = make_computation("true_comp", true_jaxpr, c.GetShape(true_op))

    false_op = c.Tuple(*(false_consts + false_ops))
    false_c = make_computation("false_comp", false_jaxpr, c.GetShape(false_op))

    return c.Conditional(pred, true_op, true_c, false_op, false_c)
示例#18
0
 def go_scan(body, length, xs, init, consts, reverse):
     num_carry = len(init)
     if xs is None:
         xs = [None] * length
     if reverse:
         xs = list(map(_reverse, xs))
     carry = init
     ys = []
     zxs = _zip(xs)
     for x in zxs:
         res = _interpret_jaxpr(body, (), *consts, *carry, *x)
         carry, y = split_list(res, [num_carry])
         ys.append(y)
     _, yavals = split_list(body.outvars, [num_carry])
     ys = list(map(lambda *x: _stack(*x, reverse), yavals, zip(*ys)))
     return [*carry, *ys]
示例#19
0
def _while_sparse(spenv, *argspecs, cond_jaxpr, cond_nconsts, body_jaxpr,
                  body_nconsts):
    cond_const_argspecs, body_const_argspecs, init_val_argspecs = split_list(
        argspecs, [cond_nconsts, body_nconsts])

    cond_sp_jaxpr, _ = _sparsify_jaxpr(spenv, cond_jaxpr, *cond_const_argspecs,
                                       *init_val_argspecs)
    body_sp_jaxpr, out_tree = _sparsify_jaxpr(spenv, body_jaxpr,
                                              *body_const_argspecs,
                                              *init_val_argspecs)

    cond_consts, _ = tree_flatten(
        argspecs_to_arrays(spenv, cond_const_argspecs))
    body_consts, _ = tree_flatten(
        argspecs_to_arrays(spenv, body_const_argspecs))
    init_vals, _ = tree_flatten(argspecs_to_arrays(spenv, init_val_argspecs))

    out_flat = lax.while_p.bind(*cond_consts,
                                *body_consts,
                                *init_vals,
                                cond_nconsts=len(cond_consts),
                                cond_jaxpr=cond_sp_jaxpr,
                                body_nconsts=len(body_consts),
                                body_jaxpr=body_sp_jaxpr)
    return arrays_to_argspecs(spenv, tree_unflatten(out_tree, out_flat))
示例#20
0
文件: ode.py 项目: tudorcebere/jax
 def converted_fun(y, t, *hconsts_args):
     hoisted_consts, args = split_list(hconsts_args, [num_consts])
     consts = merge(closure_consts, hoisted_consts)
     all_args, in_tree2 = tree_flatten((y, t, *args))
     assert in_tree == in_tree2
     out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
     return tree_unflatten(out_tree, out_flat)
示例#21
0
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
    const_cells, incells = jax_util.split_list(incells, [num_consts])
    env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr,
                              const_cells, incells, outcells)
    new_incells = [env.read(invar) for invar in jaxpr.invars]
    new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
    return const_cells + new_incells, new_outcells, None
示例#22
0
def random_variable_log_prob(flat_incells, val, *, num_consts, in_tree, **_):
    """Registers Oryx distributions with the log_prob transformation."""
    _, flat_incells = jax_util.split_list(flat_incells, [num_consts])
    _, dist = tree_util.tree_unflatten(in_tree, flat_incells)
    if any(not cell.top() for cell in flat_incells[1:]
           if isinstance(val, InverseAndILDJ)):
        return None
    return dist.log_prob(val)
示例#23
0
def _custom_cell_scan_impl(flat_cell, *args, **kwargs):
    """lax_control_flow._scan_impl, but allowing for a custom cell function."""

    reverse, length, num_consts, num_carry, jaxpr, linear, unroll = split_dict(
        kwargs, ["reverse", "length", "num_consts", "num_carry", "jaxpr", "linear", "unroll"])

    consts, init, xs = split_list(args, [num_consts, num_carry])
    _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
    cell_args = consts + init + map(partial(_index_array, 0), x_avals, xs)

    jaxpr, new_consts = _flat_initial_style_jaxpr(wrap_init(flat_cell), _abstractified(cell_args))

    args = list(new_consts) + init + xs
    kwargs['jaxpr'] = jaxpr
    kwargs['num_consts'] = len(new_consts)
    kwargs['linear'] = (False,) * len(args)

    return scan_p.bind(*args, **kwargs)
def random_variable_log_prob(flat_invals, val, **params):
    """Registers Oryx distributions with the log_prob transformation."""
    num_consts = len(flat_invals) - params['num_args']
    _, flat_invals = jax_util.split_list(flat_invals, [num_consts])
    _, dist = tree_util.tree_unflatten(params['in_tree'], flat_invals)
    if any(val.is_unknown() for val in flat_invals[1:]
           if isinstance(val, InverseAndILDJ)):
        return None
    return dist.log_prob(val)
示例#25
0
 def ildj_rule(incells, outcells, *, in_tree, out_tree, num_args, **_):
     # First incell is a wrapped function because prim is a call primitive.
     incells = incells[1:]
     num_consts = len(incells) - num_args
     const_incells, incells = jax_util.split_list(incells, [num_consts])
     out_tree = out_tree()
     if (all(outcell.top() for outcell in outcells)
             and any(not incell.top() for incell in incells)):
         flat_outvals = [outcell.val for outcell in outcells]
         flat_outildjs = [outcell.ildj for outcell in outcells]
         outvals = tree_util.tree_unflatten(out_tree, flat_outvals)
         outildjs = tree_util.tree_unflatten(out_tree, flat_outildjs)
         flat_invals = [
             None if not incell.top() else incell.val
             for incell in incells
         ]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         try:
             new_invals, new_ildjs = f_ildj(invals, outvals, outildjs)
         except NonInvertibleError:
             return const_incells + incells, outcells, None
         # We need to flatten the output from `f_ildj` using
         # `tree_util.tree_flatten` but if the user returns `None` (when
         # inversion is not possible), JAX will remove `None`s from the flattened
         # version and the number of `new_incells` will not match the old
         # `incells`. We use the private `_replace_nones` feature in JAX to
         # replace it with a sentinel that won't be removed when flattening.
         none_ = object()
         new_invals = tree_util._replace_nones(none_, new_invals)  # pylint: disable=protected-access
         new_ildjs = tree_util._replace_nones(none_, new_ildjs)  # pylint: disable=protected-access
         new_flat_invals = tree_util.tree_leaves(new_invals)
         new_flat_ildjs = tree_util.tree_leaves(new_ildjs)
         inslices = [
             NDSlice.new(inval, ildj)
             for inval, ildj in zip(new_flat_invals, new_flat_ildjs)
         ]
         new_incells = []
         for new_flat_inval, old_incell, inslice in zip(
                 new_flat_invals, incells, inslices):
             if new_flat_inval is not none_:
                 new_incells.append(
                     InverseAndILDJ(old_incell.aval, [inslice]))
             else:
                 new_incells.append(old_incell)
         return const_incells + new_incells, outcells, None
     elif (all(incell.top() for incell in incells)
           and any(not outcell.top() for outcell in outcells)):
         flat_invals = [incell.val for incell in incells]
         invals = tree_util.tree_unflatten(in_tree, flat_invals)
         outvals = self(*invals)
         flat_outvals = tree_util.tree_leaves(outvals)
         outcells = [
             InverseAndILDJ.new(outval) for outval in flat_outvals
         ]
         return const_incells + incells, outcells, None
     return const_incells + incells, outcells, None
示例#26
0
def _scan_impl(*args, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])

  consts, init, xs = split_list(args, [num_consts, num_carry])
  _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
  _, y_avals = split_list(jaxpr.out_avals, [num_carry])

  def body_fun(i, vals):
    i = i if forward else length - i - 1
    carry, ys = split_list(vals, [num_carry])
    x = _map(partial(_index_array, i), x_avals, xs)
    out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
    carry_out, y_updates = split_list(out_flat, [num_carry])
    ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates)
    return carry_out + ys_out

  ys_init = _map(partial(_empty_array, length), y_avals)
  return fori_loop(0, length, body_fun, init + ys_init)
示例#27
0
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                     cond_nconsts, body_nconsts):
    """Reaps the body of a while loop to get the reaps of the final iteration."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    _, init_avals = tree_util.tree_map(lambda x: x.aval,
                                       (body_const_tracers, init_tracers))
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')
    reap_avals = {k: v['aval'] for k, v in body_metadata.items()}

    cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)

    def new_cond(carry, _):
        return cond_fun(*(cond_const_vals + carry))

    def new_body(carry, _):
        carry, reaps = call_and_reap(
            body_fun, **reap_settings)(*(body_const_vals + carry))
        return (carry, reaps)

    new_in_avals, new_in_tree = tree_util.tree_flatten(
        (init_avals, reap_avals))
    new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_cond, new_in_tree, tuple(new_in_avals))
    new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, new_in_tree, tuple(new_in_avals))
    dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
                                         reap_avals)
    new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals))
    out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals),
                           cond_nconsts=len(cond_consts),
                           body_nconsts=len(body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_tree, out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode'])
    return out
示例#28
0
 def new_body(carry, x):
     x, plants = x
     all_plants = {**plants, **clobber_plants}
     all_values = const_vals + tree_util.tree_leaves((carry, x))
     out = plant(body_fun,
                 tag=settings.tag,
                 allowlist=settings.allowlist,
                 blocklist=settings.blocklist,
                 exclusive=settings.exclusive)(all_plants, *all_values)
     carry_out, y = jax_util.split_list(out, [num_carry])
     return carry_out, y
示例#29
0
def _cond_impl(pred, *args, **kwargs):
    true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict(
        kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"])
    true_consts, true_ops, false_consts, false_ops = split_list(
        args,
        [true_nconsts, len(true_jaxpr.in_avals), false_nconsts])

    if pred:
        return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops))
    else:
        return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
示例#30
0
def bijector_ildj_rule(incells, outcells, *, in_tree, num_consts, direction,
                       num_bijector, **_):
    """Inverse/ILDJ rule for bijectors."""
    const_incells, flat_incells = jax_util.split_list(incells, [num_consts])
    flat_bijector_cells, arg_incells = jax_util.split_list(
        flat_incells, [num_bijector])
    if any(not cell.top() for cell in flat_bijector_cells):
        return (const_incells + flat_incells, outcells, None)
    flat_inproxies = safe_map(_CellProxy, flat_incells)
    _, inproxy = tree_util.tree_unflatten(in_tree, flat_inproxies)
    bijector_vals = [cell.val for cell in flat_bijector_cells]
    bijector, _ = tree_util.tree_unflatten(
        in_tree, bijector_vals + [None] * len(arg_incells))
    if direction == 'forward':
        forward_func = bijector.forward
        inv_func = bijector.inverse
        ildj_func = bijector.inverse_log_det_jacobian
    elif direction == 'inverse':
        forward_func = bijector.inverse
        inv_func = bijector.forward
        ildj_func = bijector.forward_log_det_jacobian
    else:
        raise ValueError('Bijector direction must be '
                         '"forward" or "inverse".')

    outcell, = outcells
    incell = inproxy.cell
    if incell.bottom() and not outcell.bottom():
        val, ildj = outcell.val, outcell.ildj
        inildj = ildj + ildj_func(val, np.ndim(val))
        ndslice = NDSlice.new(inv_func(val), inildj)
        flat_incells = [InverseAndILDJ(incell.aval, [ndslice])]
        new_outcells = outcells
    elif outcell.is_unknown() and not incell.is_unknown():
        new_outcells = [InverseAndILDJ.new(forward_func(incell.val))]
    new_incells = flat_bijector_cells + flat_incells
    return (const_incells + new_incells, new_outcells, None)