def testRootsNoStrip(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length, ), dtype) return [ jnp.concatenate([ jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype) ]) ] jnp_fun = partial(jnp.roots, strip_zeros=False) def np_fun(arg): roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) if len(roots) < len(arg) - 1: roots = np.pad(roots, (0, len(arg) - len(roots) - 1), constant_values=complex(np.nan, np.nan)) return roots # Note: outputs have no defined order, so we need to use a special comparator. args = args_maker() np_roots = np_fun(*args) jnp_roots = jnp_fun(*args) self.assertSetsAllClose(np_roots, jnp_roots) self._CompileAndCheck(jnp_fun, args_maker)
def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) arr2 = arr.reshape(new_shape) arr2_sparse = arr_sparse.reshape(new_shape) self.assertArraysEqual(arr2, arr2_sparse.todense())
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) f = self.sparsify( lambda x: lax.reshape(x, new_shape, dimensions=dimensions)) arr2 = f(arr) arr2_sparse = f(arr_sparse) self.assertArraysEqual(arr2, arr2_sparse.todense())
def testRoots(self, dtype, length, leading, trailing): rng = jtu.rand_some_zero(self.rng()) def args_maker(): p = rng((length, ), dtype) return [ jnp.concatenate([ jnp.zeros(leading, p.dtype), p, jnp.zeros(trailing, p.dtype) ]) ] jnp_fun = jnp.roots def np_fun(arg): return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype)) # Note: outputs have no defined order, so we need to use a special comparator. args = args_maker() np_roots = np_fun(*args) jnp_roots = jnp_fun(*args) self.assertSetsAllClose(np_roots, jnp_roots)
def testSparseConcatenate(self, shapes, func, n_batch): f = self.sparsify(getattr(jnp, func)) rng = jtu.rand_some_zero(self.rng()) arrs = [rng(shape, 'int32') for shape in shapes] sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs] self.assertArraysEqual(f(arrs), f(sparrs).todense())