def testSimpleSuccessfulRetrieval(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu=['test-tpu-1'],
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470')
    def testMultipleSuccessfulRetrieval(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470'
            },
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
                'ipAddress': '10.4.5.6',
                'port': '8470'
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu_names=['test-tpu-2', 'test-tpu-1'],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
                             tasks { key: 1 value: '10.1.2.3:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
  def testRetrieveProjectAndZoneFromMetadata(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'health': 'HEALTHY'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project=None,
        zone=None,
        tpu=['test-tpu-1'],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map),
        coordinator_name='coordinator')

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'coordinator'
      tasks { key: 0 value: '10.128.1.2:%s' }
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.1.2.3:8470' }
    }
    """ % tpu_cluster_resolver._coordinator_port
    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
  def testGkeEnvironmentForPod(self):
    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
                                                     'grpc://10.120.27.6:8470,'
                                                     'grpc://10.120.27.7:8470,'
                                                     'grpc://10.120.27.8:8470')

    self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
    self.assertTrue(TPUClusterResolver._inGke())
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470,'
                        'grpc://10.120.27.6:8470,'
                        'grpc://10.120.27.7:8470,'
                        'grpc://10.120.27.8:8470'),
        compat.as_bytes(TPUClusterResolver._gkeEndpoints()))

    tpu_cluster_resolver = TPUClusterResolver()
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470'),
        compat.as_bytes(tpu_cluster_resolver.master()))
    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
  def testNewNetworkEndpointFormat(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health': 'HEALTHY',
            'networkEndpoints': [{
                'ipAddress': '10.2.3.4',
                'port': 8470,
            }]
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu='test-tpu-1',
        coordinator_name='coordinator',
        coordinator_address='10.128.1.5:10203',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master())
  def testMultipleSuccessfulRetrieval(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470'
        },
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
            'ipAddress': '10.4.5.6',
            'port': '8470'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu_names=['test-tpu-2', 'test-tpu-1'],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
                             tasks { key: 1 value: '10.1.2.3:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
  def testSimpleSuccessfulRetrieval(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'health': 'HEALTHY'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu=['test-tpu-1'],
        coordinator_name='coordinator',
        coordinator_address='10.128.1.5:10203',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470')
 def testGkeEnvironment(self):
   os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
   self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
   self.assertTrue(TPUClusterResolver._inGke())
   self.assertEqual(
       compat.as_bytes('grpc://10.120.27.5:8470'),
       compat.as_bytes(TPUClusterResolver._gkeMaster()))
   del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
 def verifyShouldResolve(self, tpu, should_resolve):
   tpu_cluster_resolver = TPUClusterResolver(
       project='test-project',
       zone='us-central1-c',
       tpu=tpu,
       coordinator_name=None,
       credentials=None,
       service=self.mock_service_client(tpu_map={}))
   self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(),
                    "TPU: '%s'" % tpu)
  def testGetMasterNoEntries(self):
    tpu_map = {}

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu_names=[],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))
    with self.assertRaises(ValueError):
      tpu_cluster_resolver.get_master()
    def testGetMasterNoEntries(self):
        tpu_map = {}

        tpu_cluster_resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu_names=[],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))
        with self.assertRaises(ValueError):
            tpu_cluster_resolver.get_master()
    def testPodResolution(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            tpu='test-tpu-1',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % tpu_cluster_resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470')
  def testPodResolution(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health':
                'HEALTHY',
            'networkEndpoints': [
                {
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.5',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.6',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.7',
                    'port': 8470,
                },
            ]
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        tpu='test-tpu-1',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map),
        coordinator_name='coordinator')

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % tpu_cluster_resolver._coordinator_port
    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
    self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470')
  def testPodResolutionNoCoordinator(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health':
                'HEALTHY',
            'networkEndpoints': [
                {
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.5',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.6',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.7',
                    'port': 8470,
                },
            ]
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu='test-tpu-1',
        coordinator_name=None,
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
  def testNotReadyCloudTpu(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'state': 'CREATING'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project=None,
        zone=None,
        tpu='test-tpu-1',
        coordinator_name=None,
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    with self.assertRaises(RuntimeError):
      tpu_cluster_resolver.cluster_spec()
    def testNotReadyCloudTpu(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'CREATING'
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            project=None,
            zone=None,
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        with self.assertRaises(RuntimeError):
            tpu_cluster_resolver.cluster_spec()
  def testGetMasterMultipleEntries(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470'
        },
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
            'ipAddress': '10.4.5.6',
            'port': '8470'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu_names=['test-tpu-2', 'test-tpu-1'],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))
    self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master())
  def testGetMasterNoEntries(self):
    tpu_map = {}

    with self.assertRaises(ValueError):
      TPUClusterResolver(
          project='test-project',
          zone='us-central1-c',
          tpu=[],
          coordinator_name=None,
          credentials=None,
          service=self.mock_service_client(tpu_map=tpu_map))
    def testGetMasterMultipleEntries(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470'
            },
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
                'ipAddress': '10.4.5.6',
                'port': '8470'
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu_names=['test-tpu-2', 'test-tpu-1'],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))
        self.assertEqual('grpc://10.4.5.6:8470',
                         tpu_cluster_resolver.get_master())
  def testRetrieveProjectAndZoneFromMetadata(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470'
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project=None,
        zone=None,
        tpu_names=['test-tpu-1'],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    def testGkeEnvironmentForPod(self):
        os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = (
            'grpc://10.120.27.5:8470,'
            'grpc://10.120.27.6:8470,'
            'grpc://10.120.27.7:8470,'
            'grpc://10.120.27.8:8470')

        self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
        self.assertTrue(TPUClusterResolver._inGke())
        self.assertEqual(
            compat.as_bytes('grpc://10.120.27.5:8470,'
                            'grpc://10.120.27.6:8470,'
                            'grpc://10.120.27.7:8470,'
                            'grpc://10.120.27.8:8470'),
            compat.as_bytes(TPUClusterResolver._gkeEndpoints()))

        tpu_cluster_resolver = TPUClusterResolver()
        self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'),
                         compat.as_bytes(tpu_cluster_resolver.master()))
        actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

        del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
    def testOverrideTaskTypeAndIndexAndGetMaster(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        tpu_cluster_resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470')

        tpu_cluster_resolver.task_type = 'worker'
        tpu_cluster_resolver.task_index = 3
        self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.7:8470')

        self.assertEqual(
            tpu_cluster_resolver.master(task_type='worker',
                                        task_index=2,
                                        rpc_layer='test'),
            'test://10.2.3.6:8470')
  def testOverrideTaskTypeAndIndexAndGetMaster(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health':
                'HEALTHY',
            'networkEndpoints': [
                {
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.5',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.6',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.7',
                    'port': 8470,
                },
            ]
        }
    }

    tpu_cluster_resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu='test-tpu-1',
        coordinator_name=None,
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470')

    tpu_cluster_resolver.task_type = 'worker'
    tpu_cluster_resolver.task_index = 3
    self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.7:8470')

    self.assertEqual(
        tpu_cluster_resolver.master(
            task_type='worker', task_index=2, rpc_layer='test'),
        'test://10.2.3.6:8470')
 def testEnvironmentAndRpcDetectionForGrpcString(self):
   tpu_cluster_resolver = TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
   self.assertEqual(tpu_cluster_resolver.environment, '')
   self.assertEqual(tpu_cluster_resolver.rpc_layer, 'grpc')
   self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470')
示例#25
0
  # Loss sub-graph
  loss = tf.reduce_sum(tf.square(y - my_label))
  # Training sub-graph
  global_step = tf.train.get_global_step()
  optimizer = tf.train.GradientDescentOptimizer(0.01)
  train = tf.group(optimizer.minimize(loss),
                   tf.assign_add(global_step, 1))
  # ModelFnOps connects subgraphs we built to the
  # appropriate functionality.
  return tf.estimator.EstimatorSpec(
      mode=mode, predictions=y,
      loss=loss,
      train_op=train)

tpu_config = tf.contrib.tpu.RunConfig(master=TPUClusterResolver(tpu=[os.environ['TPU_NAME']]).get_master())
tpu_config.replace(model_dir='.')

#load customized model to estimator
estimator = tf.contrib.tpu.TPUEstimator(model_fn=model, config=tpu_config, use_tpu=True, predict_batch_size=4, train_batch_size=4)
# define our data sets
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])

#input_fn = tf.estimator.inputs.numpy_input_fn({"x": x_train}, y_train, batch_size=4, num_epochs=10, shuffle=False)
#eval_input_fn = tf.estimator.inputs.numpy_input_fn({"x":x_eval}, y_eval, batch_size=4, num_epochs=1, shuffle=False)

# train
estimator.train(input_fn=predict_input_fn(), steps=10)
 def testNoCallComputeMetadata(self):
   tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
   self.assertEqual(
       compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
   self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
 def testEnvironmentDiscoveryUrl(self):
     os.environ[
         'TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
     self.assertEqual('https://{api}.internal/{apiVersion}',
                      TPUClusterResolver._environmentDiscoveryUrl())
 def testEnvironmentAndRpcDetectionForGoogle(self):
     tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/ab/cd/ef')
     self.assertEqual(tpu_cluster_resolver.environment, 'google')
     self.assertEqual(tpu_cluster_resolver.rpc_layer, None)
 def testEnvironmentAndRpcDetectionForGrpcString(self):
     tpu_cluster_resolver = TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
     self.assertEqual(tpu_cluster_resolver.environment, '')
     self.assertEqual(tpu_cluster_resolver.rpc_layer, 'grpc')
     self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470')
 def testDiscoveryUrl(self):
   os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
   self.assertEqual('https://{api}.internal/{apiVersion}',
                    TPUClusterResolver._discoveryUrl())