예제 #1
0
    def test_validate_machine_type_service_cluster(self):
        client_workers = [
            ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
            ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
        ]
        service_workers = [
            ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                          'pytorch-0.2'),
            ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a',
                          'pytorch-0.2'),
        ]

        no_check_cluster = Cluster(client_workers,
                                   service_workers,
                                   check_service_machine_type=False,
                                   client_master_ip='10.0.0.0')
        no_check_cluster.validate()  # Does not raise exception

        check_cluster = Cluster(client_workers,
                                service_workers,
                                client_master_ip='10.0.0.0')
        self.assertRaisesRegex(
            RuntimeError,
            'All service_workers must have the same machine_type',
            check_cluster.validate)
예제 #2
0
 def test_create_bad_client_workers(self):
     service_workers = [
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     client_workers = [
         ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     self.assertRaisesRegex(
         ValueError,
         'client_workers argument must be a list of ClientWorker', Cluster,
         client_workers, service_workers)
예제 #3
0
 def test_validate_bad_zone_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(RuntimeError,
                            'All workers must be in the same zone',
                            cluster.validate)
예제 #4
0
    def test_healthy_pod_service_cluster(self):
        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        self.assertCountEqual(expected, service_workers)
예제 #5
0
파일: cluster.py 프로젝트: sieginglion/xla
        def add_service_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            tpu_name = ctc.name()
            if ctc.state() != 'READY':
                raise RuntimeError(('TPU {tpu_name} is not READY yet. '
                                    'Re-run when all TPUs are READY').format(
                                        tpu_name=tpu_name))
            if ctc.health() != 'HEALTHY':
                raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. '
                                    'Re-run when all TPUs are HEALTHY').format(
                                        tpu_name=tpu_name))

            runtime_version = ctc.runtime_version()
            machine_type = ctc.accelerator_type()
            zone = self._parse_resource_url(ctc._full_name(), 'locations')
            network_endpoints = ctc.network_endpoints()

            for endpoint in network_endpoints:
                worker = ServiceWorker(internal_ip=endpoint['ipAddress'],
                                       port=endpoint['port'],
                                       machine_type=machine_type,
                                       zone=zone,
                                       runtime_version=runtime_version,
                                       tpu=tpu_name)
                workers.append(worker)
예제 #6
0
    def test_healthy_tpuvm_cluster(self):
        # Using TPUVM flavor of metadata.
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_tpuvm_metadata).start()
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'v2-nightly',
                'accelerator_type':
                'v3-32',
                'api_version':
                'V2_ALPHA1',
                'network_endpoints': [{
                    'ipAddress': f'10.1.0.{index}',
                    'port': '8470',
                } for index in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip=f'10.1.0.{index}',
                         machine_type='v3-32',
                         zone='fake-zone',
                         hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}')
            for index in range(4)
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.1.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='v2-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers,
                           expected_service_workers,
                           client_master_ip='10.1.0.0')
        self.assertEqual(expected, cluster)
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_metadata).start()
예제 #7
0
 def test_validate_diff_num_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'The client_workers and service_workers must have a 1:1 mapping',
         cluster.validate)
예제 #8
0
 def test_validate_good_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     cluster.validate()  # Does not raise exception
예제 #9
0
 def test_validate_diff_runtime_versions(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'All service workers must have the same runtime_version.*',
         cluster.validate)
예제 #10
0
        def add_tpu_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            tpu_name = ctc.name()
            if ctc.state() != 'READY':
                raise RuntimeError(('TPU {tpu_name} is not READY yet. '
                                    'Re-run when all TPUs are READY').format(
                                        tpu_name=tpu_name))
            if ctc.health() != 'HEALTHY':
                raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. '
                                    'Re-run when all TPUs are HEALTHY').format(
                                        tpu_name=tpu_name))

            runtime_version = ctc.runtime_version()
            machine_type = ctc.accelerator_type()
            zone = ClusterResolver._parse_resource_url(ctc._full_name(),
                                                       'locations')
            network_endpoints = ctc.network_endpoints()

            if as_client_worker:
                ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping(
                    tpu_name, zone, len(network_endpoints))

            for endpoint in network_endpoints:
                if as_client_worker:
                    internal_ip = endpoint['ipAddress']
                    hostname = ip_to_host_name[internal_ip]
                    worker = ClientWorker(internal_ip=internal_ip,
                                          machine_type=machine_type,
                                          zone=zone,
                                          hostname=hostname)
                else:
                    worker = ServiceWorker(internal_ip=endpoint['ipAddress'],
                                           port=endpoint['port'],
                                           machine_type=machine_type,
                                           zone=zone,
                                           runtime_version=runtime_version,
                                           tpu=tpu_name)
                workers.append(worker)
예제 #11
0
    def test_healthy_sea_service_cluster(self):
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)
        tpu_map = {
            f'fake-tpu-{ip}': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-8',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                }],
            }
            for ip in range(256)
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-8',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu=f'fake-tpu-{ip}') for ip in range(256)
        ]
        self.assertCountEqual(expected, service_workers)
예제 #12
0
    def test_healthy_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abcd'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers, expected_service_workers)
        self.assertEqual(expected, cluster)