def roots(p, *, strip_zeros=True): # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251 p = atleast_1d(p) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") # strip_zeros=False is unsafe because leading zeros aren't removed if not strip_zeros: if p.size > 1: return _roots_no_zeros(p) else: return array([]) if all(p == 0): return array([]) # factor out trivial roots start, end = _nonzero_range(p) # number of trailing zeros = number of roots at 0 trailing_zeros = p.size - end # strip leading and trailing zeros p = p[start:end] if p.size < 2: return zeros(trailing_zeros, p.dtype) else: roots = _roots_no_zeros(p) # combine roots and zero roots roots = hstack((roots, zeros(trailing_zeros, p.dtype))) return roots
def polyint(p, m=1, k=None): m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) p, k = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k = atleast_1d(k) if len(k) == 1: k = full((m,), k[0]) if k.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p else: coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return true_divide(concatenate((p, k)), coeff)
def roots(p, *, strip_zeros=True): _check_arraylike("roots", p) p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error( int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") return _roots_no_zeros(p[num_leading_zeros:]) else: return _roots_with_zeros(p, num_leading_zeros)
def poly(seq_of_zeros): _check_arraylike('poly', seq_of_zeros) seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros) seq_of_zeros = atleast_1d(seq_of_zeros) sh = seq_of_zeros.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg seq_of_zeros = linalg.eigvals(seq_of_zeros) if seq_of_zeros.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") dt = seq_of_zeros.dtype if len(seq_of_zeros) == 0: return ones((), dtype=dt) a = ones((1, ), dtype=dt) for k in range(len(seq_of_zeros)): a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') return a