def test_validate_machine_type_service_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a', 'pytorch-0.2'), ] no_check_cluster = Cluster(client_workers, service_workers, check_service_machine_type=False, client_master_ip='10.0.0.0') no_check_cluster.validate() # Does not raise exception check_cluster = Cluster(client_workers, service_workers, client_master_ip='10.0.0.0') self.assertRaisesRegex( RuntimeError, 'All service_workers must have the same machine_type', check_cluster.validate)
def test_create_bad_client_workers(self): service_workers = [ ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] client_workers = [ ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] self.assertRaisesRegex( ValueError, 'client_workers argument must be a list of ClientWorker', Cluster, client_workers, service_workers)
def test_validate_bad_zone_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex(RuntimeError, 'All workers must be in the same zone', cluster.validate)
def test_healthy_pod_service_cluster(self): tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-32', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' } for ip in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) service_workers = cr.get_service_workers() expected = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='pytorch-nightly', tpu='fake-pod') for ip in range(4) ] self.assertCountEqual(expected, service_workers)
def add_service_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError(('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format( tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format( tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = self._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() for endpoint in network_endpoints: worker = ServiceWorker(internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker)
def test_healthy_tpuvm_cluster(self): # Using TPUVM flavor of metadata. mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_tpuvm_metadata).start() noop_compute_service = build_mock_compute_service({}, {}) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'v2-nightly', 'accelerator_type': 'v3-32', 'api_version': 'V2_ALPHA1', 'network_endpoints': [{ 'ipAddress': f'10.1.0.{index}', 'port': '8470', } for index in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip=f'10.1.0.{index}', machine_type='v3-32', zone='fake-zone', hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}') for index in range(4) ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.1.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='v2-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers, client_master_ip='10.1.0.0') self.assertEqual(expected, cluster) mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_metadata).start()
def test_validate_diff_num_workers(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'The client_workers and service_workers must have a 1:1 mapping', cluster.validate)
def test_validate_good_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) cluster.validate() # Does not raise exception
def test_validate_diff_runtime_versions(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All service workers must have the same runtime_version.*', cluster.validate)
def add_tpu_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError(('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format( tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format( tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() if as_client_worker: ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping( tpu_name, zone, len(network_endpoints)) for endpoint in network_endpoints: if as_client_worker: internal_ip = endpoint['ipAddress'] hostname = ip_to_host_name[internal_ip] worker = ClientWorker(internal_ip=internal_ip, machine_type=machine_type, zone=zone, hostname=hostname) else: worker = ServiceWorker(internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker)
def test_healthy_sea_service_cluster(self): noop_compute_service = build_mock_compute_service({}, {}) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service) tpu_map = { f'fake-tpu-{ip}': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-8', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' }], } for ip in range(256) } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) service_workers = cr.get_service_workers() expected = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-8', zone='fake-zone', runtime_version='pytorch-nightly', tpu=f'fake-tpu-{ip}') for ip in range(256) ] self.assertCountEqual(expected, service_workers)
def test_healthy_cluster(self): list_instances_map = { 'fake-ig': { 'kind': 'compute#instanceGroupsListInstances', 'items': [ gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') for c in 'abcd' ], }, } instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-32', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' } for ip in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='pytorch-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers) self.assertEqual(expected, cluster)