def test_method_wrapper(self, *args, **kwargs): flag_value = config._read(flag_name) if flag_value == skip_value: test_name = getattr(test_method, '__name__', '[unknown test]') raise unittest.SkipTest( f"{test_name} not supported when FLAGS.{flag_name} is {flag_value}") return test_method(self, *args, **kwargs)
def setUp(self): super().setUp() self._original_config = {} for key, value in self._default_config.items(): self._original_config[key] = config._read(key) config.update(key, value) # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
def test_jax_platforms_flag(self): self._register_factory("platform_A", 20) self._register_factory("platform_B", 10) orig_jax_platforms = config._read("jax_platforms") try: config.FLAGS.jax_platforms = "cpu,platform_A" backend = xb.get_backend() self.assertEqual(backend.platform, "cpu") # Only specified backends initialized. self.assertEqual(len(xb._backends), 2) backend = xb.get_backend("platform_A") self.assertEqual(backend.platform, "platform_A") with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"): backend = xb.get_backend("platform_B") finally: config.FLAGS.jax_platforms = orig_jax_platforms