Exemplo n.º 1
0
 def body_fun(state):
   i, num, arr, out = state
   return (lax.add(i, 1), num, arr, inner_loop(i, arr, out))
Exemplo n.º 2
0
 def body_fun(state):
   i, j, arr, out = state
   arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
   arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
   out = update_entry(out, arr_i_j, i, j)
   return (i, lax.add(j, 1), arr, out)
Exemplo n.º 3
0
 def loop_body(state):
   effect[0] = True
   pos, count = state
   return (lax.add(pos, 1), lax.add(count, inc))
Exemplo n.º 4
0
 def loop_body(state):
   effect[0] = True
   pos, count = state
   f = lambda pos, inc: (lax.add(pos, 1), lax.add(count, inc))
   return api.jit(f)(pos, inc)
Exemplo n.º 5
0
 def body_fun(state):
   num, i, count = state
   return (num, lax.add(i, 1), inner_loop(i, count))
Exemplo n.º 6
0
 def body_fun(state):
   i, j, count = state
   return (i, lax.add(j, 1), lax.add(count, 1))
Exemplo n.º 7
0
 def body_fun(i, state):
   arr, total, _ = state
   arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
   return (arr, lax.add(total, arr_i), ())
Exemplo n.º 8
0
 def loop_body(state):
   pos, count = state
   return (lax.add(pos, 1), lax.add(count, 1))
Exemplo n.º 9
0
 def body_fun(i, tot):
   return lax.add(num, lax.add(tot, i))
Exemplo n.º 10
0
 def body_fun(i, state):
   arr, total = state['arr'], state['total']
   arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
   return {'arr': arr, 'total': lax.add(total, arr_i)}
Exemplo n.º 11
0
 def body_fun(state):
   arr, num, i, total = state
   arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
   return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
Exemplo n.º 12
0
def expit(x):
    x, = _promote_args_inexact("expit", x)
    one = _lax_const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Exemplo n.º 13
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("uniform.logpdf", x, loc, scale)
    log_probs = lax.neg(lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Exemplo n.º 14
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = lax._const(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
Exemplo n.º 15
0
 def f():
     return lax.add(3., 4.)