Exemplo n.º 1
0
  def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory):
    rng = rng_factory(np.random.RandomState(0))

    # The polynomial coefficients here start with zero and would have to
    # be stripped before computing eigenvalues of the companion matrix.
    # Setting strip_zeros=False skips this check,
    # allowing jit transformation but yielding nan's for these inputs.
    p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)])

    if p.size == 1:
      # polynomial = const has no roots
      self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0)
    else:
      self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
Exemplo n.º 2
0
  def testRoots(self, dtype, rng_factory, length, leading, trailing):
    rng = rng_factory(np.random.RandomState(0))

    def args_maker():
      p = rng((length,), dtype)
      return jnp.concatenate(
        [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]),

    jnp_fn = lambda arg: jnp.sort(jnp.roots(arg))
    np_fn = lambda arg: np.sort(np.roots(arg))
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=3e-6)
Exemplo n.º 3
0
    def testRoots(self, dtype, rng_factory, length, leading, trailing):
        rng = rng_factory()

        def args_maker():
            p = rng((length, ), dtype)
            return np.concatenate(
                [np.zeros(leading, p.dtype), p,
                 np.zeros(trailing, p.dtype)]),

        # order may differ (np.sort doesn't deal with complex numbers)
        np_fn = lambda arg: onp.sort(np.roots(arg))
        onp_fn = lambda arg: onp.sort(onp.roots(arg))
        self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False)
Exemplo n.º 4
0
  def testRootsNostrip(self, length, dtype, rng_factory, trailing):
    rng = rng_factory(np.random.RandomState(0))

    def args_maker():
      p = rng((length,), dtype)
      if length != 0:
        return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]),
      else:
        # adding trailing would make input invalid (start with zeros)
        return p,

    jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False))
    np_fn = lambda arg: np.sort(np.roots(arg))
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker,
                            check_dtypes=False, tol=1e-6)
Exemplo n.º 5
0
    def testRootsNostrip(self, length, dtype, rng_factory, trailing):
        rng = rng_factory(onp.random.RandomState(0))

        def args_maker():
            p = rng((length, ), dtype)
            if length != 0:
                return np.concatenate([p, np.zeros(trailing, p.dtype)]),
            else:
                # adding trailing would make input invalid (start with zeros)
                return p,

        # order may differ (np.sort doesn't deal with complex numbers)
        np_fn = lambda arg: onp.sort(np.roots(arg, strip_zeros=False))
        onp_fn = lambda arg: onp.sort(onp.roots(arg))
        self._CheckAgainstNumpy(onp_fn,
                                np_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
Exemplo n.º 6
0
def roots(p):
  p = _remove_jaxarray(p)
  return JaxArray(jnp.roots(p))