Exemplo n.º 1
0
    def test_validate_machine_type_service_cluster(self):
        client_workers = [
            ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
            ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
        ]
        service_workers = [
            ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                          'pytorch-0.2'),
            ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a',
                          'pytorch-0.2'),
        ]

        no_check_cluster = Cluster(client_workers,
                                   service_workers,
                                   check_service_machine_type=False,
                                   client_master_ip='10.0.0.0')
        no_check_cluster.validate()  # Does not raise exception

        check_cluster = Cluster(client_workers,
                                service_workers,
                                client_master_ip='10.0.0.0')
        self.assertRaisesRegex(
            RuntimeError,
            'All service_workers must have the same machine_type',
            check_cluster.validate)
Exemplo n.º 2
0
 def test_validate_bad_zone_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(RuntimeError,
                            'All workers must be in the same zone',
                            cluster.validate)
Exemplo n.º 3
0
    def test_healthy_vm_list_client_cluster(self):
        # Arrange
        list_instances_map = {}
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        # Act
        vms = ['fake-ig-a', 'fake-ig-b', 'fake-ig-c', 'fake-ig-d']
        cr = ClusterResolver(['fake-tpu'], vms=vms)
        vm_cluster = cr.get_client_workers()

        # Assert
        expected = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        self.assertCountEqual(expected, vm_cluster)
Exemplo n.º 4
0
    def add_tpu_worker(tpu_name):
      ctc = cloud_tpu_client.Client(tpu=tpu_name)
      tpu_name = ctc.name()
      if ctc.state() != 'READY':
        raise RuntimeError(
            ('TPU {tpu_name} is not READY yet. '
             'Re-run when all TPUs are READY').format(tpu_name=tpu_name))
      if ctc.health() != 'HEALTHY':
        raise RuntimeError(
            ('TPU {tpu_name} is not HEALTHY yet. '
             'Re-run when all TPUs are HEALTHY').format(tpu_name=tpu_name))

      runtime_version = ctc.runtime_version()
      machine_type = ctc.accelerator_type()
      zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations')
      network_endpoints = ctc.network_endpoints()

      for endpoint in network_endpoints:
        if as_client_worker:
          hostname = ClusterResolver._parse_resource_url(
              endpoint['greenVmSelflink'], 'instances')
          worker = ClientWorker(
              internal_ip=endpoint['ipAddress'],
              machine_type=machine_type,
              zone=zone,
              hostname=hostname)
        else:
          worker = ServiceWorker(
              internal_ip=endpoint['ipAddress'],
              port=endpoint['port'],
              machine_type=machine_type,
              zone=zone,
              runtime_version=runtime_version,
              tpu=tpu_name)
        workers.append(worker)
Exemplo n.º 5
0
 def test_validate_empty_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a')
     ]
     cluster = Cluster(client_workers, [], client_master_ip='10.0.0.0')
     self.assertRaisesRegex(
         RuntimeError,
         'Both client_workers and service_workers should not be empty',
         cluster.validate)
Exemplo n.º 6
0
 def test_create_bad_service_workers(self):
     client_workers = [
         ClientWorker('10.0.0.1',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     self.assertRaisesRegex(
         ValueError,
         'service_workers argument must be a list of ServiceWorker',
         Cluster, client_workers, client_workers)
Exemplo n.º 7
0
    def test_healthy_tpuvm_cluster(self):
        # Using TPUVM flavor of metadata.
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_tpuvm_metadata).start()
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'v2-nightly',
                'accelerator_type':
                'v3-32',
                'api_version':
                'V2_ALPHA1',
                'network_endpoints': [{
                    'ipAddress': f'10.1.0.{index}',
                    'port': '8470',
                } for index in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip=f'10.1.0.{index}',
                         machine_type='v3-32',
                         zone='fake-zone',
                         hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}')
            for index in range(4)
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.1.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='v2-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers,
                           expected_service_workers,
                           client_master_ip='10.1.0.0')
        self.assertEqual(expected, cluster)
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_metadata).start()
Exemplo n.º 8
0
 def test_validate_diff_num_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'The client_workers and service_workers must have a 1:1 mapping',
         cluster.validate)
Exemplo n.º 9
0
 def test_validate_good_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     cluster.validate()  # Does not raise exception
Exemplo n.º 10
0
 def test_validate_diff_runtime_versions(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
         ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.1'),
         ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'All service workers must have the same runtime_version.*',
         cluster.validate)
Exemplo n.º 11
0
 def test_create_bad_client_workers(self):
     service_workers = [
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     client_workers = [
         ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', '8470', 'v3-8', 'europe-west4-a',
                       'pytorch-0.2'),
     ]
     self.assertRaisesRegex(
         ValueError,
         'client_workers argument must be a list of ClientWorker', Cluster,
         client_workers, service_workers)
Exemplo n.º 12
0
 def add_client_worker(request_id, resp, exception):
     """Callback for each request in BatchHttpRequest."""
     if exception is not None:
         raise exception
     hostname = self._parse_resource_url(resp['selfLink'], 'instances')
     if resp['status'] != 'RUNNING':
         raise RuntimeError(('Instance {hostname} is not running yet. '
                             'Re-run when all VMs are running').format(
                                 hostname=hostname))
     worker = ClientWorker(
         internal_ip=resp['networkInterfaces'][0]['networkIP'],
         machine_type=self._parse_resource_url(resp['machineType'],
                                               'machineTypes'),
         zone=self._parse_resource_url(resp['zone'], 'zones'),
         hostname=hostname)
     workers.append(worker)
Exemplo n.º 13
0
    def test_healthy_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abcd'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers, expected_service_workers)
        self.assertEqual(expected, cluster)