Example #1
0
def _abs_taylor_rule(x, series_in, **params):
  x, = x
  primal_out = lax.abs_p.bind(x, **params)
  negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0))
  fix_sign = lambda y: negs * y
  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out
Example #2
0
 def masked(*args):
   [dynamic_length], consts, [i], carry, xs = split_list(
       args, [1, num_consts, 1, num_carry])
   out = fun(*(consts + carry + xs))
   new_carry, ys = split_list(out, [num_carry])
   new_carry = [lax.select(i < dynamic_length, new_c, c)
                for new_c, c in zip(new_carry, carry)]
   return [i + 1] + new_carry + ys
Example #3
0
def _jaxtupletree_select(pred, on_true, on_false):
  aval = core.get_aval(on_true)
  if type(aval) is core.AbstractTuple:
    return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false))
  elif isinstance(aval, UnshapedArray):
    return lax.select(pred, on_true, on_false)
  else:
    raise TypeError(aval)
Example #4
0
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts,
                        false_nconsts):
    # TODO: maybe avoid moving arg axes to front if we're promoting to select?
    args = [
        batching.moveaxis(x, d, 0)
        if d is not batching.not_mapped and d != 0 else x
        for x, d in zip(args, dims)
    ]
    true_nops = len(true_jaxpr.in_avals) - true_nconsts
    (pred, ), true_consts, true_ops, false_consts, false_ops = split_list(
        args, [1, true_nconsts, true_nops, false_nconsts])
    size, = {
        x.shape[d]
        for x, d in zip(args, dims) if d is not batching.not_mapped
    }
    orig_bat = [d is not batching.not_mapped for d in dims]
    (pred_bat, ), t_bat, tconst_bat, f_bat, fconst_bat = split_list(
        orig_bat,
        [1, true_nconsts, len(true_ops), false_nconsts])

    _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size,
                                           tconst_bat + t_bat, False)
    _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size,
                                            fconst_bat + f_bat, False)
    out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)]

    true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size,
                                                 tconst_bat + t_bat, out_bat)
    false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size,
                                                  fconst_bat + f_bat, out_bat)

    if pred_bat:
        true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts +
                                                           true_ops))
        false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts +
                                                             false_ops))
        true_out = [
            batching.broadcast(x, size, 0) if not b else x
            for x, b in zip(true_out, out_bat)
        ]
        false_out = [
            batching.broadcast(x, size, 0) if not b else x
            for x, b in zip(false_out, out_bat)
        ]
        return [lax.select(pred, t, f)
                for t, f in zip(true_out, false_out)], [0] * len(true_out)
    else:
        out_dims = [0 if b else batching.not_mapped for b in out_bat]
        return cond_p.bind(*itertools.chain([pred], true_consts, true_ops,
                                            false_consts, false_ops),
                           true_jaxpr=true_jaxpr_batched,
                           false_jaxpr=false_jaxpr_batched,
                           true_nconsts=len(true_consts),
                           false_nconsts=len(false_consts)), out_dims
Example #5
0
File: jet.py Project: yangliuy/jax
def _lax_min_taylor_rule(primal_in, series_in):
    x, y = primal_in
    xgy = x < y   # less than mask
    xey = x == y  # equal to mask
    primal_out = lax.select(xgy, x, y)

    def select_min_and_avg_eq(x_i, y_i):
        """Select x where x>y or average when x==y"""
        min_i = lax.select(xgy, x_i, y_i)
        min_i = lax.select(xey, (x_i + y_i)/2, min_i)
        return min_i

    series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
    return primal_out, series_out
Example #6
0
File: jet.py Project: nhanwei/jax
 def select_min_and_avg_eq(x_i, y_i):
     """Select x where x>y or average when x==y"""
     min_i = lax.select(xgy, x_i, y_i)
     min_i = lax.select(xey, (x_i + y_i) / 2, min_i)
     return min_i
Example #7
0
File: jet.py Project: nhanwei/jax
def _select_taylor_rule(primal_in, series_in, **params):
    b, x, y = primal_in
    primal_out = lax.select_p.bind(b, x, y, **params)
    sel = lambda _, x, y: lax.select(b, x, y)
    series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)]
    return primal_out, series_out
Example #8
0
def _cond_pred_bcast_select(pred, x, y):
  bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred))))
  return lax.select(bcast_pred, x, y)