def testDeviceDetails(self): (cpu, ) = config.list_physical_devices('CPU') details = config.get_device_details(cpu) self.assertEqual(details, {}) if not test_util.is_gpu_available(): return gpus = config.list_physical_devices('GPU') details = config.get_device_details(gpus[0]) self.assertIsInstance(details['device_name'], str) self.assertNotEmpty(details['device_name']) if test.is_built_with_rocm(): # AMD GPUs do not have a compute capability self.assertNotIn('compute_capability', details) else: cc = details['compute_capability'] self.assertIsInstance(cc, tuple) major, minor = cc self.assertGreater(major, 0) self.assertGreaterEqual(minor, 0) # Test GPU returned from get_visible_devices if len(gpus) > 2: config.set_visible_devices(gpus[1], 'GPU') (visible_gpu, ) = config.get_visible_devices('GPU') details = config.get_device_details(visible_gpu) self.assertIsInstance(details['device_name'], str)
def testDeviceDetailsErrors(self): logical_devices = config.list_logical_devices() with self.assertRaisesRegex(ValueError, 'must be a tf.config.PhysicalDevice'): config.get_device_details(logical_devices[0]) phys_dev = context.PhysicalDevice('/physical_device:CPU:100', 'CPU') with self.assertRaisesRegex( ValueError, 'The PhysicalDevice must be one obtained from ' 'calling `tf.config.list_physical_devices`'): config.get_device_details(phys_dev)
def log_device_compatibility_check(policy_name): """Logs a compatibility check if the devices support the policy. Currently only logs for the policy mixed_float16. A log is shown only the first time this function is called. Args: policy_name: The name of the dtype policy. """ global _logged_compatibility_check if _logged_compatibility_check: return _logged_compatibility_check = True gpus = config.list_physical_devices('GPU') gpu_details_list = [config.get_device_details(g) for g in gpus] _log_device_compatibility_check(policy_name, gpu_details_list)