コード例 #1
0
ファイル: polynomial_test.py プロジェクト: romanngg/jax
    def testRootsNoStrip(self, dtype, length, leading, trailing):
        rng = jtu.rand_some_zero(self.rng())

        def args_maker():
            p = rng((length, ), dtype)
            return [
                jnp.concatenate([
                    jnp.zeros(leading, p.dtype), p,
                    jnp.zeros(trailing, p.dtype)
                ])
            ]

        jnp_fun = partial(jnp.roots, strip_zeros=False)

        def np_fun(arg):
            roots = np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))
            if len(roots) < len(arg) - 1:
                roots = np.pad(roots, (0, len(arg) - len(roots) - 1),
                               constant_values=complex(np.nan, np.nan))
            return roots

        # Note: outputs have no defined order, so we need to use a special comparator.
        args = args_maker()
        np_roots = np_fun(*args)
        jnp_roots = jnp_fun(*args)
        self.assertSetsAllClose(np_roots, jnp_roots)
        self._CompileAndCheck(jnp_fun, args_maker)
コード例 #2
0
    def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        arr2 = arr.reshape(new_shape)
        arr2_sparse = arr_sparse.reshape(new_shape)

        self.assertArraysEqual(arr2, arr2_sparse.todense())
コード例 #3
0
    def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch,
                                        n_dense, dimensions):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        f = self.sparsify(
            lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

        arr2 = f(arr)
        arr2_sparse = f(arr_sparse)

        self.assertArraysEqual(arr2, arr2_sparse.todense())
コード例 #4
0
ファイル: polynomial_test.py プロジェクト: romanngg/jax
    def testRoots(self, dtype, length, leading, trailing):
        rng = jtu.rand_some_zero(self.rng())

        def args_maker():
            p = rng((length, ), dtype)
            return [
                jnp.concatenate([
                    jnp.zeros(leading, p.dtype), p,
                    jnp.zeros(trailing, p.dtype)
                ])
            ]

        jnp_fun = jnp.roots

        def np_fun(arg):
            return np.roots(arg).astype(dtypes._to_complex_dtype(arg.dtype))

        # Note: outputs have no defined order, so we need to use a special comparator.
        args = args_maker()
        np_roots = np_fun(*args)
        jnp_roots = jnp_fun(*args)
        self.assertSetsAllClose(np_roots, jnp_roots)
コード例 #5
0
 def testSparseConcatenate(self, shapes, func, n_batch):
     f = self.sparsify(getattr(jnp, func))
     rng = jtu.rand_some_zero(self.rng())
     arrs = [rng(shape, 'int32') for shape in shapes]
     sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
     self.assertArraysEqual(f(arrs), f(sparrs).todense())