def test_create_bad_client_workers(self): service_workers = [ ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] client_workers = [ ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] self.assertRaisesRegex( ValueError, 'client_workers argument must be a list of ClientWorker', Cluster, client_workers, service_workers)
def test_validate_bad_zone_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex(RuntimeError, 'All workers must be in the same zone', cluster.validate)
def test_healthy_sea_service_cluster(self): tpu_resp_map = { 'fake-tpu-{}'.format(ip): gen_fake_tpu_entry( 'v3-8', ['10.0.0.{}'.format(ip)], 'fake-tpu-{}'.format(ip), 'READY', 'pytorch-nightly', health='HEALTHY') for ip in range(256) } noop_compute_service = build_mock_compute_service({}, {}) tpu_service = build_mock_tpu_service(tpu_resp_map) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service, tpu_service) tpus = list(tpu_resp_map.keys()) cr = ClusterResolver(tpus) service_workers = cr.get_service_workers() expected = [ ServiceWorker( internal_ip='10.0.0.{}'.format(ip), port='8470', machine_type='v3-8', zone='fake-zone', sw_version='pytorch-nightly') for ip in range(256) ] self.assertCountEqual(expected, service_workers)
def test_validate_diff_num_workers(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'The client_workers and service_workers must have a 1:1 mapping', cluster.validate)
def test_validate_good_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'), ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'), ] cluster = Cluster(client_workers, service_workers) cluster.validate() # Does not raise exception
def test_validate_machine_type_client_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-8', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'), ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'), ] no_check_cluster = Cluster(client_workers, service_workers, check_client_machine_type=False) no_check_cluster.validate() # Does not raise exception check_cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All client_workers must have the same machine_type', check_cluster.validate)
def test_healthy_cluster(self): list_instances_map = { 'fake-ig': { 'kind': 'compute#instanceGroupsListInstances', 'items': [ gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') for c in 'abcd' ], }, } instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) tpu_resp_map = { 'fake-pod': gen_fake_tpu_entry( 'v3-32', ['10.0.0.{}'.format(ip) for ip in range(4)], 'fake-pod', 'READY', 'pytorch-nightly', health='HEALTHY'), } tpu_service = build_mock_tpu_service(tpu_resp_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service, tpu_service) tpus = list(tpu_resp_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker( internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] expected_service_workers = [ ServiceWorker( internal_ip='10.0.0.{}'.format(ip), port='8470', machine_type='v3-32', zone='fake-zone', sw_version='pytorch-nightly') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers) self.assertEqual(expected, cluster)
def test_validate_diff_sw_versions(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.1'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) self.assertRaisesRegex( RuntimeError, 'All service workers must have the same sw_version.*', cluster.validate)