Exemplo n.º 1
0
 def test_get_no_executable(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         self.assertEqual(cc.get_executable(computation, compile_options),
                          None)
Exemplo n.º 2
0
 def test_jit(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         f = jit(lambda x: x * x)
         f(1)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 1)
         f(1.0)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 2)
Exemplo n.º 3
0
 def test_pmap(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
         x = np.arange(jax.device_count(), dtype=np.int64)
         f(x)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 1)
         x = np.arange(jax.device_count(), dtype=np.float32)
         f(x)
         files_in_directory = len(os.listdir(tmpdir))
         self.assertEqual(files_in_directory, 2)
Exemplo n.º 4
0
 def test_diff_executables(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
         computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         backend = jax.lib.xla_bridge.get_backend()
         executable1 = backend.compile(computation1, compile_options)
         executable2 = backend.compile(computation2, compile_options)
         cc.put_executable(computation1, compile_options, executable1)
         cc.put_executable(computation2, compile_options, executable2)
         self.assertNotEqual(
             cc.get_executable(computation1, compile_options),
             cc.get_executable(computation2, compile_options))
Exemplo n.º 5
0
 def test_put_executable(self):
   with tempfile.TemporaryDirectory() as tmpdir:
     cc.initialize_cache(tmpdir)
     computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
     compile_options = jax._src.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     backend = jax._src.lib.xla_bridge.get_backend()
     executable = backend.compile(computation, compile_options)
     cc.put_executable("alambda", computation, compile_options, executable,
                       backend)
     deserialized_executable = cc.get_executable(computation, compile_options, backend)
     inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
     expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
     actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
     self.assertEqual(expected, actual)
Exemplo n.º 6
0
 def test_xmap(self):
   with tempfile.TemporaryDirectory() as tmpdir:
     cc.initialize_cache(tmpdir)
     def f(x):
       return x * 2
     devices = np.array(jax.local_devices()[:2])
     if devices.size < 2:
       raise SkipTest("Test requires 2 devices")
     x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
        axis_resources={'a': 'x'})(x)
     files_in_directory = len(os.listdir(tmpdir))
     self.assertEqual(files_in_directory, 1)
     x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
        axis_resources={'a': 'x'})(x)
     files_in_directory = len(os.listdir(tmpdir))
     self.assertEqual(files_in_directory, 2)
Exemplo n.º 7
0
  def test_pjit(self):
    with tempfile.TemporaryDirectory() as tmpdir:
      cc.initialize_cache(tmpdir)
      @partial(pjit,
               in_axis_resources=(P('x'), P('x')),
               out_axis_resources=None)
      def f(x, y):
        return x + y

      shape = (8, 8)
      x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
      f(x, x + 1)
      files_in_directory = len(os.listdir(tmpdir))
      self.assertEqual(files_in_directory, 1)
      x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
      f(x, x + 1)
      files_in_directory = len(os.listdir(tmpdir))
      self.assertEqual(files_in_directory, 2)