Ejemplo n.º 1
0
def roots(p, *, strip_zeros=True):
    # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
    p = atleast_1d(p)
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")

    # strip_zeros=False is unsafe because leading zeros aren't removed
    if not strip_zeros:
        if p.size > 1:
            return _roots_no_zeros(p)
        else:
            return array([])

    if all(p == 0):
        return array([])

    # factor out trivial roots
    start, end = _nonzero_range(p)
    # number of trailing zeros = number of roots at 0
    trailing_zeros = p.size - end

    # strip leading and trailing zeros
    p = p[start:end]

    if p.size < 2:
        return zeros(trailing_zeros, p.dtype)
    else:
        roots = _roots_no_zeros(p)
        # combine roots and zero roots
        roots = hstack((roots, zeros(trailing_zeros, p.dtype)))
        return roots
Ejemplo n.º 2
0
def polyint(p, m=1, k=None):
  m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
  k = 0 if k is None else k
  _check_arraylike("polyint", p, k)
  p, k = _promote_dtypes_inexact(p, k)
  if m < 0:
    raise ValueError("Order of integral must be positive (see polyder)")
  k = atleast_1d(k)
  if len(k) == 1:
    k = full((m,), k[0])
  if k.shape != (m,):
    raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
  if m == 0:
    return p
  else:
    coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
    return true_divide(concatenate((p, k)), coeff)
Ejemplo n.º 3
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Ejemplo n.º 4
0
def poly(seq_of_zeros):
    _check_arraylike('poly', seq_of_zeros)
    seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
    seq_of_zeros = atleast_1d(seq_of_zeros)

    sh = seq_of_zeros.shape
    if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
        # import at runtime to avoid circular import
        from jax._src.numpy import linalg
        seq_of_zeros = linalg.eigvals(seq_of_zeros)

    if seq_of_zeros.ndim != 1:
        raise ValueError("input must be 1d or non-empty square 2d array.")

    dt = seq_of_zeros.dtype
    if len(seq_of_zeros) == 0:
        return ones((), dtype=dt)

    a = ones((1, ), dtype=dt)
    for k in range(len(seq_of_zeros)):
        a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')

    return a