예제 #1
0
def isinf(x):
    _check_arraylike("isinf", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.eq(lax.abs(x), _constant_like(x, np.inf))
    elif dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        im = lax.imag(x)
        return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                              lax.eq(lax.abs(im), _constant_like(im, np.inf)))
    else:
        return lax.full_like(x, False, dtype=np.bool_)
예제 #2
0
파일: lax_linalg.py 프로젝트: tpanthera/jax
    def body(k, state):
        pivot, perm, a, error = state
        m_idx = np.arange(m)
        n_idx = np.arange(n)

        if np.issubdtype(a.dtype, np.complexfloating):
            t = a[:, k]
            magnitude = np.abs(np.real(t)) + np.abs(np.imag(t))
        else:
            magnitude = np.abs(a[:, k])
        i = np.argmax(np.where(m_idx >= k, magnitude, -np.inf))
        pivot = ops.index_update(pivot, ops.index[k], i)

        a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ])

        perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ])

        # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
        x = a[k, k]
        error = error | lax.eq(x, np._constant_like(a, 0))
        a = ops.index_update(a, ops.index[:, k],
                             np.where(m_idx > k, a[:, k] / x, a[:, k]))

        # a[k+1:, k+1:] -= np.outer(a[k+1:, k], a[k, k+1:])
        a = a - np.where(
            (m_idx[:, None] > k) & (n_idx > k), np.outer(a[:, k], a[k, :]),
            np.array(0, dtype=a.dtype))
        return pivot, perm, a, error
예제 #3
0
def sinc(x):
    _check_arraylike("sinc", x)
    x, = _promote_dtypes_inexact(x)
    eq_zero = lax.eq(x, _lax_const(x, 0))
    pi_x = lax.mul(_lax_const(x, np.pi), x)
    safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
    return _where(eq_zero, _sinc_maclaurin(0, pi_x),
                  lax.div(lax.sin(safe_pi_x), safe_pi_x))
예제 #4
0
 def fn(x1, x2):
   x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
   # Comparison on complex types are defined as a lexicographic ordering on
   # the (real, imag) pair.
   if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
     rx = lax.real(x1)
     ry = lax.real(x2)
     return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                       lax_fn(rx, ry))
   return lax_fn(x1, x2)
예제 #5
0
def _isposneginf(infinity, x, out):
  if out is not None:
    raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
  dtype = dtypes.dtype(x)
  if dtypes.issubdtype(dtype, np.floating):
    return lax.eq(x, _constant_like(x, infinity))
  elif dtypes.issubdtype(dtype, np.complexfloating):
    raise ValueError("isposinf/isneginf are not well defined for complex types")
  else:
    return lax.full_like(x, False, dtype=np.bool_)
예제 #6
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
예제 #7
0
파일: nlm.py 프로젝트: bryanhpchiang/nlm
 def compare(patch_y, patch_x):
     patch_center_y = patch_y + filter_radius
     patch_center_x = patch_x + filter_radius
     # Skip if patch is out of image boundaries or this is the center patch
     skip = lax.lt(patch_center_y, pad) | lax.ge(patch_center_y, _h +
                                                 pad) | lax.lt(patch_center_x, pad) | lax.ge(patch_center_x, _w+pad) | (lax.eq(patch_center_y, win_center_y) & lax.eq(patch_center_x, win_center_x))
     return lax.cond(skip, lambda _: (0., 0.), _compare, (patch_center_y, patch_center_x))