def test_healthy_pod_service_cluster(self): tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-32', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' } for ip in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) service_workers = cr.get_service_workers() expected = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='pytorch-nightly', tpu='fake-pod') for ip in range(4) ] self.assertCountEqual(expected, service_workers)
def test_healthy_sea_service_cluster(self): noop_compute_service = build_mock_compute_service({}, {}) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service) tpu_map = { f'fake-tpu-{ip}': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-8', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' }], } for ip in range(256) } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) service_workers = cr.get_service_workers() expected = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-8', zone='fake-zone', runtime_version='pytorch-nightly', tpu=f'fake-tpu-{ip}') for ip in range(256) ] self.assertCountEqual(expected, service_workers)