def _abs_taylor_rule(x, series_in, **params): x, = x primal_out = lax.abs_p.bind(x, **params) negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0)) fix_sign = lambda y: negs * y series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)] return primal_out, series_out
def masked(*args): [dynamic_length], consts, [i], carry, xs = split_list( args, [1, num_consts, 1, num_carry]) out = fun(*(consts + carry + xs)) new_carry, ys = split_list(out, [num_carry]) new_carry = [lax.select(i < dynamic_length, new_c, c) for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys
def _jaxtupletree_select(pred, on_true, on_false): aval = core.get_aval(on_true) if type(aval) is core.AbstractTuple: return core.pack(map(partial(_jaxtupletree_select, pred), on_true, on_false)) elif isinstance(aval, UnshapedArray): return lax.select(pred, on_true, on_false) else: raise TypeError(aval)
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [ batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims) ] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred, ), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = { x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat, ), t_bat, tconst_bat, f_bat, fconst_bat = split_list( orig_bat, [1, true_nconsts, len(true_ops), false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [ batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat) ] false_out = [ batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat) ] return [lax.select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind(*itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def _lax_min_taylor_rule(primal_in, series_in): x, y = primal_in xgy = x < y # less than mask xey = x == y # equal to mask primal_out = lax.select(xgy, x, y) def select_min_and_avg_eq(x_i, y_i): """Select x where x>y or average when x==y""" min_i = lax.select(xgy, x_i, y_i) min_i = lax.select(xey, (x_i + y_i)/2, min_i) return min_i series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)] return primal_out, series_out
def select_min_and_avg_eq(x_i, y_i): """Select x where x>y or average when x==y""" min_i = lax.select(xgy, x_i, y_i) min_i = lax.select(xey, (x_i + y_i) / 2, min_i) return min_i
def _select_taylor_rule(primal_in, series_in, **params): b, x, y = primal_in primal_out = lax.select_p.bind(b, x, y, **params) sel = lambda _, x, y: lax.select(b, x, y) series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)] return primal_out, series_out
def _cond_pred_bcast_select(pred, x, y): bcast_pred = lax.broadcast_in_dim(pred, onp.shape(x), list(range(onp.ndim(pred)))) return lax.select(bcast_pred, x, y)