def testNumAcceleratorsRetryFailure(self, mock_list_devices,
                                     mock_eager_list_devices):
   resolver = TPUClusterResolver(tpu='')
   mock_list_devices.side_effect = errors.DeadlineExceededError(
       None, None, 'timeout')
   mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
       None, None, 'timeout')
   with self.assertRaises(RuntimeError):
     resolver.num_accelerators()
Example #2
0
    def testNumAcceleratorsSuccess(self, mock_list_devices):
        device_names = [
            '/job:tpu_worker/task:0/device:TPU:0',
            '/job:tpu_worker/task:1/device:TPU:1',
            '/job:tpu_worker/task:2/device:TPU:0',
            '/job:tpu_worker/task:3/device:TPU:1',
            '/job:tpu_worker/task:0/device:TPU:4',
            '/job:tpu_worker/task:1/device:TPU:5',
            '/job:tpu_worker/task:2/device:TPU:4',
            '/job:tpu_worker/task:3/device:TPU:5',
        ]
        device_list = [
            session._DeviceAttributes(name, 'TPU', 1024, 0)
            for name in device_names
        ]
        mock_list_devices.return_value = device_list

        resolver = TPUClusterResolver(tpu='')
        self.assertEqual(resolver.num_accelerators(), 2)
  def testNumAcceleratorsSuccess(self, mock_list_devices):
    device_names = [
        '/job:tpu_worker/task:0/device:TPU:0',
        '/job:tpu_worker/task:1/device:TPU:1',
        '/job:tpu_worker/task:2/device:TPU:0',
        '/job:tpu_worker/task:3/device:TPU:1',
        '/job:tpu_worker/task:0/device:TPU:4',
        '/job:tpu_worker/task:1/device:TPU:5',
        '/job:tpu_worker/task:2/device:TPU:4',
        '/job:tpu_worker/task:3/device:TPU:5',
    ]
    device_list = [
        session._DeviceAttributes(
            name, 'TPU', 1024, 0) for name in device_names
    ]
    mock_list_devices.return_value = device_list

    resolver = TPUClusterResolver(tpu='')
    self.assertEqual(resolver.num_accelerators(), 2)