Esempio n. 1
0
 def test_create_bad_client_workers(self):
     service_workers = [
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     client_workers = [
         ClientWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     self.assertRaisesRegex(
         ValueError,
         'client_workers argument must be a list of ClientWorker', Cluster,
         client_workers, service_workers)
Esempio n. 2
0
 def test_validate_bad_zone_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'us-central1-b'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(RuntimeError,
                            'All workers must be in the same zone',
                            cluster.validate)
Esempio n. 3
0
  def test_healthy_sea_service_cluster(self):
    tpu_resp_map = {
        'fake-tpu-{}'.format(ip): gen_fake_tpu_entry(
            'v3-8', ['10.0.0.{}'.format(ip)],
            'fake-tpu-{}'.format(ip),
            'READY',
            'pytorch-nightly',
            health='HEALTHY') for ip in range(256)
    }
    noop_compute_service = build_mock_compute_service({}, {})
    tpu_service = build_mock_tpu_service(tpu_resp_map)
    self.mock_discovery.side_effect = build_mock_services_fn(
        noop_compute_service, tpu_service)

    tpus = list(tpu_resp_map.keys())
    cr = ClusterResolver(tpus)
    service_workers = cr.get_service_workers()

    expected = [
        ServiceWorker(
            internal_ip='10.0.0.{}'.format(ip),
            port='8470',
            machine_type='v3-8',
            zone='fake-zone',
            sw_version='pytorch-nightly') for ip in range(256)
    ]
    self.assertCountEqual(expected, service_workers)
Esempio n. 4
0
 def test_validate_diff_num_workers(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     self.assertRaisesRegex(
         RuntimeError,
         'The client_workers and service_workers must have a 1:1 mapping',
         cluster.validate)
Esempio n. 5
0
 def test_validate_good_cluster(self):
     client_workers = [
         ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
         ClientWorker('10.0.0.3',
                      'n1-standard-16',
                      'europe-west4-a',
                      hostname='test'),
     ]
     service_workers = [
         ServiceWorker('10.0.0.0', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.1', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.2', 'v3-32', 'europe-west4-a'),
         ServiceWorker('10.0.0.3', 'v3-32', 'europe-west4-a'),
     ]
     cluster = Cluster(client_workers, service_workers)
     cluster.validate()  # Does not raise exception
Esempio n. 6
0
    def test_validate_machine_type_client_cluster(self):
        client_workers = [
            ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
            ClientWorker('10.0.0.1', 'n1-standard-8', 'europe-west4-a'),
        ]
        service_workers = [
            ServiceWorker('10.0.0.0', 'v3-8', 'europe-west4-a'),
            ServiceWorker('10.0.0.1', 'v3-8', 'europe-west4-a'),
        ]

        no_check_cluster = Cluster(client_workers,
                                   service_workers,
                                   check_client_machine_type=False)
        no_check_cluster.validate()  # Does not raise exception

        check_cluster = Cluster(client_workers, service_workers)
        self.assertRaisesRegex(
            RuntimeError, 'All client_workers must have the same machine_type',
            check_cluster.validate)
Esempio n. 7
0
  def test_healthy_cluster(self):
    list_instances_map = {
        'fake-ig': {
            'kind':
                'compute#instanceGroupsListInstances',
            'items': [
                gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                for c in 'abcd'
            ],
        },
    }
    instance_resp_map = {
        'fake-ig-' + c:
        gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                     '10.0.0.' + ip, 'RUNNING')
        for c, ip in zip('abcd', '0123')
    }
    compute_service = build_mock_compute_service(instance_resp_map,
                                                 list_instances_map)

    tpu_resp_map = {
        'fake-pod':
            gen_fake_tpu_entry(
                'v3-32', ['10.0.0.{}'.format(ip) for ip in range(4)],
                'fake-pod',
                'READY',
                'pytorch-nightly',
                health='HEALTHY'),
    }
    tpu_service = build_mock_tpu_service(tpu_resp_map)
    self.mock_discovery.side_effect = build_mock_services_fn(
        compute_service, tpu_service)

    tpus = list(tpu_resp_map.keys())
    cr = ClusterResolver(tpus)
    cluster = cr.get_cluster()

    expected_client_workers = [
        ClientWorker(
            internal_ip='10.0.0.' + ip,
            machine_type='n1-standard-16',
            zone='fake-zone',
            hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123')
    ]
    expected_service_workers = [
        ServiceWorker(
            internal_ip='10.0.0.{}'.format(ip),
            port='8470',
            machine_type='v3-32',
            zone='fake-zone',
            sw_version='pytorch-nightly') for ip in range(4)
    ]
    expected = Cluster(expected_client_workers, expected_service_workers)
    self.assertEqual(expected, cluster)
Esempio n. 8
0
 def test_validate_diff_sw_versions(self):
   client_workers = [
       ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'),
       ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a'),
   ]
   service_workers = [
       ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.1'),
       ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.2'),
       ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.1'),
       ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a',
                     'pytorch-0.2'),
   ]
   cluster = Cluster(client_workers, service_workers)
   self.assertRaisesRegex(
       RuntimeError, 'All service workers must have the same sw_version.*',
       cluster.validate)