コード例 #1
0
    def testNested(self):
        def f(x, y):
            return psum(psum(x, 'i'), 'j')

        f = pjit(f, 'i')
        f = pjit(f, 'j', out_axes=1)

        x = onp.ones((3, 4), onp.float32)
        ans = f(x, x)
        expected = 12 * onp.ones((4, 3), onp.float32)
        self.assertAllClose(ans, expected, check_dtypes=True)
コード例 #2
0
 def testTupleInput(self):
     f = lambda x: x[0] - psum(x[0], 'i')
     x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
     f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
     ans = f((x, ))
     expected = x - x.sum(0)
     self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #3
0
    def testForwardModeAutodiff(self):
        def f(x):
            return np.cos(x - psum(np.sin(x), 'i'))

        x = np.ones(4)
        expected = jvp(pmap(f, 'i'), (x, ), (x, ))

        g = pjit(f, axis_name='i')
        ans = jvp(g, (x, ), (x, ))

        self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #4
0
    def testReverseModeAutodiff(self):
        def f(x):
            return x - psum(x, 'i')

        x = np.ones(4)
        expected1 = grad(lambda x: np.sum(pmap(f, 'i')(x)))(x)
        expected2 = grad(lambda x: np.sum(x - np.sum(x)))(x)

        g = pjit(f, axis_name='i')
        ans = grad(lambda x: np.sum(g(x)))(x)

        self.assertAllClose(ans, expected1, check_dtypes=False)
        self.assertAllClose(ans, expected2, check_dtypes=False)