示例#1
0
def setUpModule():
    # Parse flags (for `pytest`).
    if not FLAGS.is_parsed():
        FLAGS(sys.argv, known_only=True)
    fake.set_n_cpu_devices(FLAGS.chex_n_cpu_devices)
    asserts.assert_devices_available(FLAGS.chex_n_cpu_devices,
                                     'cpu',
                                     backend='cpu')
    def test_set_n_cpu_devices(self):
        # Should not initialize backends.
        fake.set_n_cpu_devices(4)

        # Hence, this one does not fail.
        fake.set_n_cpu_devices(6)

        # This assert initializes backends.
        asserts.assert_devices_available(6, 'cpu', backend='cpu')

        # Which means that next call must fail.
        with self.assertRaisesRegex(
                RuntimeError, 'Attempted to set 8 devices, but 6 CPUs.+'):
            fake.set_n_cpu_devices(8)
  def test_set_n_cpu_devices(self):
    try:
      # Should not initialize backends.
      fake.set_n_cpu_devices(4)
    except RuntimeError as set_cpu_exception:
      raise unittest.SkipTest(
          "set_n_cpu_devices: backend's already been initialized. "
          'Run this test in isolation from others.') from set_cpu_exception

    # Hence, this one does not fail.
    fake.set_n_cpu_devices(6)

    # This assert initializes backends.
    asserts.assert_devices_available(6, 'cpu', backend='cpu')

    # Which means that next call must fail.
    with self.assertRaisesRegex(RuntimeError,
                                'Attempted to set 8 devices, but 6 CPUs.+'):
      fake.set_n_cpu_devices(8)
示例#4
0
 def test_tpu_assert(self):
   n_tpu = self._device_count('tpu')
   asserts.assert_devices_available(n_tpu, 'tpu')
   if n_tpu:
     asserts.assert_tpu_available()
   else:
     with self.assertRaisesRegex(AssertionError, 'No 3 TPUs available'):
       asserts.assert_devices_available(3, 'tpu')
     with self.assertRaisesRegex(AssertionError, 'No TPU devices available'):
       asserts.assert_tpu_available()
   with self.assertRaisesRegex(AssertionError, 'No 3 TPUs available'):
     asserts.assert_devices_available(3, 'tpu', backend='cpu')
示例#5
0
 def test_not_less_than(self, devtype):
   n = self._device_count(devtype)
   if n > 0:
     asserts.assert_devices_available(
         n - 1, devtype, backend=devtype, not_less_than=True)
     with self.assertRaisesRegex(AssertionError, f'Only {n} < {n + 1}'):
       asserts.assert_devices_available(
           n + 1, devtype, backend=devtype, not_less_than=True)
   else:
     with self.assertRaisesRegex(RuntimeError, 'Unknown backend'):  # pylint: disable=g-error-prone-assert-raises
       asserts.assert_devices_available(
           n - 1, devtype, backend=devtype, not_less_than=True)
示例#6
0
def setUpModule():
    fake.set_n_cpu_devices()
    asserts.assert_devices_available(FLAGS['chex_n_cpu_devices'].value,
                                     'cpu',
                                     backend='cpu')
示例#7
0
 def test_cpu_assert(self):
   n_cpu = jax.device_count('cpu')
   asserts.assert_devices_available(n_cpu, 'cpu', backend='cpu')
示例#8
0
 def test_unsupported_device(self):
   with self.assertRaisesRegex(ValueError, 'Unknown device type'):  # pylint: disable=g-error-prone-assert-raises
     asserts.assert_devices_available(1, 'unsupported_devtype')
示例#9
0
def setUpModule():
    # Has the same effect as `fake.set_n_cpu_devices(FLAGS.chex_n_cpu_devices)`.
    fake.set_n_cpu_devices()
    asserts.assert_devices_available(FLAGS.chex_n_cpu_devices,
                                     'cpu',
                                     backend='cpu')