def testRetrieveProjectAndZoneFromMetadata(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'health': 'HEALTHY'
        }
    }

    resolver = TPUClusterResolver(
        project=None,
        zone=None,
        tpu=['test-tpu-1'],
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map),
        coordinator_name='coordinator')

    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'coordinator'
      tasks { key: 0 value: '10.128.1.2:%s' }
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.1.2.3:8470' }
    }
    """ % resolver._coordinator_port
    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
  def testGkeEnvironmentForPod(self):
    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
                                                     'grpc://10.120.27.6:8470,'
                                                     'grpc://10.120.27.7:8470,'
                                                     'grpc://10.120.27.8:8470')

    self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
    self.assertTrue(TPUClusterResolver._inGke())
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470,'
                        'grpc://10.120.27.6:8470,'
                        'grpc://10.120.27.7:8470,'
                        'grpc://10.120.27.8:8470'),
        compat.as_bytes(TPUClusterResolver._gkeEndpoints()))

    resolver = TPUClusterResolver()
    self.assertEqual(
        compat.as_bytes('grpc://10.120.27.5:8470'),
        compat.as_bytes(resolver.master()))
    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
Пример #3
0
    def testRetrieveProjectAndZoneFromMetadata(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        resolver = TPUClusterResolver(
            project=None,
            zone=None,
            tpu=['test-tpu-1'],
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator'
      tasks { key: 0 value: '10.128.1.2:%s' }
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.1.2.3:8470' }
    }
    """ % resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
Пример #4
0
    def testNewNetworkEndpointFormat(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health': 'HEALTHY',
                'networkEndpoints': [{
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                }]
            }
        }

        resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual('grpc://10.2.3.4:8470', resolver.master())
Пример #5
0
    def testGkeEnvironmentForPod(self):
        os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = (
            'grpc://10.120.27.5:8470,'
            'grpc://10.120.27.6:8470,'
            'grpc://10.120.27.7:8470,'
            'grpc://10.120.27.8:8470')

        self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
        self.assertTrue(TPUClusterResolver._inGke())
        self.assertEqual(
            compat.as_bytes('grpc://10.120.27.5:8470,'
                            'grpc://10.120.27.6:8470,'
                            'grpc://10.120.27.7:8470,'
                            'grpc://10.120.27.8:8470'),
            compat.as_bytes(TPUClusterResolver._gkeEndpoints()))

        resolver = TPUClusterResolver()
        self.assertEqual(compat.as_bytes('grpc://10.120.27.5:8470'),
                         compat.as_bytes(resolver.master()))
        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'worker'
      tasks { key: 0 value: '10.120.27.5:8470' }
      tasks { key: 1 value: '10.120.27.6:8470' }
      tasks { key: 2 value: '10.120.27.7:8470' }
      tasks { key: 3 value: '10.120.27.8:8470' }
    }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)

        del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
  def testNewNetworkEndpointFormat(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health': 'HEALTHY',
            'networkEndpoints': [{
                'ipAddress': '10.2.3.4',
                'port': 8470,
            }]
        }
    }

    resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu='test-tpu-1',
        coordinator_name='coordinator',
        coordinator_address='10.128.1.5:10203',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    self.assertEqual('grpc://10.2.3.4:8470', resolver.master())
Пример #7
0
    def testSimpleSuccessfulRetrieval(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'health': 'HEALTHY'
            }
        }

        resolver = TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu=['test-tpu-1'],
            coordinator_name='coordinator',
            coordinator_address='10.128.1.5:10203',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
  def testSimpleSuccessfulRetrieval(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'health': 'HEALTHY'
        }
    }

    resolver = TPUClusterResolver(
        project='test-project',
        zone='us-central1-c',
        tpu=['test-tpu-1'],
        coordinator_name='coordinator',
        coordinator_address='10.128.1.5:10203',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
Пример #9
0
    def testNotReadyCloudTpu(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'ipAddress': '10.1.2.3',
                'port': '8470',
                'state': 'CREATING'
            }
        }

        resolver = TPUClusterResolver(
            project=None,
            zone=None,
            tpu='test-tpu-1',
            coordinator_name=None,
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map))

        with self.assertRaises(RuntimeError):
            resolver.cluster_spec()
  def testNotReadyCloudTpu(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'ipAddress': '10.1.2.3',
            'port': '8470',
            'state': 'CREATING'
        }
    }

    resolver = TPUClusterResolver(
        project=None,
        zone=None,
        tpu='test-tpu-1',
        coordinator_name=None,
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map))

    with self.assertRaises(RuntimeError):
      resolver.cluster_spec()
Пример #11
0
    def testPodResolution(self):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        resolver = TPUClusterResolver(
            tpu='test-tpu-1',
            credentials=None,
            service=self.mock_service_client(tpu_map=tpu_map),
            coordinator_name='coordinator')

        actual_cluster_spec = resolver.cluster_spec()
        expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % resolver._coordinator_port
        self._verifyClusterSpecEquality(actual_cluster_spec,
                                        str(expected_proto))
        self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
  def testPodResolution(self):
    tpu_map = {
        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
            'health':
                'HEALTHY',
            'networkEndpoints': [
                {
                    'ipAddress': '10.2.3.4',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.5',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.6',
                    'port': 8470,
                },
                {
                    'ipAddress': '10.2.3.7',
                    'port': 8470,
                },
            ]
        }
    }

    resolver = TPUClusterResolver(
        tpu='test-tpu-1',
        credentials=None,
        service=self.mock_service_client(tpu_map=tpu_map),
        coordinator_name='coordinator')

    actual_cluster_spec = resolver.cluster_spec()
    expected_proto = """
    job {
      name: 'coordinator',
      tasks { key: 0 value: '10.128.1.2:%s'}
    }
    job {
      name: 'worker'
      tasks { key: 0 value: '10.2.3.4:8470' }
      tasks { key: 1 value: '10.2.3.5:8470' }
      tasks { key: 2 value: '10.2.3.6:8470' }
      tasks { key: 3 value: '10.2.3.7:8470' }
    }
    """ % resolver._coordinator_port
    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
    self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
Пример #13
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("TPU system %s has already been initialized. "
                    "Reinitializing the TPU can cause previously created "
                    "variables on TPU to be lost.")

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    tpu_devices = sorted(
        [x for x in context.list_devices() if "device:TPU:" in x])

    if not tpu_devices:
      raise RuntimeError("Could not find any TPU devices")

    # Replace the remote TPU device with the remote TPU_SYSTEM system device. As
    # in the remote TPU device case, we will try to compile it instead of
    # running through optimization passes and TF Executor, but TPU_SYSTEM should
    # work.
    tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM")

    with ops.device(tpu_system_device):
      output = _tpu_init_fn()
    serialized_topology = output.numpy()
  else:
    master = cluster_resolver.master()
    cluster_spec = cluster_resolver.cluster_spec()

    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    if cluster_spec:
      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
Пример #14
0
 def testNoCallComputeMetadata(self):
     resolver = TPUClusterResolver(tpu='/bns/foo/bar')
     self.assertEqual('/bns/foo/bar', resolver.master())
     self.assertEqual(None, resolver.cluster_spec())
Пример #15
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution or if run in a
        tf.function.
  """
  job = None
  if cluster_resolver is None:
    # If no cluster resolver is specified, and running eagerly, execute the init
    # ops in the current device scope.
    if context.executing_eagerly():
      curr_device = device.DeviceSpec.from_string(context.context().device_name)
      if curr_device.job is not None:
        job = "{}/replica:0/task:0".format(curr_device.job)

    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning(
        "TPU system %s has already been initialized. "
        "Reinitializing the TPU can cause previously created "
        "variables on TPU to be lost.", tpu_name)

  logging.info("Initializing the TPU system: %s", tpu_name)

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    if tpu_name not in _LOCAL_MASTERS:
      # Explicitly place the tpu.initialize_system in the first worker to
      # avoid the output node match multiple devices error.
      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

    @function.defun
    def _tpu_init_fn():
      # In TF1, we usually close chips when compilation fails to clear the data
      # in infeed. In TF2, we don't need to do this because infeed is no longer
      # used, so user can recover from TPU compilation failures more smoothly.
      return tpu.initialize_system(
          job=job, compilation_failure_closes_chips=False)

    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
    # devices available.
    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
      output = _tpu_init_fn()

    # Clear out the eager context caches since the memory is invalid now.
    logging.info("Clearing out eager caches")
    context.context()._clear_caches()  # pylint: disable=protected-access

    serialized_topology = output.numpy()

    # TODO(b/134094971): Remove this when lazy tensor copy in multi-device
    # function has been implemented.
    context.context().mirroring_policy = context.MIRRORING_ALL
  elif not ops.executing_eagerly_outside_functions():
    master = cluster_resolver.master()
    cluster_spec = cluster_resolver.cluster_spec()

    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    if cluster_spec:
      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())
  else:
    raise RuntimeError("initialize_tpu_system is not supported within "
                       "tf.functions.")

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
Пример #16
0
def shutdown_tpu_system(cluster_resolver=None):
    """Shuts down the TPU devices.

  This will clear all caches, even those that are maintained through sequential
  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
  cache.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution or if run in a
        tf.function.
  """
    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "You are shutting down a TPU system %s that has not been "
            "initialized.")

    logging.info("Shutting down the TPU system: %s", tpu_name)

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        if tpu_name not in _LOCAL_MASTERS:
            # Explicitly place the tpu.shutdown_system in the first worker to
            # avoid the output node match multiple devices error.
            job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

        @function.defun
        def _tpu_shutdown_fn():
            tpu.shutdown_system(job=job)

        # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            _tpu_shutdown_fn()

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                sess.run(tpu.shutdown_system())
    else:
        raise RuntimeError("initialize_tpu_system is not supported within "
                           "tf.functions.")

    logging.info("Finished shutting down TPU system.")
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        del _INITIALIZED_TPU_SYSTEMS[tpu_name]
 def testNoCallComputeMetadata(self):
   resolver = TPUClusterResolver(
       tpu='/bns/foo/bar')
   self.assertEqual('/bns/foo/bar', resolver.master())
   self.assertEqual(None, resolver.cluster_spec())
Пример #18
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
    if cluster_resolver is None:
        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning("TPU system %s has already been initialized. "
                        "Reinitializing the TPU can cause previously created "
                        "variables on TPU to be lost.")

    logging.info("Initializing the TPU system: %s", tpu_name)

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        job = None
        if tpu_name not in _LOCAL_MASTERS:
            # Explicitly place the tpu.initialize_system in the first worker to
            # avoid the output node match multiple devices error.
            job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

        @function.defun
        def _tpu_init_fn():
            return tpu.initialize_system(job=job)

        # The TPU_SYSTEM device must match the device used in tpu.initialize_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            output = _tpu_init_fn()

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access

        serialized_topology = output.numpy()
    else:
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
Пример #19
0
 def testNoCallComputeMetadata(self):
   tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
   self.assertEqual(
       compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
   self.assertEqual(None, tpu_cluster_resolver.cluster_spec())