def setUpModule(): # Parse flags (for `pytest`). if not FLAGS.is_parsed(): FLAGS(sys.argv, known_only=True) fake.set_n_cpu_devices(FLAGS.chex_n_cpu_devices) asserts.assert_devices_available(FLAGS.chex_n_cpu_devices, 'cpu', backend='cpu')
def test_set_n_cpu_devices(self): # Should not initialize backends. fake.set_n_cpu_devices(4) # Hence, this one does not fail. fake.set_n_cpu_devices(6) # This assert initializes backends. asserts.assert_devices_available(6, 'cpu', backend='cpu') # Which means that next call must fail. with self.assertRaisesRegex( RuntimeError, 'Attempted to set 8 devices, but 6 CPUs.+'): fake.set_n_cpu_devices(8)
def test_set_n_cpu_devices(self): try: # Should not initialize backends. fake.set_n_cpu_devices(4) except RuntimeError as set_cpu_exception: raise unittest.SkipTest( "set_n_cpu_devices: backend's already been initialized. " 'Run this test in isolation from others.') from set_cpu_exception # Hence, this one does not fail. fake.set_n_cpu_devices(6) # This assert initializes backends. asserts.assert_devices_available(6, 'cpu', backend='cpu') # Which means that next call must fail. with self.assertRaisesRegex(RuntimeError, 'Attempted to set 8 devices, but 6 CPUs.+'): fake.set_n_cpu_devices(8)
def test_tpu_assert(self): n_tpu = self._device_count('tpu') asserts.assert_devices_available(n_tpu, 'tpu') if n_tpu: asserts.assert_tpu_available() else: with self.assertRaisesRegex(AssertionError, 'No 3 TPUs available'): asserts.assert_devices_available(3, 'tpu') with self.assertRaisesRegex(AssertionError, 'No TPU devices available'): asserts.assert_tpu_available() with self.assertRaisesRegex(AssertionError, 'No 3 TPUs available'): asserts.assert_devices_available(3, 'tpu', backend='cpu')
def test_not_less_than(self, devtype): n = self._device_count(devtype) if n > 0: asserts.assert_devices_available( n - 1, devtype, backend=devtype, not_less_than=True) with self.assertRaisesRegex(AssertionError, f'Only {n} < {n + 1}'): asserts.assert_devices_available( n + 1, devtype, backend=devtype, not_less_than=True) else: with self.assertRaisesRegex(RuntimeError, 'Unknown backend'): # pylint: disable=g-error-prone-assert-raises asserts.assert_devices_available( n - 1, devtype, backend=devtype, not_less_than=True)
def setUpModule(): fake.set_n_cpu_devices() asserts.assert_devices_available(FLAGS['chex_n_cpu_devices'].value, 'cpu', backend='cpu')
def test_cpu_assert(self): n_cpu = jax.device_count('cpu') asserts.assert_devices_available(n_cpu, 'cpu', backend='cpu')
def test_unsupported_device(self): with self.assertRaisesRegex(ValueError, 'Unknown device type'): # pylint: disable=g-error-prone-assert-raises asserts.assert_devices_available(1, 'unsupported_devtype')
def setUpModule(): # Has the same effect as `fake.set_n_cpu_devices(FLAGS.chex_n_cpu_devices)`. fake.set_n_cpu_devices() asserts.assert_devices_available(FLAGS.chex_n_cpu_devices, 'cpu', backend='cpu')