예제 #1
0
    def testNumAcceleratorsFilterTasks(self, mock_list_devices,
                                       mock_eager_list_devices):
        devices = [
            LogicalDevice("/job:worker1/task:0/device:TPU:0", "TPU"),
            LogicalDevice("/job:worker1/task:0/device:TPU:1", "TPU"),
            LogicalDevice("/job:worker1/task:0/device:GPU:0", "GPU"),
            LogicalDevice("/job:worker1/task:0/device:GPU:1", "GPU"),
            LogicalDevice("/job:worker2/task:1/device:TPU:2", "TPU"),
            LogicalDevice("/job:worker2/task:2/device:TPU:3", "TPU"),
            LogicalDevice("/job:worker2/task:3/device:GPU:2", "GPU"),
            LogicalDevice("/job:worker2/task:4/device:GPU:3", "GPU"),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = MockBaseClusterResolver()
        self.assertEqual(
            resolver.num_accelerators(task_type="worker1", task_id=0), {
                "TPU": 2,
                "GPU": 2
            })
        self.assertEqual(
            resolver.num_accelerators(task_type="worker2", task_id=3),
            {"GPU": 1})
        self.assertEqual(
            resolver.num_accelerators(task_type="worker2", task_id=4),
            {"GPU": 1})
예제 #2
0
    def testNumAcceleratorsSuccess(self, mock_list_devices,
                                   mock_eager_list_devices):
        devices = [
            LogicalDevice("/job:worker/task:0/device:GPU:0", "GPU"),
            LogicalDevice("/job:worker/task:0/device:GPU:1", "GPU"),
            LogicalDevice("/job:worker/task:0/device:GPU:2", "GPU"),
            LogicalDevice("/job:worker/task:0/device:GPU:3", "GPU"),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = MockBaseClusterResolver()
        self.assertEqual(resolver.num_accelerators(), {"GPU": 4})
예제 #3
0
    def testNumAcceleratorsSuccess(self, mock_list_devices,
                                   mock_eager_list_devices):
        devices = [
            LogicalDevice('/job:tpu_worker/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:1/device:TPU:1', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:2/device:TPU:0', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:3/device:TPU:1', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:0/device:TPU:4', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:1/device:TPU:5', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:2/device:TPU:4', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:3/device:TPU:5', 'TPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            service=self.mock_service_client(tpu_map=tpu_map))
        self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
    def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices,
                                               mock_eager_list_devices):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker1": ["w10:2222"],
        "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "worker1",
        "index": "0"
      }
    }
    """

        devices = [
            LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'),
            LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'),
            LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'),
            LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'),
            LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = TFConfigClusterResolver()

        # By default we read from TF_CONFIG
        self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2})

        # Override still works when we want it to
        self.assertEqual(
            resolver.num_accelerators(task_type='worker2', task_id=3),
            {'GPU': 1})