Esempio n. 1
0
    def testNewNetworkEndpointFormat(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'state': 'READY',
                'health': 'HEALTHY',
                'networkEndpoints': [{
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                }]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
Esempio n. 2
0
    def testRetrieveProjectAndZoneFromMetadata(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'READY',
                'health': 'HEALTHY'
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project=None,
            zone=None,
            tpu=['test-tpu-1'],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = cluster_resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator'
      tasks { key: 0 value: '10.128.1.2:%s' }
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.1.2.3:8470' }
    }
    """ % cluster_resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
Esempio n. 3
0
 def testNoCallComputeMetadata(self):
   cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
   self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master())
   self.assertEqual(
       server_lib.ClusterSpec({
           'worker': ['10.1.2.3:8470']
       }).as_dict(),
       cluster_resolver.cluster_spec().as_dict())
    def testTpuTopology(self):
        cluster_resolver = resolver.TPUClusterResolver(tpu='local')
        self.assertIsNone(cluster_resolver._tpu_topology)

        # Test set with tpu topology proto.
        cluster_resolver.set_tpu_topology(
            serialized_tpu_topology=topology_pb2.TopologyProto(
                mesh_shape=[1, 1, 1, 1]).SerializeToString())
        self.assertIsInstance(cluster_resolver.tpu_hardware_feature,
                              topology_pb2.TPUHardwareFeature)
Esempio n. 5
0
 def verifyShouldResolve(self, tpu, should_resolve):
     cluster_resolver = resolver.TPUClusterResolver(
         project='test-project',
         zone='us-central1-c',
         tpu=tpu,
         coordinator_name=None,
         credentials=None,
         service=self.mock_service_client(tpu_map={}))
     self.assertEqual(should_resolve,
                      cluster_resolver._cloud_tpu_client.api_available(),
                      "TPU: '%s'" % tpu)
Esempio n. 6
0
    def testGetMasterNoEntries(self):
        tpu_map = {}

        with self.assertRaises(ValueError):
            resolver.TPUClusterResolver(
                project='test-project',
                zone='us-central1-c',
                tpu=[],
                coordinator_name=None,
                credentials=None,
                service=self.mock_service_client(tpu_map=tpu_map))
Esempio n. 7
0
    def testNumAcceleratorsSuccess(self, mock_list_devices,
                                   mock_eager_list_devices):
        devices = [
            LogicalDevice('/job:tpu_worker/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:1/device:TPU:1', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:2/device:TPU:0', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:3/device:TPU:1', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:0/device:TPU:4', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:1/device:TPU:5', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:2/device:TPU:4', 'TPU'),
            LogicalDevice('/job:tpu_worker/task:3/device:TPU:5', 'TPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            service=self.mock_service_client(tpu_map=tpu_map))
        self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
Esempio n. 8
0
    def testPodResolution(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            tpu='test-tpu-1',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = cluster_resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % cluster_resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
Esempio n. 9
0
    def testGkeEnvironmentForDonut(self):
        os.environ[
            'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'

        self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)

        cluster_resolver = resolver.TPUClusterResolver()
        self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'),
                         compat.as_bytes(cluster_resolver.master()))
        actual_cluster_spec = cluster_resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
    }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

        del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
Esempio n. 10
0
    def testNotReadyCloudTpu(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'CREATING'
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project=None,
            zone=None,
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        with self.assertRaises(RuntimeError):
            cluster_resolver.cluster_spec()
Esempio n. 11
0
    def testOverrideTaskTypeAndIndexAndGetMaster(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')

        cluster_resolver.task_type = 'worker'
        cluster_resolver.task_id = 3
        self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
Esempio n. 12
0
    def testFailedMetadata(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='nonexistent-tpu',
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        with self.assertRaises(ValueError) as context:
            cluster_resolver.cluster_spec()

        self.assertIn('Could not lookup TPU metadata', str(context.exception))
Esempio n. 13
0
    def testNumAcceleratorsRetryFailure(self, mock_list_devices,
                                        mock_eager_list_devices):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            service=self.mock_service_client(tpu_map=tpu_map))
        mock_list_devices.side_effect = errors.DeadlineExceededError(
            None, None, 'timeout')
        mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
            None, None, 'timeout')
        with self.assertRaises(RuntimeError):
            cluster_resolver.num_accelerators()
Esempio n. 14
0
 def testRpcDetectionForGrpcString(self):
     cluster_resolver = resolver.TPUClusterResolver(
         tpu='grpc://10.1.2.3:8470')
     self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
Esempio n. 15
0
 def testCheckRunningInGceWithNoTpuName(self):
     with self.assertRaisesRegex(
             ValueError, 'Please provide a TPU Name to connect to.*'):
         resolver.TPUClusterResolver(tpu='')
Esempio n. 16
0
 def testLocalTpuResolver(self):
   cr = resolver.TPUClusterResolver(tpu='local')
   self.assertEqual(cr.get_master(), '')