def testNumAcceleratorsFilterTasks(self, mock_list_devices, mock_eager_list_devices): devices = [ LogicalDevice("/job:worker1/task:0/device:TPU:0", "TPU"), LogicalDevice("/job:worker1/task:0/device:TPU:1", "TPU"), LogicalDevice("/job:worker1/task:0/device:GPU:0", "GPU"), LogicalDevice("/job:worker1/task:0/device:GPU:1", "GPU"), LogicalDevice("/job:worker2/task:1/device:TPU:2", "TPU"), LogicalDevice("/job:worker2/task:2/device:TPU:3", "TPU"), LogicalDevice("/job:worker2/task:3/device:GPU:2", "GPU"), LogicalDevice("/job:worker2/task:4/device:GPU:3", "GPU"), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list resolver = MockBaseClusterResolver() self.assertEqual( resolver.num_accelerators(task_type="worker1", task_id=0), { "TPU": 2, "GPU": 2 }) self.assertEqual( resolver.num_accelerators(task_type="worker2", task_id=3), {"GPU": 1}) self.assertEqual( resolver.num_accelerators(task_type="worker2", task_id=4), {"GPU": 1})
def testNumAcceleratorsSuccess(self, mock_list_devices, mock_eager_list_devices): devices = [ LogicalDevice("/job:worker/task:0/device:GPU:0", "GPU"), LogicalDevice("/job:worker/task:0/device:GPU:1", "GPU"), LogicalDevice("/job:worker/task:0/device:GPU:2", "GPU"), LogicalDevice("/job:worker/task:0/device:GPU:3", "GPU"), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list resolver = MockBaseClusterResolver() self.assertEqual(resolver.num_accelerators(), {"GPU": 4})
def testNumAcceleratorsSuccess(self, mock_list_devices, mock_eager_list_devices): devices = [ LogicalDevice('/job:tpu_worker/task:0/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:0/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:5', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:5', 'TPU'), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices, mock_eager_list_devices): os.environ['TF_CONFIG'] = """ { "cluster": { "worker1": ["w10:2222"], "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"] }, "rpc_layer": "grpc", "task": { "type": "worker1", "index": "0" } } """ devices = [ LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'), LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'), LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'), LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'), LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'), LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'), LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'), LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list resolver = TFConfigClusterResolver() # By default we read from TF_CONFIG self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2}) # Override still works when we want it to self.assertEqual( resolver.num_accelerators(task_type='worker2', task_id=3), {'GPU': 1})