Exemple #1
0
    def test_healthy_pod_service_cluster(self):
        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        self.assertCountEqual(expected, service_workers)
Exemple #2
0
    def test_healthy_sea_service_cluster(self):
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)
        tpu_map = {
            f'fake-tpu-{ip}': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-8',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                }],
            }
            for ip in range(256)
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-8',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu=f'fake-tpu-{ip}') for ip in range(256)
        ]
        self.assertCountEqual(expected, service_workers)