def test_validate_good_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) cluster.validate() # Does not raise exception
def test_validate_machine_type_client_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-8', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] no_check_cluster = Cluster(client_workers, service_workers, check_client_machine_type=False) no_check_cluster.validate() # Does not raise exception check_cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All client_workers must have the same machine_type', check_cluster.validate)
def test_healthy_cluster(self): list_instances_map = { 'fake-ig': { 'kind': 'compute#instanceGroupsListInstances', 'items': [ gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') for c in 'abcd' ], }, } instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) tpu_resp_map = { 'fake-pod': gen_fake_tpu_entry( 'v3-32', ['10.0.0.{}'.format(ip) for ip in range(4)], 'fake-pod', 'READY', 'pytorch-nightly', health='HEALTHY'), } tpu_service = build_mock_tpu_service(tpu_resp_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service, tpu_service) tpus = list(tpu_resp_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker( internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] expected_service_workers = [ ServiceWorker( internal_ip='10.0.0.{}'.format(ip), port='8470', machine_type='v3-32', zone='fake-zone', sw_version='pytorch-nightly') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers) self.assertEqual(expected, cluster)
def test_validate_bad_zone_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex(RuntimeError, 'All workers must be in the same zone', cluster.validate)
def test_validate_diff_num_workers(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'The client_workers and service_workers must have a 1:1 mapping', cluster.validate)
def test_validate_diff_sw_versions(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All service workers must have the same sw_version.*', cluster.validate)
def test_validate_empty_workers(self): cluster = Cluster([], []) self.assertRaisesRegex( RuntimeError, 'Both client_workers and service_workers should not be empty', cluster.validate)