def test_different_computations(self): computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) self.assertNotEqual(cc.get_cache_key(computation1, compile_options), cc.get_cache_key(computation2, compile_options))
def test_same_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() self.assertEqual(cc.get_cache_key(computation, compile_options, backend), cc.get_cache_key(computation, compile_options, backend))
def test_different_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) compile_options_filled = self.filled_compile_options() self.assertNotEqual( cc.get_cache_key(computation, compile_options_not_filled), cc.get_cache_key(computation, compile_options_filled))
def test_xla_flags(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() orig_xla_flags = os.getenv("XLA_FLAGS") orig_argv = sys.argv try: os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" key1 = cc.get_cache_key(computation, compile_options, backend) os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1" key2 = cc.get_cache_key(computation, compile_options, backend) self.assertNotEqual(key1, key2) os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" key3 = cc.get_cache_key(computation, compile_options, backend) self.assertEqual(key1, key3) # Test flag in _xla_flags_to_exclude_from_cache_key os.environ["XLA_FLAGS"] = ( "--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8" ) key4 = cc.get_cache_key(computation, compile_options, backend) self.assertEqual(key1, key4) # Test flags given on command line del os.environ["XLA_FLAGS"] sys.argv.append("--xla_gpu_autotune_level=0") key5 = cc.get_cache_key(computation, compile_options, backend) self.assertEqual(key1, key5) sys.argv.append("--xla_force_host_platform_device_count=8") self.assertEqual(key1, key5) finally: if orig_xla_flags is not None: os.environ["XLA_FLAGS"] = orig_xla_flags elif os.getenv("XLA_FLAGS") is not None: del os.environ["XLA_FLAGS"] sys.argv = orig_argv