def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory): rng = rng_factory(np.random.RandomState(0)) # The polynomial coefficients here start with zero and would have to # be stripped before computing eigenvalues of the companion matrix. # Setting strip_zeros=False skips this check, # allowing jit transformation but yielding nan's for these inputs. p = jnp.concatenate([jnp.zeros(zeros, dtype), rng((nonzeros,), dtype)]) if p.size == 1: # polynomial = const has no roots self.assertTrue(jnp.roots(p, strip_zeros=False).size == 0) else: self.assertTrue(jnp.any(jnp.isnan(jnp.roots(p, strip_zeros=False))))
def testRoots(self, dtype, rng_factory, length, leading, trailing): rng = rng_factory(np.random.RandomState(0)) def args_maker(): p = rng((length,), dtype) return jnp.concatenate( [jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype)]), jnp_fn = lambda arg: jnp.sort(jnp.roots(arg)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=3e-6)
def testRoots(self, dtype, rng_factory, length, leading, trailing): rng = rng_factory() def args_maker(): p = rng((length, ), dtype) return np.concatenate( [np.zeros(leading, p.dtype), p, np.zeros(trailing, p.dtype)]), # order may differ (np.sort doesn't deal with complex numbers) np_fn = lambda arg: onp.sort(np.roots(arg)) onp_fn = lambda arg: onp.sort(onp.roots(arg)) self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False)
def testRootsNostrip(self, length, dtype, rng_factory, trailing): rng = rng_factory(np.random.RandomState(0)) def args_maker(): p = rng((length,), dtype) if length != 0: return jnp.concatenate([p, jnp.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, jnp_fn = lambda arg: jnp.sort(jnp.roots(arg, strip_zeros=False)) np_fn = lambda arg: np.sort(np.roots(arg)) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-6)
def testRootsNostrip(self, length, dtype, rng_factory, trailing): rng = rng_factory(onp.random.RandomState(0)) def args_maker(): p = rng((length, ), dtype) if length != 0: return np.concatenate([p, np.zeros(trailing, p.dtype)]), else: # adding trailing would make input invalid (start with zeros) return p, # order may differ (np.sort doesn't deal with complex numbers) np_fn = lambda arg: onp.sort(np.roots(arg, strip_zeros=False)) onp_fn = lambda arg: onp.sort(onp.roots(arg)) self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False, tol=1e-6)
def roots(p): p = _remove_jaxarray(p) return JaxArray(jnp.roots(p))