示例#1
0
 def test_validate_bad_zone_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(RuntimeError,
                            'All workers must be in the same zone',
                            cluster.validate)
示例#2
0
  def test_healthy_vm_list_client_cluster(self):
    # Arrange
    list_instances_map = {}
    instance_resp_map = {
        'fake-ig-' + c:
        gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                     '10.0.0.' + ip, 'RUNNING')
        for c, ip in zip('abcd', '0123')
    }
    compute_service = build_mock_compute_service(instance_resp_map,
                                                 list_instances_map)
    noop_tpu_service = build_mock_tpu_service({})
    self.mock_discovery.side_effect = build_mock_services_fn(
        compute_service, noop_tpu_service)

    # Act
    vms = ['fake-ig-a', 'fake-ig-b', 'fake-ig-c', 'fake-ig-d']
    cr = ClusterResolver(['fake-tpu'], vms=vms)
    vm_cluster = cr.get_client_workers()

    # Assert
    expected = [
        ClientWorker(
            internal_ip='10.0.0.' + ip,
            machine_type='n1-standard-16',
            zone='fake-zone',
            hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123')
    ]
    self.assertCountEqual(expected, vm_cluster)
示例#3
0
 def test_create_bad_service_workers(self):
   client_workers = [
       ClientWorker(
           '10.0.0.1', 'n1-standard-16', 'europe-west4-a', hostname='test'),
   ]
   self.assertRaisesRegex(
       ValueError, 'service_workers argument must be a list of ServiceWorker',
       Cluster, client_workers, client_workers)
示例#4
0
 def test_validate_diff_num_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'The client_workers and service_workers must have a 1:1 mapping',
         cluster.validate)
示例#5
0
 def test_validate_good_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     cluster.validate()  # Does not raise exception
示例#6
0
    def test_validate_machine_type_client_cluster(self):
        client_workers = [
            ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
            ClientWorker('10.0.0.1', 'n1-standard-8', 'europe-west4-a'),
        ]
        service_workers = [
            ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'),
            ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
        ]

        no_check_cluster = Cluster(client_workers,
                                   service_workers,
                                   check_client_machine_type=False)
        no_check_cluster.validate()  # Does not raise exception

        check_cluster = Cluster(client_workers, service_workers)
        self.assertRaisesRegex(
            RuntimeError, 'All client_workers must have the same machine_type',
            check_cluster.validate)
示例#7
0
  def test_healthy_cluster(self):
    list_instances_map = {
        'fake-ig': {
            'kind':
                'compute#instanceGroupsListInstances',
            'items': [
                gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                for c in 'abcd'
            ],
        },
    }
    instance_resp_map = {
        'fake-ig-' + c:
        gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                     '10.0.0.' + ip, 'RUNNING')
        for c, ip in zip('abcd', '0123')
    }
    compute_service = build_mock_compute_service(instance_resp_map,
                                                 list_instances_map)

    tpu_resp_map = {
        'fake-pod':
            gen_fake_tpu_entry(
                'v3-32', ['10.0.0.{}'.format(ip) for ip in range(4)],
                'fake-pod',
                'READY',
                'pytorch-nightly',
                health='HEALTHY'),
    }
    tpu_service = build_mock_tpu_service(tpu_resp_map)
    self.mock_discovery.side_effect = build_mock_services_fn(
        compute_service, tpu_service)

    tpus = list(tpu_resp_map.keys())
    cr = ClusterResolver(tpus)
    cluster = cr.get_cluster()

    expected_client_workers = [
        ClientWorker(
            internal_ip='10.0.0.' + ip,
            machine_type='n1-standard-16',
            zone='fake-zone',
            hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123')
    ]
    expected_service_workers = [
        ServiceWorker(
            internal_ip='10.0.0.{}'.format(ip),
            port='8470',
            machine_type='v3-32',
            zone='fake-zone',
            sw_version='pytorch-nightly') for ip in range(4)
    ]
    expected = Cluster(expected_client_workers, expected_service_workers)
    self.assertEqual(expected, cluster)
示例#8
0
 def test_validate_diff_sw_versions(self):
   client_workers = [
       ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'),
   ]
   service_workers = [
       ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.1'),
       ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.2'),
       ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.1'),
       ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.2'),
   ]
   cluster = Cluster(client_workers, service_workers)
   self.assertRaisesRegex(
       RuntimeError, 'All service workers must have the same sw_version.*',
       cluster.validate)
示例#9
0
 def test_create_bad_client_workers(self):
     service_workers = [
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     client_workers = [
         ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     self.assertRaisesRegex(
         ValueError,
         'client_workers argument must be a list of ClientWorker', Cluster,
         client_workers, service_workers)