コード例 #1
0
ファイル: jet.py プロジェクト: jamestwebber/jax
def _abs_taylor_rule(x, series_in, **params):
  x, = x
  zero = lax.full_like(x, 0, shape=())
  primal_out = lax.abs_p.bind(x, **params)
  negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
  fix_sign = lambda y: negs * y
  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
  return primal_out, series_out
コード例 #2
0
ファイル: util.py プロジェクト: xueeinstein/jax
def _where(condition, x=None, y=None):
  if x is None or y is None:
    raise ValueError("Either both or neither of the x and y arguments should "
                     "be provided to jax.numpy.where, got {} and {}."
                     .format(x, y))
  if not np.issubdtype(_dtype(condition), np.bool_):
    condition = lax.ne(condition, lax_internal._zero(condition))
  x, y = _promote_dtypes(x, y)
  condition, x, y = _broadcast_arrays(condition, x, y)
  try: is_always_empty = core.is_empty_shape(np.shape(x))
  except: is_always_empty = False  # can fail with dynamic shapes
  return lax.select(condition, x, y) if not is_always_empty else x
コード例 #3
0
ファイル: jet.py プロジェクト: jamestwebber/jax
def _lax_min_taylor_rule(primal_in, series_in):
    x, y = primal_in
    xgy = x < y   # less than mask
    xey = x == y  # equal to mask
    primal_out = lax.select(xgy, x, y)

    def select_min_and_avg_eq(x_i, y_i):
        """Select x where x>y or average when x==y"""
        min_i = lax.select(xgy, x_i, y_i)
        min_i = lax.select(xey, (x_i + y_i)/2, min_i)
        return min_i

    series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
    return primal_out, series_out
コード例 #4
0
 def select_min_and_avg_eq(x_i, y_i):
     """Select x where x>y or average when x==y"""
     min_i = lax.select(xgy, x_i, y_i)
     min_i = lax.select(xey, (x_i + y_i) / 2, min_i)
     return min_i
コード例 #5
0
def _select_taylor_rule(primal_in, series_in, **params):
    b, x, y = primal_in
    primal_out = lax.select_p.bind(b, x, y, **params)
    sel = lambda _, x, y: lax.select(b, x, y)
    series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)]
    return primal_out, series_out
コード例 #6
0
ファイル: windowed_reductions.py プロジェクト: wayfeng/jax
 def reducer(x, y):
     kx, vx = x
     ky, vy = y
     which = select_prim.bind(kx, ky)
     return (lax.select(which, kx, ky), lax.select(which, vx, vy))
コード例 #7
0
def assert_func(error: Error, pred: Bool, msg: str) -> Error:
    code = next_code()
    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    return Error(out_err, out_code, {code: msg, **error.msgs})
コード例 #8
0
def assert_discharge_rule(error, pred, code, *, msgs):
    out_err = error.err | jnp.logical_not(pred)
    out_code = lax.select(error.err, error.code, code)
    return [], Error(out_err, out_code, {**error.msgs, **msgs})
コード例 #9
0
ファイル: conditionals.py プロジェクト: xueeinstein/jax
def _bcast_select(pred, on_true, on_false):
    if np.ndim(pred) != np.ndim(on_true):
        idx = list(range(np.ndim(pred)))
        pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
    return lax.select(pred, on_true, on_false)