コード例 #1
0
 def test_secs_remaining_zero(self):
     timer = _CountDownTimer(0.)
     time.sleep(.01)
     secs_remaining = timer.secs_remaining()
     self.assertEqual(0., secs_remaining)
コード例 #2
0
 def test_secs_remaining_short(self):
     timer = _CountDownTimer(.001)
     time.sleep(.1)
     secs_remaining = timer.secs_remaining()
     self.assertEqual(0., secs_remaining)
コード例 #3
0
 def test_secs_remaining_long(self):
     timer = _CountDownTimer(60)
     time.sleep(.1)
     secs_remaining = timer.secs_remaining()
     self.assertLess(0., secs_remaining)
     self.assertGreater(60., secs_remaining)
コード例 #4
0
    def _wait_for_processes(self, wait_processes, kill_processes,
                            timeout_secs):
        """Waits until all `wait_processes` finish, then kills `kill_processes`.

    Fails an assert if a process in `wait_processes` finishes unsuccessfully.
    The processes in `kill_processes` are assumed to never finish so they are
    killed.

    Args:
      wait_processes: A list of _ProcessInfo tuples. This function will wait for
        each to finish.
      kill_processes: A list of _ProcessInfo tuples. Each will be killed once
        every process in `wait_processes` is finished.
      timeout_secs: Seconds to wait before timing out and terminating processes.

    Returns:
      A list of strings, each which is a string of the stderr of a wait process.

    Raises:
      Exception: When waiting for tasks to finish times out.
    """

        timer = _CountDownTimer(timeout_secs)
        wait_process_stderrs = [None] * len(wait_processes)
        finished_wait_processes = set()
        while len(finished_wait_processes) < len(wait_processes):
            if timer.secs_remaining() == 0:
                logging.error(
                    "Timed out! Outputting logs of unfinished processes:")
                for i, wait_process in enumerate(wait_processes):
                    if i in finished_wait_processes:
                        continue
                    wait_process.stderr.seek(0)
                    wait_process_stderrs[i] = wait_process.stderr.read()
                    logging.info(
                        "stderr for incomplete %s (last %d chars): %s\n",
                        wait_process.name, MAX_OUTPUT_CHARS,
                        wait_process.stderr.read()[-MAX_OUTPUT_CHARS:])
                raise Exception("Timed out waiting for tasks to complete.")
            for i, wait_process in enumerate(wait_processes):
                if i in finished_wait_processes:
                    continue
                ret_code = wait_process.popen.poll()
                if ret_code is None:
                    continue
                logging.info("%s finished", wait_process.name)
                wait_process.stderr.seek(0)
                wait_process_stderrs[i] = wait_process.stderr.read()
                logging.info("stderr for %s (last %d chars): %s\n",
                             wait_process.name, MAX_OUTPUT_CHARS,
                             wait_process_stderrs[i][-MAX_OUTPUT_CHARS:])
                self.assertEqual(0, ret_code)
                finished_wait_processes.add(i)
            for kill_process in kill_processes:
                ret_code = kill_process.popen.poll()
                # Kill processes should not end until we kill them.
                # If it returns early, note the return code.
                self.assertIsNone(ret_code)
            # Delay between polling loops.
            time.sleep(0.25)
        logging.info("All wait processes finished")
        for i, kill_process in enumerate(kill_processes):
            # Kill each kill process.
            kill_process.popen.kill()
            kill_process.popen.wait()
            kill_process.stderr.seek(0)
            logging.info("stderr for %s (last %d chars): %s\n",
                         kill_process.name, MAX_OUTPUT_CHARS,
                         kill_process.stderr.read()[-MAX_OUTPUT_CHARS:])
        return wait_process_stderrs
コード例 #5
0
  def _wait_for_processes(self, wait_processes, kill_processes, timeout_secs):
    """Waits until all `wait_processes` finish, then kills `kill_processes`.

    Fails an assert if a process in `wait_processes` finishes unsuccessfully.
    The processes in `kill_processes` are assumed to never finish so they are
    killed.

    Args:
      wait_processes: A list of _ProcessInfo tuples. This function will wait for
        each to finish.
      kill_processes: A list of _ProcessInfo tuples. Each will be killed once
        every process in `wait_processes` is finished.
      timeout_secs: Seconds to wait before timing out and terminating processes.

    Returns:
      A list of strings, each which is a string of the stderr of a wait process.

    Raises:
      Exception: When waiting for tasks to finish times out.
    """

    timer = _CountDownTimer(timeout_secs)
    finished_wait_processes = set()
    poll_count = {wait_process: 0.0 for wait_process in wait_processes}

    while len(finished_wait_processes) < len(wait_processes):
      if timer.secs_remaining() == 0:
        logging.error("Timed out! Outputting logs of unfinished processes:")
        for i, wait_process in enumerate(wait_processes):
          if i in finished_wait_processes:
            continue
          log_all(wait_process, "incompleted")
        raise Exception("Timed out waiting for tasks to complete.")
      for i, wait_process in enumerate(wait_processes):
        if i in finished_wait_processes:
          continue
        ret_code = wait_process.popen.poll()
        if ret_code is None:
          poll_count[wait_process] += 0.25
          if ((poll_count[wait_process] / 10.) -
              int(poll_count[wait_process] / 10.)) == 0:
            logging.info("%d secs has elapsed for %s", poll_count[wait_process],
                         wait_process.name)
          continue
        logging.info("%s finished", wait_process.name)
        log_all(wait_process, "completed")
        self.assertEqual(0, ret_code)
        finished_wait_processes.add(i)
      for kill_process in kill_processes:
        ret_code = kill_process.popen.poll()
        # Kill processes should not end until we kill them.
        # If it returns early, note the return code.
        if ret_code is not None:
          logging.error("kill process %s ended with ret_code %d",
                        kill_process.name, ret_code)
          log_all(kill_process, "ended with code {}".format(ret_code))
          self.assertIsNone(ret_code)
      # Delay between polling loops.
      time.sleep(0.25)
    logging.info("All wait processes finished")
    for i, kill_process in enumerate(kill_processes):
      # Kill each kill process.
      kill_process.popen.kill()
      kill_process.popen.wait()
      log_all(kill_process, "killed")
コード例 #6
0
  def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    """See `tf.estimator.Estimator` train."""

    if (steps is not None) and (max_steps is not None):
      raise ValueError("Can not provide both steps and max_steps.")
    if steps is not None and steps <= 0:
      raise ValueError("Must specify steps > 0, given: {}".format(steps))

    if steps is not None:
      max_steps = self._latest_checkpoint_global_step() + steps

    # Each iteration of this AdaNet loop represents an `_Iteration`. The
    # current iteration number is stored as a variable in the checkpoint so
    # that training can be stopped and started at anytime.
    with self._train_loop_context():
      while True:
        current_iteration = self._latest_checkpoint_iteration_number()
        tf.logging.info("Beginning training AdaNet iteration %s",
                        current_iteration)
        self._iteration_ended = False
        result = super(Estimator, self).train(
            input_fn=input_fn,
            hooks=hooks,
            max_steps=max_steps,
            saving_listeners=saving_listeners)

        # If training ended because the maximum number of training steps
        # occurred, exit training.
        if self._latest_checkpoint_global_step() >= max_steps:
          return result

        # If training ended for any reason other than the iteration ending,
        # exit training.
        if not self._iteration_ended:
          return result

        # The chief prepares the next AdaNet iteration, and increments the
        # iteration number by 1.
        if self.config.is_chief:
          # As the chief, store the train hooks and make a placeholder input_fn
          # in order to use them when preparing the next iteration.
          self._train_hooks = hooks
          self._placeholder_input_fn = make_placeholder_input_fn(input_fn)
          self._prepare_next_iteration()

        # This inner loop serves mainly for synchronizing the workers with the
        # chief during distributed training. Workers that finish training early
        # wait for the chief to prepare the next iteration and increment the
        # iteration number. Workers that are slow to finish training quickly
        # move onto the next iteration. And workers that go offline and return
        # online after training ended terminate gracefully.
        wait_for_chief = not self.config.is_chief
        timer = _CountDownTimer(self._worker_wait_timeout_secs)
        while wait_for_chief:
          # If the chief hits max_steps, it will stop training itself and not
          # increment the iteration number, so this is how the worker knows to
          # exit if it wakes up and the chief is gone.
          # TODO: Support steps parameter.
          if self._latest_checkpoint_global_step() >= max_steps:
            return result

          # In distributed training, a worker may end training before the chief
          # overwrites the checkpoint with the incremented iteration number. If
          # that is the case, it should wait for the chief to do so. Otherwise
          # the worker will get stuck waiting for its weights to be initialized.
          next_iteration = self._latest_checkpoint_iteration_number()
          if next_iteration > current_iteration:
            break

          # Check timeout when waiting for potentially downed chief.
          if timer.secs_remaining() == 0:
            tf.logging.error(
                "Chief job did not prepare next iteration after %s secs. It "
                "may have been preempted, been turned down, or crashed. This "
                "worker is now exiting training.",
                self._worker_wait_timeout_secs)
            return result
          tf.logging.info("Waiting for chief to finish")
          time.sleep(5)

        tf.logging.info("Finished training Adanet iteration %s",
                        current_iteration)

        # Stagger starting workers to prevent training instability.
        if not self.config.is_chief:
          task_id = self.config.task_id or 0
          # Wait 5 secs more for each new worker up to 60 secs.
          delay_secs = min(60, task_id * 5)
          tf.logging.info("Waiting %d secs before starting training.",
                          delay_secs)
          time.sleep(delay_secs)