def testNested(self): def f(x, y): return psum(psum(x, 'i'), 'j') f = pjit(f, 'i') f = pjit(f, 'j', out_axes=1) x = onp.ones((3, 4), onp.float32) ans = f(x, x) expected = 12 * onp.ones((4, 3), onp.float32) self.assertAllClose(ans, expected, check_dtypes=True)
def testTupleInput(self): f = lambda x: x[0] - psum(x[0], 'i') x = onp.arange(8., dtype=onp.float32).reshape(4, 2) f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0) ans = f((x, )) expected = x - x.sum(0) self.assertAllClose(ans, expected, check_dtypes=False)
def testForwardModeAutodiff(self): def f(x): return np.cos(x - psum(np.sin(x), 'i')) x = np.ones(4) expected = jvp(pmap(f, 'i'), (x, ), (x, )) g = pjit(f, axis_name='i') ans = jvp(g, (x, ), (x, )) self.assertAllClose(ans, expected, check_dtypes=False)
def testReverseModeAutodiff(self): def f(x): return x - psum(x, 'i') x = np.ones(4) expected1 = grad(lambda x: np.sum(pmap(f, 'i')(x)))(x) expected2 = grad(lambda x: np.sum(x - np.sum(x)))(x) g = pjit(f, axis_name='i') ans = grad(lambda x: np.sum(g(x)))(x) self.assertAllClose(ans, expected1, check_dtypes=False) self.assertAllClose(ans, expected2, check_dtypes=False)