Ejemplo n.º 1
0
 def testNumAcceleratorsRetryFailure(self, mock_list_devices,
                                     mock_eager_list_devices):
   cluster_resolver = resolver.TPUClusterResolver(tpu='')
   mock_list_devices.side_effect = errors.DeadlineExceededError(
       None, None, 'timeout')
   mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
       None, None, 'timeout')
   with self.assertRaises(RuntimeError):
     cluster_resolver.num_accelerators()
  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
    """Creates a new `Session` and waits for model to be ready.

    Creates a new `Session` on 'master'.  Waits for the model to be
    initialized or recovered from a checkpoint.  It's expected that
    another thread or process will make the model ready, and that this
    is intended to be used by threads/processes that participate in a
    distributed training configuration where a different thread/process
    is responsible for initializing or recovering the model being trained.

    NB: The amount of time this method waits for the session is bounded
    by max_wait_secs. By default, this function will wait indefinitely.

    Args:
      master: `String` representation of the TensorFlow master to use.
      config: Optional ConfigProto proto used to configure the session.
      max_wait_secs: Maximum time to wait for the session to become available.

    Returns:
      A `Session`. May be None if the operation exceeds the timeout
      specified by config.operation_timeout_in_ms.

    Raises:
      tf.DeadlineExceededError: if the session is not available after
        max_wait_secs.
    """
    self._target = master

    if max_wait_secs is None:
      max_wait_secs = float("Inf")
    timer = _CountDownTimer(max_wait_secs)

    while True:
      sess = session.Session(self._target, graph=self._graph, config=config)
      not_ready_msg = None
      not_ready_local_msg = None
      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
          sess)
      if local_init_success:
        # Successful if local_init_op is None, or ready_for_local_init_op passes
        is_ready, not_ready_msg = self._model_ready(sess)
        if is_ready:
          return sess

      self._safe_close(sess)

      # Do we have enough time left to try again?
      remaining_ms_after_wait = (
          timer.secs_remaining() - self._recovery_wait_secs)
      if remaining_ms_after_wait < 0:
        raise errors.DeadlineExceededError(
            None, None,
            "Session was not ready after waiting %d secs." % (max_wait_secs,))

      logging.info("Waiting for model to be ready.  "
                   "Ready_for_local_init_op:  %s, ready: %s",
                   not_ready_local_msg, not_ready_msg)
      time.sleep(self._recovery_wait_secs)
Ejemplo n.º 3
0
    def testNumAcceleratorsRetryFailure(self, mock_list_devices,
                                        mock_eager_list_devices):
        tpu_map = {
            'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
                'health':
                'HEALTHY',
                'networkEndpoints': [
                    {
                        'ipAddress': '10.2.3.4',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.5',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.6',
                        'port': 8470,
                    },
                    {
                        'ipAddress': '10.2.3.7',
                        'port': 8470,
                    },
                ]
            }
        }

        cluster_resolver = resolver.TPUClusterResolver(
            project='test-project',
            zone='us-central1-c',
            tpu='test-tpu-1',
            service=self.mock_service_client(tpu_map=tpu_map))
        mock_list_devices.side_effect = errors.DeadlineExceededError(
            None, None, 'timeout')
        mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
            None, None, 'timeout')
        with self.assertRaises(RuntimeError):
            cluster_resolver.num_accelerators()