def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [{ 'ipAddress': '10.2.3.4', 'port': 8470, }] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'state': 'READY', 'health': 'HEALTHY' } } cluster_resolver = resolver.TPUClusterResolver( project=None, zone=None, tpu=['test-tpu-1'], credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.2:%s' } } job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ % cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testNoCallComputeMetadata(self): cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470') self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master()) self.assertEqual( server_lib.ClusterSpec({ 'worker': ['10.1.2.3:8470'] }).as_dict(), cluster_resolver.cluster_spec().as_dict())
def testTpuTopology(self): cluster_resolver = resolver.TPUClusterResolver(tpu='local') self.assertIsNone(cluster_resolver._tpu_topology) # Test set with tpu topology proto. cluster_resolver.set_tpu_topology( serialized_tpu_topology=topology_pb2.TopologyProto( mesh_shape=[1, 1, 1, 1]).SerializeToString()) self.assertIsInstance(cluster_resolver.tpu_hardware_feature, topology_pb2.TPUHardwareFeature)
def verifyShouldResolve(self, tpu, should_resolve): cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=tpu, coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map={})) self.assertEqual(should_resolve, cluster_resolver._cloud_tpu_client.api_available(), "TPU: '%s'" % tpu)
def testGetMasterNoEntries(self): tpu_map = {} with self.assertRaises(ValueError): resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=[], coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map))
def testNumAcceleratorsSuccess(self, mock_list_devices, mock_eager_list_devices): devices = [ LogicalDevice('/job:tpu_worker/task:0/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:0/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:5', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:5', 'TPU'), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( tpu='test-tpu-1', credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator', tasks { key: 0 value: '10.128.1.2:%s'} } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } tasks { key: 1 value: '10.2.3.5:8470' } tasks { key: 2 value: '10.2.3.6:8470' } tasks { key: 3 value: '10.2.3.7:8470' } } """ % cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
def testGkeEnvironmentForDonut(self): os.environ[ 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) cluster_resolver = resolver.TPUClusterResolver() self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'), compat.as_bytes(cluster_resolver.master())) actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'worker' tasks { key: 0 value: '10.120.27.5:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testNotReadyCloudTpu(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'state': 'CREATING' } } cluster_resolver = resolver.TPUClusterResolver( project=None, zone=None, tpu='test-tpu-1', coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) with self.assertRaises(RuntimeError): cluster_resolver.cluster_spec()
def testOverrideTaskTypeAndIndexAndGetMaster(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470') cluster_resolver.task_type = 'worker' cluster_resolver.task_id = 3 self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
def testFailedMetadata(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='nonexistent-tpu', coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) with self.assertRaises(ValueError) as context: cluster_resolver.cluster_spec() self.assertIn('Could not lookup TPU metadata', str(context.exception))
def testNumAcceleratorsRetryFailure(self, mock_list_devices, mock_eager_list_devices): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', service=self.mock_service_client(tpu_map=tpu_map)) mock_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') mock_eager_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') with self.assertRaises(RuntimeError): cluster_resolver.num_accelerators()
def testRpcDetectionForGrpcString(self): cluster_resolver = resolver.TPUClusterResolver( tpu='grpc://10.1.2.3:8470') self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testCheckRunningInGceWithNoTpuName(self): with self.assertRaisesRegex( ValueError, 'Please provide a TPU Name to connect to.*'): resolver.TPUClusterResolver(tpu='')
def testLocalTpuResolver(self): cr = resolver.TPUClusterResolver(tpu='local') self.assertEqual(cr.get_master(), '')