def test_validate_machine_type_service_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-8', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v2-8', 'europe-west4-a', 'pytorch-0.2'), ] no_check_cluster = Cluster(client_workers, service_workers, check_service_machine_type=False, client_master_ip='10.0.0.0') no_check_cluster.validate() # Does not raise exception check_cluster = Cluster(client_workers, service_workers, client_master_ip='10.0.0.0') self.assertRaisesRegex( RuntimeError, 'All service_workers must have the same machine_type', check_cluster.validate)
def test_validate_good_cluster(self): client_workers = [ ClientWorker('10.0.0.0', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.1', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.2', 'n1-standard-16', 'europe-west4-a'), ClientWorker('10.0.0.3', 'n1-standard-16', 'europe-west4-a', hostname='test'), ] service_workers = [ ServiceWorker('10.0.0.0', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.1', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.2', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ServiceWorker('10.0.0.3', '8470', 'v3-32', 'europe-west4-a', 'pytorch-0.2'), ] cluster = Cluster(client_workers, service_workers) cluster.validate() # Does not raise exception