Beispiel #1
0
 def testHandlesByteStrings(self):
     self.assertEqual(
         client.Client(tpu='tpu_name', zone='zone',
                       project='project')._full_name(),
         client.Client(tpu=b'tpu_name', zone=b'zone',
                       project=b'project')._full_name(),
     )
Beispiel #2
0
    def testRecoverableHBMOOMNoAPI(self):
        test_cases = [
            ({
                'projects/test-project/locations/us-central1-c/nodes/tpu_name':
                {
                    'state':
                    'READY',
                    'symptoms': [{
                        'createTime':
                        '2000-01-01T00:29:30.123456Z',
                        'symptomType':
                        'HBM_OUT_OF_MEMORY',
                        'details':
                        'The TPU HBM has run OOM at timestamp '
                        '2020-05-29T04:51:32.038721+00:00',
                        'workerId':
                        '0'
                    }]
                }
            }, True),
        ]

        for tpu_map, want in test_cases:
            c = client.Client(
                tpu='grpc://1.2.3.4:8470',
                service=self.mock_service_client(tpu_map=tpu_map))
            self.assertEqual(want, c.recoverable())
Beispiel #3
0
    def testRecoverableOOMDisabled(self):
        test_cases = [
            ({
                'projects/test-project/locations/us-central1-c/nodes/tpu_name':
                {
                    'state':
                    'READY',
                    'symptoms': [{
                        'createTime':
                        '2000-01-01T00:29:30.123456Z',
                        'symptomType':
                        'OUT_OF_MEMORY',
                        'details':
                        'The TPU runtime has run OOM at timestamp '
                        '2020-05-29T04:51:32.038721+00:00',
                        'workerId':
                        '0'
                    }]
                }
            }, True),
        ]

        FLAGS.runtime_oom_exit = False
        for tpu_map, want in test_cases:
            c = client.Client(
                tpu='tpu_name',
                service=self.mock_service_client(tpu_map=tpu_map))
            self.assertEqual(want, c.recoverable())
        FLAGS.runtime_oom_exit = True
Beispiel #4
0
    def testWaitForHealthy(self):
        time_mock = mock.patch.object(time, 'time', autospec=True).start()
        time_mock.side_effect = self._mock_time
        sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start()
        sleep_mock.side_effect = self._mock_sleep

        health_timeseries = (['UNHEALTHY_MAINTENANCE'] * 30 +
                             ['TIMEOUT'] * 10 + [None] * 20 + ['HEALTHY'] * 30)
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'READY',
                'health': health_timeseries,
            },
        }

        c = client.Client(tpu='tpu_name',
                          service=self.mock_service_client(tpu_map=tpu_map))

        # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout
        timeout = 80
        interval = 5
        return_time = 60
        c.wait_for_healthy(timeout_s=timeout, interval=interval)
        self.assertEqual(time.time(), return_time)
        self.assertEqual(sleep_mock.call_count, return_time / interval)
Beispiel #5
0
    def testWaitForHealthyRaisesError(self):
        time_mock = mock.patch.object(time, 'time', autospec=True).start()
        time_mock.side_effect = self._mock_time
        sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start()
        sleep_mock.side_effect = self._mock_sleep

        # Mock timeseries where takes longer than timeout.
        health_timeseries = ['UNHEALTHY_MAINTENANCE'] * 50 + ['TIMEOUT'] * 50
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'READY',
                'health': health_timeseries,
            },
        }

        c = client.Client(tpu='tpu_name',
                          service=self.mock_service_client(tpu_map=tpu_map))

        # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout
        with self.assertRaisesRegex(
                RuntimeError,
                'Timed out waiting for TPU .* to become healthy'):
            c.wait_for_healthy(timeout_s=80, interval=5)
    def testConfigureTpuVersion(self, urlopen):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
                'state':
                'READY',
                'networkEndpoints': [
                    {
                        'ipAddress': '1.2.3.4'
                    },
                    {
                        'ipAddress': '5.6.7.8'
                    },
                ]
            }
        }
        c = client.Client(tpu='tpu_name',
                          project='test-project',
                          zone='us-central1-c',
                          service=self.mock_service_client(tpu_map=tpu_map))
        c.configure_tpu_version('1.15')

        paths = [call[0][0].full_url for call in urlopen.call_args_list]

        self.assertEqual([
            'http://1.2.3.4:8475/requestversion/1.15',
            'http://5.6.7.8:8475/requestversion/1.15'
        ], sorted(paths))
Beispiel #7
0
 def testInitializeWithoutMetadata(self):
     c = client.Client(tpu='tpu_name', project='project', zone='zone')
     self.assertEqual('tpu_name', c._tpu)
     self.assertEqual(True, c._use_api)
     self.assertIsNone(c._service)
     self.assertIsNone(c._credentials)
     self.assertEqual('project', c._project)
     self.assertEqual('zone', c._zone)
     self.assertIsNone(c._discovery_url)
Beispiel #8
0
 def testRecoverableNoState(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
         }
     }
     c = client.Client(tpu='tpu_name',
                       service=self.mock_service_client(tpu_map=tpu_map))
     self.assertEqual(True, c.recoverable())
Beispiel #9
0
 def testInitializeNoArgumentsWithEnvironmentVariable(self):
     os.environ['TPU_NAME'] = 'tpu_name'
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
             'health': 'HEALTHY'
         }
     }
     c = client.Client(service=self.mock_service_client(tpu_map=tpu_map))
     self.assertClientContains(c)
Beispiel #10
0
 def testInitializeTpuName(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
             'health': 'HEALTHY'
         }
     }
     c = client.Client(tpu='tpu_name',
                       service=self.mock_service_client(tpu_map=tpu_map))
     self.assertClientContains(c)
Beispiel #11
0
 def testNetworkEndpointsNotReadyWithApi(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
         }
     }
     c = client.Client(tpu='tpu_name',
                       service=self.mock_service_client(tpu_map=tpu_map))
     self.assertRaisesRegex(RuntimeError,
                            'TPU .* is not yet ready; state: "None"',
                            c.network_endpoints)
Beispiel #12
0
 def testInitializeIpAddress(self):
     c = client.Client(tpu='grpc://1.2.3.4:8470')
     self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
     self.assertEqual(False, c._use_api)
     self.assertIsNone(c._service)
     self.assertIsNone(c._credentials)
     self.assertIsNone(c._project)
     self.assertIsNone(c._zone)
     self.assertIsNone(c._discovery_url)
     self.assertEqual([{
         'ipAddress': '1.2.3.4',
         'port': '8470'
     }], c.network_endpoints())
Beispiel #13
0
 def testAcceleratorTypeApi(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
             'state': 'PREEMPTED',
             'health': 'HEALTHY',
             'acceleratorType': 'v3-8',
             'tensorflowVersion': 'nightly',
         }
     }
     c = client.Client(tpu='tpu_name',
                       service=self.mock_service_client(tpu_map=tpu_map))
     self.assertEqual('v3-8', c.accelerator_type())
Beispiel #14
0
 def testGetTpuVersion(self, urlopen):
     c = client.Client(tpu='grpc://1.2.3.4:8470')
     resp = mock.Mock()
     resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}']
     urlopen.return_value = resp
     self.assertIsNone(c.runtime_version(),
                       'Missing key should be handled.')
     self.assertEqual('someVersion', c.runtime_version(),
                      'Should return configured version.')
     paths = [call[0][0].full_url for call in urlopen.call_args_list]
     self.assertCountEqual([
         'http://1.2.3.4:8475/requestversion',
         'http://1.2.3.4:8475/requestversion',
     ], sorted(paths))
Beispiel #15
0
 def testInitializeNoArgumentsWithTPUEnvironmentVariableTPUConfig(self):
     os.environ['TPU_CONFIG'] = json.dumps({
         'project': 'test-project',
         'zone': 'us-central1-c',
         'tpu_node_name': 'tpu_name',
     })
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'ipAddress': '10.1.2.3',
             'port': '8470',
             'state': 'READY',
             'health': 'HEALTHY',
         }
     }
     c = client.Client(service=self.mock_service_client(tpu_map=tpu_map))
     self.assertClientContains(c)
Beispiel #16
0
 def baseConfigureTpuVersion(self):
     tpu_map = {
         'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
             'state':
             'READY',
             'networkEndpoints': [
                 {
                     'ipAddress': '1.2.3.4'
                 },
                 {
                     'ipAddress': '5.6.7.8'
                 },
             ]
         }
     }
     return client.Client(tpu='tpu_name',
                          project='test-project',
                          zone='us-central1-c',
                          service=self.mock_service_client(tpu_map=tpu_map))
Beispiel #17
0
 def testRecoverableNoApiAccess(self):
     c = client.Client(tpu='grpc://1.2.3.4:8470')
     self.assertEqual(True, c.recoverable())
Beispiel #18
0
 def testInitializeNoArguments(self):
     with self.assertRaisesRegex(
             ValueError, 'Please provide a TPU Name to connect to.'):
         client.Client()
Beispiel #19
0
 def testInitializeMultiElementTpuArray(self):
     with self.assertRaisesRegex(
             NotImplementedError,
             'Using multiple TPUs in a single session is not yet implemented'
     ):
         client.Client(tpu=['multiple', 'elements'])