示例#1
0
 def testAxisName2d(self):
   def f(x):
     return x - lax.psum(x, 'i') + lax.pmax(x, 'j')
   x = jnp.arange(8 * 8).reshape((8, 8))
   s = [('vectorized', None)]
   self.assertAllClose(gmap(gmap(f, s, axis_name='i'), s, axis_name='j')(x),
                       vmap(vmap(f, axis_name='i'), axis_name='j')(x))
示例#2
0
    def testBasicSchedules(self, schedule):
        def f(x):
            return jnp.dot(jnp.sin(x), x.T) * 4 + x

        x = jnp.arange(800).reshape((8, 10, 10))

        self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
示例#3
0
    def testAxisName(self, schedule):
        def f(x):
            return x - lax.psum(x, 'i')

        x = jnp.arange(8)
        self.assertAllClose(
            gmap(f, schedule, axis_name='i')(x),
            vmap(f, axis_name='i')(x))
示例#4
0
文件: gmap_test.py 项目: xiaoral2/jax
    def testBasicSchedules(self, schedule):
        def f(x):
            return jnp.dot(jnp.sin(x), x.T) * 4 + x

        x = jnp.arange(800).reshape((8, 10, 10))

        for loop, n in schedule:
            approx_n = x.shape[0] if n is None else n
            if loop == 'parallel' and approx_n > xla_bridge.device_count():
                raise SkipTest("this test requires more XLA devices")

        self.assertAllClose(vmap(f)(x), gmap(f, schedule)(x))