Пример #1
0
 def testNested(self):
     f = lambda x: psum(psum(x, 'i'), 'j')
     x = onp.ones((2, 2))
     ans1 = pmap(pmap(f, 'i'), 'j')(x)
     ans2 = pmap(pmap(f, 'j'), 'i')(x)
     expected = 4 * onp.ones((2, 2))
     self.assertAllClose(ans1, expected, check_dtypes=False)
     self.assertAllClose(ans2, expected, check_dtypes=False)
Пример #2
0
 def testSplitBasic(self):
     f = lambda x: psum(np.sin(x), 'i')
     x = onp.ones((2, 2))
     fsplit = axisvar_split(f, 'i', ('j', 'k'))
     ans = pmap(pmap(fsplit, 'j'), 'k')(x)
     expected = onp.sum(onp.sin(x))
     self.assertAllClose(ans, expected, check_dtypes=False)
Пример #3
0
 def testTupleInput(self):
     f = lambda x: x[0] - psum(x[0], 'i')
     x = onp.arange(8., dtype=onp.float32).reshape(4, 2)
     f = pjit(f, axis_name='i', in_axes=0, out_axes=0, mesh_axis=0)
     ans = f((x, ))
     expected = x - x.sum(0)
     self.assertAllClose(ans, expected, check_dtypes=False)
Пример #4
0
    def testSum(self):
        pfun, axis_name = papply(np.sum)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(lambda x: psum(x, axis_name))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = pmap(pfun, axis_name)(onp.arange(3.))
        expected = onp.sum(onp.arange(3.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Пример #5
0
    def testLogSoftmax(self):
        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Пример #6
0
 def testLogSoftmax(self):
     f = lambda x: x - np.log(psum(np.exp(x), 'i'))
     x = onp.log(onp.arange(1., 10., dtype=onp.float32))
     ans = pmap(f, axis_name='i')(x)
     expected = x - onp.log(onp.sum(onp.exp(x)))
     self.assertAllClose(ans, expected, check_dtypes=False)
Пример #7
0
 def testReduceSum(self):
     f = lambda x: psum(x, 'i')
     ans = pmap(f, axis_name='i')(onp.ones(4))
     expected = 4 * onp.ones(4)
     self.assertAllClose(ans, expected, check_dtypes=False)