Example #1
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.assert_num_jit_and_pmap_compilations(1):
            sharded_f(x)
            sharded_f(x)
Example #2
0
    def testCompilationCache(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.assert_num_jit_and_pmap_compilations(1):
            sharded_f(x)
            sharded_f(x)