예제 #1
0
    def testDeviceDetails(self):
        (cpu, ) = config.list_physical_devices('CPU')
        details = config.get_device_details(cpu)
        self.assertEqual(details, {})

        if not test_util.is_gpu_available():
            return

        gpus = config.list_physical_devices('GPU')
        details = config.get_device_details(gpus[0])
        self.assertIsInstance(details['device_name'], str)
        self.assertNotEmpty(details['device_name'])
        if test.is_built_with_rocm():
            # AMD GPUs do not have a compute capability
            self.assertNotIn('compute_capability', details)
        else:
            cc = details['compute_capability']
            self.assertIsInstance(cc, tuple)
            major, minor = cc
            self.assertGreater(major, 0)
            self.assertGreaterEqual(minor, 0)

        # Test GPU returned from get_visible_devices
        if len(gpus) > 2:
            config.set_visible_devices(gpus[1], 'GPU')
            (visible_gpu, ) = config.get_visible_devices('GPU')
            details = config.get_device_details(visible_gpu)
            self.assertIsInstance(details['device_name'], str)
예제 #2
0
    def testDeviceDetailsErrors(self):
        logical_devices = config.list_logical_devices()
        with self.assertRaisesRegex(ValueError,
                                    'must be a tf.config.PhysicalDevice'):
            config.get_device_details(logical_devices[0])

        phys_dev = context.PhysicalDevice('/physical_device:CPU:100', 'CPU')
        with self.assertRaisesRegex(
                ValueError, 'The PhysicalDevice must be one obtained from '
                'calling `tf.config.list_physical_devices`'):
            config.get_device_details(phys_dev)
예제 #3
0
def log_device_compatibility_check(policy_name):
    """Logs a compatibility check if the devices support the policy.

  Currently only logs for the policy mixed_float16. A log is shown only the
  first time this function is called.

  Args:
    policy_name: The name of the dtype policy.
  """
    global _logged_compatibility_check
    if _logged_compatibility_check:
        return
    _logged_compatibility_check = True
    gpus = config.list_physical_devices('GPU')
    gpu_details_list = [config.get_device_details(g) for g in gpus]
    _log_device_compatibility_check(policy_name, gpu_details_list)