def _abs_taylor_rule(x, series_in, **params): x, = x zero = lax.full_like(x, 0, shape=()) primal_out = lax.abs_p.bind(x, **params) negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0)) fix_sign = lambda y: negs * y series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)] return primal_out, series_out
def _where(condition, x=None, y=None): if x is None or y is None: raise ValueError("Either both or neither of the x and y arguments should " "be provided to jax.numpy.where, got {} and {}." .format(x, y)) if not np.issubdtype(_dtype(condition), np.bool_): condition = lax.ne(condition, lax_internal._zero(condition)) x, y = _promote_dtypes(x, y) condition, x, y = _broadcast_arrays(condition, x, y) try: is_always_empty = core.is_empty_shape(np.shape(x)) except: is_always_empty = False # can fail with dynamic shapes return lax.select(condition, x, y) if not is_always_empty else x
def _lax_min_taylor_rule(primal_in, series_in): x, y = primal_in xgy = x < y # less than mask xey = x == y # equal to mask primal_out = lax.select(xgy, x, y) def select_min_and_avg_eq(x_i, y_i): """Select x where x>y or average when x==y""" min_i = lax.select(xgy, x_i, y_i) min_i = lax.select(xey, (x_i + y_i)/2, min_i) return min_i series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)] return primal_out, series_out
def select_min_and_avg_eq(x_i, y_i): """Select x where x>y or average when x==y""" min_i = lax.select(xgy, x_i, y_i) min_i = lax.select(xey, (x_i + y_i) / 2, min_i) return min_i
def _select_taylor_rule(primal_in, series_in, **params): b, x, y = primal_in primal_out = lax.select_p.bind(b, x, y, **params) sel = lambda _, x, y: lax.select(b, x, y) series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)] return primal_out, series_out
def reducer(x, y): kx, vx = x ky, vy = y which = select_prim.bind(kx, ky) return (lax.select(which, kx, ky), lax.select(which, vx, vy))
def assert_func(error: Error, pred: Bool, msg: str) -> Error: code = next_code() out_err = error.err | jnp.logical_not(pred) out_code = lax.select(error.err, error.code, code) return Error(out_err, out_code, {code: msg, **error.msgs})
def assert_discharge_rule(error, pred, code, *, msgs): out_err = error.err | jnp.logical_not(pred) out_code = lax.select(error.err, error.code, code) return [], Error(out_err, out_code, {**error.msgs, **msgs})
def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): idx = list(range(np.ndim(pred))) pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) return lax.select(pred, on_true, on_false)