コード例 #1
0
ファイル: parallel_test.py プロジェクト: zhucer2003/jax
    def testNormalize(self):
        def f(x):
            return x / x.sum(0)

        x = onp.arange(4.)
        expected = f(x)
        ans = _parallelize(f)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)

        jaxpr = make_jaxpr(_parallelize(f))(x)
        self.assertIn('psum', repr(jaxpr))
コード例 #2
0
 def testOuter3(self):
   x = onp.arange(10)
   y = 2 * onp.arange(10)
   def f(x, y): return x[:, None] * y
   expected = f(x, y)
   ans = _parallelize(f)(x, y)
   self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #3
0
 def testAdd2(self):
   x = onp.arange(10)
   y = 2 * onp.arange(10)
   def f(y): return x + y
   expected = f(y)
   ans = _parallelize(f)(y)
   self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #4
0
ファイル: parallel_test.py プロジェクト: zhucer2003/jax
    def testTransposeAndAddRank3(self):
        def fun(x):
            return x + x.T

        x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
        expected = fun(x)
        ans = _parallelize(fun)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #5
0
ファイル: parallel_test.py プロジェクト: zhucer2003/jax
    def testTranspose(self, shape, perm):
        def fun(x):
            return lax.transpose(x, perm)

        x = onp.arange(prod(shape)).reshape(shape)
        expected = fun(x)
        ans = _parallelize(fun)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #6
0
ファイル: parallel_test.py プロジェクト: zhucer2003/jax
    def testCall(self):
        @jit
        def fun(x):
            return x

        x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
        expected = fun(x)
        ans = _parallelize(fun)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)