def testNumAcceleratorsRetryFailure(self, mock_list_devices, mock_eager_list_devices): resolver = TPUClusterResolver(tpu='') mock_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') mock_eager_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') with self.assertRaises(RuntimeError): resolver.num_accelerators()
def testNumAcceleratorsSuccess(self, mock_list_devices): device_names = [ '/job:tpu_worker/task:0/device:TPU:0', '/job:tpu_worker/task:1/device:TPU:1', '/job:tpu_worker/task:2/device:TPU:0', '/job:tpu_worker/task:3/device:TPU:1', '/job:tpu_worker/task:0/device:TPU:4', '/job:tpu_worker/task:1/device:TPU:5', '/job:tpu_worker/task:2/device:TPU:4', '/job:tpu_worker/task:3/device:TPU:5', ] device_list = [ session._DeviceAttributes(name, 'TPU', 1024, 0) for name in device_names ] mock_list_devices.return_value = device_list resolver = TPUClusterResolver(tpu='') self.assertEqual(resolver.num_accelerators(), 2)
def testNumAcceleratorsSuccess(self, mock_list_devices): device_names = [ '/job:tpu_worker/task:0/device:TPU:0', '/job:tpu_worker/task:1/device:TPU:1', '/job:tpu_worker/task:2/device:TPU:0', '/job:tpu_worker/task:3/device:TPU:1', '/job:tpu_worker/task:0/device:TPU:4', '/job:tpu_worker/task:1/device:TPU:5', '/job:tpu_worker/task:2/device:TPU:4', '/job:tpu_worker/task:3/device:TPU:5', ] device_list = [ session._DeviceAttributes( name, 'TPU', 1024, 0) for name in device_names ] mock_list_devices.return_value = device_list resolver = TPUClusterResolver(tpu='') self.assertEqual(resolver.num_accelerators(), 2)