Ejemplo n.º 1
0
    def test_validate_machine_type_service_cluster(self):
        client_workers = [
            ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
            ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
        ]
        service_workers = [
            ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                          'pytorch-0.2'),
            ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a',
                          'pytorch-0.2'),
        ]

        no_check_cluster = Cluster(client_workers,
                                   service_workers,
                                   check_service_machine_type=False,
                                   client_master_ip='10.0.0.0')
        no_check_cluster.validate()  # Does not raise exception

        check_cluster = Cluster(client_workers,
                                service_workers,
                                client_master_ip='10.0.0.0')
        self.assertRaisesRegex(
            RuntimeError,
            'All service_workers must have the same machine_type',
            check_cluster.validate)
Ejemplo n.º 2
0
 def test_validate_empty_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a')
     ]
     cluster = Cluster(client_workers, [], client_master_ip='10.0.0.0')
     self.assertRaisesRegex(
         RuntimeError,
         'Both client_workers and service_workers should not be empty',
         cluster.validate)
Ejemplo n.º 3
0
    def test_healthy_tpuvm_cluster(self):
        # Using TPUVM flavor of metadata.
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_tpuvm_metadata).start()
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'v2-nightly',
                'accelerator_type':
                'v3-32',
                'api_version':
                'V2_ALPHA1',
                'network_endpoints': [{
                    'ipAddress': f'10.1.0.{index}',
                    'port': '8470',
                } for index in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip=f'10.1.0.{index}',
                         machine_type='v3-32',
                         zone='fake-zone',
                         hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}')
            for index in range(4)
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.1.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='v2-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers,
                           expected_service_workers,
                           client_master_ip='10.1.0.0')
        self.assertEqual(expected, cluster)
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_metadata).start()
Ejemplo n.º 4
0
 def test_validate_good_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     cluster.validate()  # Does not raise exception
Ejemplo n.º 5
0
 def test_validate_bad_zone_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(RuntimeError,
                            'All workers must be in the same zone',
                            cluster.validate)
Ejemplo n.º 6
0
 def test_validate_diff_num_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'The client_workers and service_workers must have a 1:1 mapping',
         cluster.validate)
Ejemplo n.º 7
0
 def test_validate_diff_runtime_versions(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'All service workers must have the same runtime_version.*',
         cluster.validate)
Ejemplo n.º 8
0
    def test_healthy_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abcd'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers, expected_service_workers)
        self.assertEqual(expected, cluster)
Ejemplo n.º 9
0
 def test_validate_empty_workers(self):
     cluster = Cluster([], [])
     self.assertRaisesRegex(
         RuntimeError,
         'Both client_workers and service_workers should not be empty',
         cluster.validate)