def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets) choice = jnp.argmin(which) return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)
def vq(obs, code_book, check_finite=True): _check_arraylike("scipy.cluster.vq.vq", obs, code_book) if obs.ndim != code_book.ndim: raise ValueError("Observation and code_book should have the same rank") obs, code_book = _promote_dtypes_inexact(obs, code_book) if obs.ndim == 1: obs, code_book = obs[..., None], code_book[..., None] if obs.ndim != 2: raise ValueError("ndim different than 1 or 2 are not supported") # explicitly rank promotion dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs) code = argmin(dist, axis=-1) dist_min = vmap(operator.getitem)(dist, code) return code, dist_min
def roots(p, *, strip_zeros=True): _check_arraylike("roots", p) p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error( int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") return _roots_no_zeros(p[num_leading_zeros:]) else: return _roots_with_zeros(p, num_leading_zeros)
def _nonzero_range(arr): # return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros is_zero = arr == 0 start = argmin(is_zero) end = is_zero.size - argmin(is_zero[::-1]) return start, end