Exemplo n.º 1
0
 def test_different_computations(self):
     computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
     computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
     compile_options = jax.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
                         cc.get_cache_key(computation2, compile_options))
Exemplo n.º 2
0
 def test_same_hash_key(self):
   computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
   compile_options = jax._src.lib.xla_bridge.get_compile_options(
                      num_replicas=1, num_partitions=1)
   backend = jax._src.lib.xla_bridge.get_backend()
   self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
                    cc.get_cache_key(computation, compile_options, backend))
Exemplo n.º 3
0
 def test_different_hash_key(self):
     computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
     compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     compile_options_filled = self.filled_compile_options()
     self.assertNotEqual(
         cc.get_cache_key(computation, compile_options_not_filled),
         cc.get_cache_key(computation, compile_options_filled))
Exemplo n.º 4
0
    def test_xla_flags(self):
        computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
        compile_options = jax._src.lib.xla_bridge.get_compile_options(
            num_replicas=1, num_partitions=1)
        backend = jax._src.lib.xla_bridge.get_backend()

        orig_xla_flags = os.getenv("XLA_FLAGS")
        orig_argv = sys.argv
        try:
            os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
            key1 = cc.get_cache_key(computation, compile_options, backend)
            os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
            key2 = cc.get_cache_key(computation, compile_options, backend)
            self.assertNotEqual(key1, key2)

            os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
            key3 = cc.get_cache_key(computation, compile_options, backend)
            self.assertEqual(key1, key3)

            # Test flag in _xla_flags_to_exclude_from_cache_key
            os.environ["XLA_FLAGS"] = (
                "--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8"
            )
            key4 = cc.get_cache_key(computation, compile_options, backend)
            self.assertEqual(key1, key4)

            # Test flags given on command line
            del os.environ["XLA_FLAGS"]
            sys.argv.append("--xla_gpu_autotune_level=0")
            key5 = cc.get_cache_key(computation, compile_options, backend)
            self.assertEqual(key1, key5)
            sys.argv.append("--xla_force_host_platform_device_count=8")
            self.assertEqual(key1, key5)

        finally:
            if orig_xla_flags is not None:
                os.environ["XLA_FLAGS"] = orig_xla_flags
            elif os.getenv("XLA_FLAGS") is not None:
                del os.environ["XLA_FLAGS"]
            sys.argv = orig_argv