Example #1
0
File: eigh.py Project: cloudhan/jax
  def loop_body(state):
    agenda, blocks, eigenvectors = state
    (offset, b), agenda = agenda.pop()

    which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets)
    choice = jnp.argmin(which)
    return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)
Example #2
0
def vq(obs, code_book, check_finite=True):
    _check_arraylike("scipy.cluster.vq.vq", obs, code_book)
    if obs.ndim != code_book.ndim:
        raise ValueError("Observation and code_book should have the same rank")
    obs, code_book = _promote_dtypes_inexact(obs, code_book)
    if obs.ndim == 1:
        obs, code_book = obs[..., None], code_book[..., None]
    if obs.ndim != 2:
        raise ValueError("ndim different than 1 or 2 are not supported")

    # explicitly rank promotion
    dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs)
    code = argmin(dist, axis=-1)
    dist_min = vmap(operator.getitem)(dist, code)
    return code, dist_min
Example #3
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Example #4
0
def _nonzero_range(arr):
    # return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
    is_zero = arr == 0
    start = argmin(is_zero)
    end = is_zero.size - argmin(is_zero[::-1])
    return start, end