def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'health': 'HEALTHY', 'networkEndpoints': [{ 'ipAddress': '10.2.3.4', 'port': 8470, }] } } resolver = cluster_resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self.assertEqual('grpc://10.2.3.4:8470', resolver.master())
def testGkeEnvironmentForPod(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ( 'grpc://10.120.27.5:8470,' 'grpc://10.120.27.6:8470,' 'grpc://10.120.27.7:8470,' 'grpc://10.120.27.8:8470') self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(cluster_resolver.TPUClusterResolver._inGke()) self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470,' 'grpc://10.120.27.6:8470,' 'grpc://10.120.27.7:8470,' 'grpc://10.120.27.8:8470'), compat.as_bytes( cluster_resolver.TPUClusterResolver._gkeEndpoints())) resolver = cluster_resolver.TPUClusterResolver() self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'), compat.as_bytes(resolver.master())) actual_cluster_spec = resolver.cluster_spec() expected_proto = """ job { name: 'worker' tasks { key: 0 value: '10.120.27.5:8470' } tasks { key: 1 value: '10.120.27.6:8470' } tasks { key: 2 value: '10.120.27.7:8470' } tasks { key: 3 value: '10.120.27.8:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } } resolver = cluster_resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=['test-tpu-1'], coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } } resolver = cluster_resolver.TPUClusterResolver( project=None, zone=None, tpu=['test-tpu-1'], credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.2:%s' } } job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ % resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
def verifyShouldResolve(self, tpu, should_resolve): resolver = cluster_resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=tpu, coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map={})) self.assertEqual(should_resolve, resolver._shouldResolve(), "TPU: '%s'" % tpu)
def testGetMasterNoEntries(self): tpu_map = {} with self.assertRaises(ValueError): cluster_resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=[], coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map))
def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } resolver = cluster_resolver.TPUClusterResolver( tpu='test-tpu-1', credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = resolver.cluster_spec() expected_proto = """ job { name: 'coordinator', tasks { key: 0 value: '10.128.1.2:%s'} } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } tasks { key: 1 value: '10.2.3.5:8470' } tasks { key: 2 value: '10.2.3.6:8470' } tasks { key: 3 value: '10.2.3.7:8470' } } """ % resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
def testOverrideTaskTypeAndIndexAndGetMaster(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } resolver = cluster_resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470') resolver.task_type = 'worker' resolver.task_index = 3 self.assertEqual(resolver.master(), 'grpc://10.2.3.7:8470') self.assertEqual( resolver.master( task_type='worker', task_index=2, rpc_layer='test'), 'test://10.2.3.6:8470')
def testNumAcceleratorsSuccess(self, mock_list_devices): device_names = [ '/job:tpu_worker/task:0/device:TPU:0', '/job:tpu_worker/task:1/device:TPU:1', '/job:tpu_worker/task:2/device:TPU:0', '/job:tpu_worker/task:3/device:TPU:1', '/job:tpu_worker/task:0/device:TPU:4', '/job:tpu_worker/task:1/device:TPU:5', '/job:tpu_worker/task:2/device:TPU:4', '/job:tpu_worker/task:3/device:TPU:5', ] device_list = [ session._DeviceAttributes(name, 'TPU', 1024, 0) for name in device_names ] mock_list_devices.return_value = device_list resolver = cluster_resolver.TPUClusterResolver(tpu='') self.assertEqual(resolver.num_accelerators(), 2)
def testNotReadyCloudTpu(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'state': 'CREATING' } } resolver = cluster_resolver.TPUClusterResolver( project=None, zone=None, tpu='test-tpu-1', coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) with self.assertRaises(RuntimeError): resolver.cluster_spec()
def testNumAcceleratorsRetryFailure(self, mock_list_devices): resolver = cluster_resolver.TPUClusterResolver(tpu='') mock_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') with self.assertRaises(RuntimeError): resolver.num_accelerators()
def testEnvironmentAndRpcDetectionForGrpcString(self): resolver = cluster_resolver.TPUClusterResolver( tpu='grpc://10.1.2.3:8470') self.assertEqual(resolver.environment, '') self.assertEqual(resolver.rpc_layer, 'grpc') self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
def testEnvironmentAndRpcDetectionForGoogle(self): resolver = cluster_resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef') self.assertEqual(resolver.environment, 'google') self.assertEqual(resolver.rpc_layer, None)
def testNoCallComputeMetadata(self): resolver = cluster_resolver.TPUClusterResolver(tpu='/bns/foo/bar') self.assertEqual(compat.as_bytes('/bns/foo/bar'), resolver.master()) self.assertEqual(None, resolver.cluster_spec())
def testCheckRunningInGceWithNoTpuName(self): with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'): cluster_resolver.TPUClusterResolver(tpu='')