def testAxisName2d(self): def f(x): return x - lax.psum(x, 'i') + lax.pmax(x, 'j') x = jnp.arange(8 * 8).reshape((8, 8)) s = [('vectorized', None)] self.assertAllClose(gmap(gmap(f, s, axis_name='i'), s, axis_name='j')(x), vmap(vmap(f, axis_name='i'), axis_name='j')(x))
def testBasicSchedules(self, schedule): def f(x): return jnp.dot(jnp.sin(x), x.T) * 4 + x x = jnp.arange(800).reshape((8, 10, 10)) self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
def testAxisName(self, schedule): def f(x): return x - lax.psum(x, 'i') x = jnp.arange(8) self.assertAllClose( gmap(f, schedule, axis_name='i')(x), vmap(f, axis_name='i')(x))
def testBasicSchedules(self, schedule): def f(x): return jnp.dot(jnp.sin(x), x.T) * 4 + x x = jnp.arange(800).reshape((8, 10, 10)) for loop, n in schedule: approx_n = x.shape[0] if n is None else n if loop == 'parallel' and approx_n > xla_bridge.device_count(): raise SkipTest("this test requires more XLA devices") self.assertAllClose(vmap(f)(x), gmap(f, schedule)(x))