def test_secs_remaining_zero(self): timer = _CountDownTimer(0.) time.sleep(.01) secs_remaining = timer.secs_remaining() self.assertEqual(0., secs_remaining)
def test_secs_remaining_short(self): timer = _CountDownTimer(.001) time.sleep(.1) secs_remaining = timer.secs_remaining() self.assertEqual(0., secs_remaining)
def test_secs_remaining_long(self): timer = _CountDownTimer(60) time.sleep(.1) secs_remaining = timer.secs_remaining() self.assertLess(0., secs_remaining) self.assertGreater(60., secs_remaining)
def _wait_for_processes(self, wait_processes, kill_processes, timeout_secs): """Waits until all `wait_processes` finish, then kills `kill_processes`. Fails an assert if a process in `wait_processes` finishes unsuccessfully. The processes in `kill_processes` are assumed to never finish so they are killed. Args: wait_processes: A list of _ProcessInfo tuples. This function will wait for each to finish. kill_processes: A list of _ProcessInfo tuples. Each will be killed once every process in `wait_processes` is finished. timeout_secs: Seconds to wait before timing out and terminating processes. Returns: A list of strings, each which is a string of the stderr of a wait process. Raises: Exception: When waiting for tasks to finish times out. """ timer = _CountDownTimer(timeout_secs) wait_process_stderrs = [None] * len(wait_processes) finished_wait_processes = set() while len(finished_wait_processes) < len(wait_processes): if timer.secs_remaining() == 0: logging.error( "Timed out! Outputting logs of unfinished processes:") for i, wait_process in enumerate(wait_processes): if i in finished_wait_processes: continue wait_process.stderr.seek(0) wait_process_stderrs[i] = wait_process.stderr.read() logging.info( "stderr for incomplete %s (last %d chars): %s\n", wait_process.name, MAX_OUTPUT_CHARS, wait_process.stderr.read()[-MAX_OUTPUT_CHARS:]) raise Exception("Timed out waiting for tasks to complete.") for i, wait_process in enumerate(wait_processes): if i in finished_wait_processes: continue ret_code = wait_process.popen.poll() if ret_code is None: continue logging.info("%s finished", wait_process.name) wait_process.stderr.seek(0) wait_process_stderrs[i] = wait_process.stderr.read() logging.info("stderr for %s (last %d chars): %s\n", wait_process.name, MAX_OUTPUT_CHARS, wait_process_stderrs[i][-MAX_OUTPUT_CHARS:]) self.assertEqual(0, ret_code) finished_wait_processes.add(i) for kill_process in kill_processes: ret_code = kill_process.popen.poll() # Kill processes should not end until we kill them. # If it returns early, note the return code. self.assertIsNone(ret_code) # Delay between polling loops. time.sleep(0.25) logging.info("All wait processes finished") for i, kill_process in enumerate(kill_processes): # Kill each kill process. kill_process.popen.kill() kill_process.popen.wait() kill_process.stderr.seek(0) logging.info("stderr for %s (last %d chars): %s\n", kill_process.name, MAX_OUTPUT_CHARS, kill_process.stderr.read()[-MAX_OUTPUT_CHARS:]) return wait_process_stderrs
def _wait_for_processes(self, wait_processes, kill_processes, timeout_secs): """Waits until all `wait_processes` finish, then kills `kill_processes`. Fails an assert if a process in `wait_processes` finishes unsuccessfully. The processes in `kill_processes` are assumed to never finish so they are killed. Args: wait_processes: A list of _ProcessInfo tuples. This function will wait for each to finish. kill_processes: A list of _ProcessInfo tuples. Each will be killed once every process in `wait_processes` is finished. timeout_secs: Seconds to wait before timing out and terminating processes. Returns: A list of strings, each which is a string of the stderr of a wait process. Raises: Exception: When waiting for tasks to finish times out. """ timer = _CountDownTimer(timeout_secs) finished_wait_processes = set() poll_count = {wait_process: 0.0 for wait_process in wait_processes} while len(finished_wait_processes) < len(wait_processes): if timer.secs_remaining() == 0: logging.error("Timed out! Outputting logs of unfinished processes:") for i, wait_process in enumerate(wait_processes): if i in finished_wait_processes: continue log_all(wait_process, "incompleted") raise Exception("Timed out waiting for tasks to complete.") for i, wait_process in enumerate(wait_processes): if i in finished_wait_processes: continue ret_code = wait_process.popen.poll() if ret_code is None: poll_count[wait_process] += 0.25 if ((poll_count[wait_process] / 10.) - int(poll_count[wait_process] / 10.)) == 0: logging.info("%d secs has elapsed for %s", poll_count[wait_process], wait_process.name) continue logging.info("%s finished", wait_process.name) log_all(wait_process, "completed") self.assertEqual(0, ret_code) finished_wait_processes.add(i) for kill_process in kill_processes: ret_code = kill_process.popen.poll() # Kill processes should not end until we kill them. # If it returns early, note the return code. if ret_code is not None: logging.error("kill process %s ended with ret_code %d", kill_process.name, ret_code) log_all(kill_process, "ended with code {}".format(ret_code)) self.assertIsNone(ret_code) # Delay between polling loops. time.sleep(0.25) logging.info("All wait processes finished") for i, kill_process in enumerate(kill_processes): # Kill each kill process. kill_process.popen.kill() kill_process.popen.wait() log_all(kill_process, "killed")
def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None): """See `tf.estimator.Estimator` train.""" if (steps is not None) and (max_steps is not None): raise ValueError("Can not provide both steps and max_steps.") if steps is not None and steps <= 0: raise ValueError("Must specify steps > 0, given: {}".format(steps)) if steps is not None: max_steps = self._latest_checkpoint_global_step() + steps # Each iteration of this AdaNet loop represents an `_Iteration`. The # current iteration number is stored as a variable in the checkpoint so # that training can be stopped and started at anytime. with self._train_loop_context(): while True: current_iteration = self._latest_checkpoint_iteration_number() tf.logging.info("Beginning training AdaNet iteration %s", current_iteration) self._iteration_ended = False result = super(Estimator, self).train( input_fn=input_fn, hooks=hooks, max_steps=max_steps, saving_listeners=saving_listeners) # If training ended because the maximum number of training steps # occurred, exit training. if self._latest_checkpoint_global_step() >= max_steps: return result # If training ended for any reason other than the iteration ending, # exit training. if not self._iteration_ended: return result # The chief prepares the next AdaNet iteration, and increments the # iteration number by 1. if self.config.is_chief: # As the chief, store the train hooks and make a placeholder input_fn # in order to use them when preparing the next iteration. self._train_hooks = hooks self._placeholder_input_fn = make_placeholder_input_fn(input_fn) self._prepare_next_iteration() # This inner loop serves mainly for synchronizing the workers with the # chief during distributed training. Workers that finish training early # wait for the chief to prepare the next iteration and increment the # iteration number. Workers that are slow to finish training quickly # move onto the next iteration. And workers that go offline and return # online after training ended terminate gracefully. wait_for_chief = not self.config.is_chief timer = _CountDownTimer(self._worker_wait_timeout_secs) while wait_for_chief: # If the chief hits max_steps, it will stop training itself and not # increment the iteration number, so this is how the worker knows to # exit if it wakes up and the chief is gone. # TODO: Support steps parameter. if self._latest_checkpoint_global_step() >= max_steps: return result # In distributed training, a worker may end training before the chief # overwrites the checkpoint with the incremented iteration number. If # that is the case, it should wait for the chief to do so. Otherwise # the worker will get stuck waiting for its weights to be initialized. next_iteration = self._latest_checkpoint_iteration_number() if next_iteration > current_iteration: break # Check timeout when waiting for potentially downed chief. if timer.secs_remaining() == 0: tf.logging.error( "Chief job did not prepare next iteration after %s secs. It " "may have been preempted, been turned down, or crashed. This " "worker is now exiting training.", self._worker_wait_timeout_secs) return result tf.logging.info("Waiting for chief to finish") time.sleep(5) tf.logging.info("Finished training Adanet iteration %s", current_iteration) # Stagger starting workers to prevent training instability. if not self.config.is_chief: task_id = self.config.task_id or 0 # Wait 5 secs more for each new worker up to 60 secs. delay_secs = min(60, task_id * 5) tf.logging.info("Waiting %d secs before starting training.", delay_secs) time.sleep(delay_secs)