def test_validate_machine_type_service_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a', 'pytorch-0.2'), ] no_check_cluster = Cluster(client_workers, service_workers, check_service_machine_type=False, client_master_ip='10.0.0.0') no_check_cluster.validate() # Does not raise exception check_cluster = Cluster(client_workers, service_workers, client_master_ip='10.0.0.0') self.assertRaisesRegex( RuntimeError, 'All service_workers must have the same machine_type', check_cluster.validate)
def test_validate_bad_zone_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex(RuntimeError, 'All workers must be in the same zone', cluster.validate)
def test_healthy_vm_list_client_cluster(self): # Arrange list_instances_map = {} instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service) # Act vms = ['fake-ig-a', 'fake-ig-b', 'fake-ig-c', 'fake-ig-d'] cr = ClusterResolver(['fake-tpu'], vms=vms) vm_cluster = cr.get_client_workers() # Assert expected = [ ClientWorker(internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] self.assertCountEqual(expected, vm_cluster)
def add_tpu_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError( ('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format(tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError( ('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format(tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() for endpoint in network_endpoints: if as_client_worker: hostname = ClusterResolver._parse_resource_url( endpoint['greenVmSelflink'], 'instances') worker = ClientWorker( internal_ip=endpoint['ipAddress'], machine_type=machine_type, zone=zone, hostname=hostname) else: worker = ServiceWorker( internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker)
def test_validate_empty_workers(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a') ] cluster = Cluster(client_workers, [], client_master_ip='10.0.0.0') self.assertRaisesRegex( RuntimeError, 'Both client_workers and service_workers should not be empty', cluster.validate)
def test_create_bad_service_workers(self): client_workers = [ ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] self.assertRaisesRegex( ValueError, 'service_workers argument must be a list of ServiceWorker', Cluster, client_workers, client_workers)
def test_healthy_tpuvm_cluster(self): # Using TPUVM flavor of metadata. mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_tpuvm_metadata).start() noop_compute_service = build_mock_compute_service({}, {}) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'v2-nightly', 'accelerator_type': 'v3-32', 'api_version': 'V2_ALPHA1', 'network_endpoints': [{ 'ipAddress': f'10.1.0.{index}', 'port': '8470', } for index in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip=f'10.1.0.{index}', machine_type='v3-32', zone='fake-zone', hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}') for index in range(4) ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.1.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='v2-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers, client_master_ip='10.1.0.0') self.assertEqual(expected, cluster) mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_metadata).start()
def test_validate_diff_num_workers(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'The client_workers and service_workers must have a 1:1 mapping', cluster.validate)
def test_validate_good_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) cluster.validate() # Does not raise exception
def test_validate_diff_runtime_versions(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All service workers must have the same runtime_version.*', cluster.validate)
def test_create_bad_client_workers(self): service_workers = [ ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] client_workers = [ ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ] self.assertRaisesRegex( ValueError, 'client_workers argument must be a list of ClientWorker', Cluster, client_workers, service_workers)
def add_client_worker(request_id, resp, exception): """Callback for each request in BatchHttpRequest.""" if exception is not None: raise exception hostname = self._parse_resource_url(resp['selfLink'], 'instances') if resp['status'] != 'RUNNING': raise RuntimeError(('Instance {hostname} is not running yet. ' 'Re-run when all VMs are running').format( hostname=hostname)) worker = ClientWorker( internal_ip=resp['networkInterfaces'][0]['networkIP'], machine_type=self._parse_resource_url(resp['machineType'], 'machineTypes'), zone=self._parse_resource_url(resp['zone'], 'zones'), hostname=hostname) workers.append(worker)
def test_healthy_cluster(self): list_instances_map = { 'fake-ig': { 'kind': 'compute#instanceGroupsListInstances', 'items': [ gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') for c in 'abcd' ], }, } instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-32', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' } for ip in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='pytorch-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers) self.assertEqual(expected, cluster)