Ejemplo n.º 1
0
    def testNewNetworkEndpointFormat(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health': 'HEALTHY',
                'networkEndpoints': [{
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                }]
            }
        }

        resolver = cluster_resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual('grpc://10.2.3.4:8470', resolver.master())
Ejemplo n.º 2
0
    def testGkeEnvironmentForPod(self):
        os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = (
            'grpc://10.120.27.5:8470,'
            'grpc://10.120.27.6:8470,'
            'grpc://10.120.27.7:8470,'
            'grpc://10.120.27.8:8470')

        self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
        self.assertTrue(cluster_resolver.TPUClusterResolver._inGke())
        self.assertEqual(
            compat.as_bytes('grpc://10.120.27.5:8470,'
                            'grpc://10.120.27.6:8470,'
                            'grpc://10.120.27.7:8470,'
                            'grpc://10.120.27.8:8470'),
            compat.as_bytes(
                cluster_resolver.TPUClusterResolver._gkeEndpoints()))

        resolver = cluster_resolver.TPUClusterResolver()
        self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'),
                         compat.as_bytes(resolver.master()))
        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

        del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
Ejemplo n.º 3
0
    def testSimpleSuccessfulRetrieval(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        resolver = cluster_resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu=['test-tpu-1'],
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
Ejemplo n.º 4
0
    def testRetrieveProjectAndZoneFromMetadata(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        resolver = cluster_resolver.TPUClusterResolver(
            project=None,
            zone=None,
            tpu=['test-tpu-1'],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator'
      tasks { key: 0 value: '10.128.1.2:%s' }
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.1.2.3:8470' }
    }
    """ % resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
Ejemplo n.º 5
0
 def verifyShouldResolve(self, tpu, should_resolve):
     resolver = cluster_resolver.TPUClusterResolver(
         project='test-project',
         zone='us-central1-c',
         tpu=tpu,
         coordinator_name=None,
         credentials=None,
         service=self.mock_service_client(tpu_map={}))
     self.assertEqual(should_resolve, resolver._shouldResolve(),
                      "TPU: '%s'" % tpu)
Ejemplo n.º 6
0
    def testGetMasterNoEntries(self):
        tpu_map = {}

        with self.assertRaises(ValueError):
            cluster_resolver.TPUClusterResolver(
                project='test-project',
                zone='us-central1-c',
                tpu=[],
                coordinator_name=None,
                credentials=None,
                service=self.mock_service_client(tpu_map=tpu_map))
Ejemplo n.º 7
0
    def testPodResolution(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        resolver = cluster_resolver.TPUClusterResolver(
            tpu='test-tpu-1',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
  def testOverrideTaskTypeAndIndexAndGetMaster(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health':
                'HEALTHY',
            'networkEndpoints': [
                {
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.5',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.6',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.7',
                    'port': 8470,
                },
            ]
        }
    }

    resolver = cluster_resolver.TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu='test-tpu-1',
        coordinator_name=None,
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')

    resolver.task_type = 'worker'
    resolver.task_index = 3
    self.assertEqual(resolver.master(), 'grpc://10.2.3.7:8470')

    self.assertEqual(
        resolver.master(
            task_type='worker', task_index=2, rpc_layer='test'),
        'test://10.2.3.6:8470')
Ejemplo n.º 9
0
    def testNumAcceleratorsSuccess(self, mock_list_devices):
        device_names = [
            '/job:tpu_worker/task:0/device:TPU:0',
            '/job:tpu_worker/task:1/device:TPU:1',
            '/job:tpu_worker/task:2/device:TPU:0',
            '/job:tpu_worker/task:3/device:TPU:1',
            '/job:tpu_worker/task:0/device:TPU:4',
            '/job:tpu_worker/task:1/device:TPU:5',
            '/job:tpu_worker/task:2/device:TPU:4',
            '/job:tpu_worker/task:3/device:TPU:5',
        ]
        device_list = [
            session._DeviceAttributes(name, 'TPU', 1024, 0)
            for name in device_names
        ]
        mock_list_devices.return_value = device_list

        resolver = cluster_resolver.TPUClusterResolver(tpu='')
        self.assertEqual(resolver.num_accelerators(), 2)
Ejemplo n.º 10
0
    def testNotReadyCloudTpu(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'CREATING'
            }
        }

        resolver = cluster_resolver.TPUClusterResolver(
            project=None,
            zone=None,
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        with self.assertRaises(RuntimeError):
            resolver.cluster_spec()
Ejemplo n.º 11
0
 def testNumAcceleratorsRetryFailure(self, mock_list_devices):
     resolver = cluster_resolver.TPUClusterResolver(tpu='')
     mock_list_devices.side_effect = errors.DeadlineExceededError(
         None, None, 'timeout')
     with self.assertRaises(RuntimeError):
         resolver.num_accelerators()
Ejemplo n.º 12
0
 def testEnvironmentAndRpcDetectionForGrpcString(self):
     resolver = cluster_resolver.TPUClusterResolver(
         tpu='grpc://10.1.2.3:8470')
     self.assertEqual(resolver.environment, '')
     self.assertEqual(resolver.rpc_layer, 'grpc')
     self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
Ejemplo n.º 13
0
 def testEnvironmentAndRpcDetectionForGoogle(self):
     resolver = cluster_resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
     self.assertEqual(resolver.environment, 'google')
     self.assertEqual(resolver.rpc_layer, None)
Ejemplo n.º 14
0
 def testNoCallComputeMetadata(self):
     resolver = cluster_resolver.TPUClusterResolver(tpu='/bns/foo/bar')
     self.assertEqual(compat.as_bytes('/bns/foo/bar'), resolver.master())
     self.assertEqual(None, resolver.cluster_spec())
Ejemplo n.º 15
0
 def testCheckRunningInGceWithNoTpuName(self):
     with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
         cluster_resolver.TPUClusterResolver(tpu='')