def testNormalize(self): def f(x): return x / x.sum(0) x = onp.arange(4.) expected = f(x) ans = _parallelize(f)(x) self.assertAllClose(ans, expected, check_dtypes=False) jaxpr = make_jaxpr(_parallelize(f))(x) self.assertIn('psum', repr(jaxpr))
def testOuter3(self): x = onp.arange(10) y = 2 * onp.arange(10) def f(x, y): return x[:, None] * y expected = f(x, y) ans = _parallelize(f)(x, y) self.assertAllClose(ans, expected, check_dtypes=False)
def testAdd2(self): x = onp.arange(10) y = 2 * onp.arange(10) def f(y): return x + y expected = f(y) ans = _parallelize(f)(y) self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeAndAddRank3(self): def fun(x): return x + x.T x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)) expected = fun(x) ans = _parallelize(fun)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testTranspose(self, shape, perm): def fun(x): return lax.transpose(x, perm) x = onp.arange(prod(shape)).reshape(shape) expected = fun(x) ans = _parallelize(fun)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testCall(self): @jit def fun(x): return x x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)) expected = fun(x) ans = _parallelize(fun)(x) self.assertAllClose(ans, expected, check_dtypes=False)