def train(self, delay_secs=None): """Fit the estimator using the training data. Train the estimator for `self._train_steps` steps, after waiting for `delay_secs` seconds. If `self._train_steps` is `None`, train forever. Args: delay_secs: Start training after this many seconds. Returns: The trained estimator. """ start = time.time() # Start the server, if needed. It's important to start the server before # we (optionally) sleep for the case where no device_filters are set. # Otherwise, the servers will wait to connect to each other before starting # to train. We might as well start as soon as we can. config = self._estimator.config if (config.environment != run_config.Environment.LOCAL and config.environment != run_config.Environment.GOOGLE and config.cluster_spec and config.master): self._start_server() extra_hooks = [] if delay_secs is None: task_id = self._estimator.config.task_id or 0 if self._delay_workers_by_global_step: # Wait 5500 global steps for the second worker. Each worker waits more # then previous one but with a diminishing number of steps. extra_hooks.append( basic_session_run_hooks.GlobalStepWaiterHook( int(8000.0 * math.log(task_id + 1)))) delay_secs = 0 else: # Wait 5 secs more for each new worker up to 60 secs. delay_secs = min(60, task_id * 5) if delay_secs > 0: elapsed_secs = time.time() - start remaining = delay_secs - elapsed_secs logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) return self._call_train(input_fn=self._train_input_fn, max_steps=self._train_steps, hooks=self._train_monitors + extra_hooks)
def test_wait_for_step(self): with ops.Graph().as_default(): gstep = variables.get_or_create_global_step() hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) hook.begin() with session_lib.Session() as sess: sess.run(variables_lib.global_variables_initializer()) waiter = threading.Thread( target=hook.before_run, args=(session_run_hook.SessionRunContext( original_args=None, session=sess),)) waiter.daemon = True waiter.start() time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(state_ops.assign(gstep, 500)) time.sleep(1.0) self.assertTrue(waiter.is_alive()) sess.run(state_ops.assign(gstep, 1100)) time.sleep(1.2) self.assertFalse(waiter.is_alive())
def train(self, delay_secs=None): """Fit the estimator using the training data. Train the estimator for `self._train_steps` steps, after waiting for `delay_secs` seconds. If `self._train_steps` is `None`, train forever. Args: delay_secs: Start training after this many seconds. Returns: The trained estimator. """ start = time.time() # Start the server, if needed. It's important to start the server before # we (optionally) sleep for the case where no device_filters are set. # Otherwise, the servers will wait to connect to each other before starting # to train. We might as well start as soon as we can. config = self._estimator.config if isinstance(config, run_config.RunConfig): if (config.cluster_spec and config.master and config.environment == run_config.Environment.LOCAL): logging.warn( "ClusterSpec and master are provided, but environment is " "set to 'local'. Set environment to 'cloud' if you intend " "to use the distributed runtime.") if (config.environment != run_config.Environment.LOCAL and config.environment != run_config.Environment.GOOGLE and config.cluster_spec and config.master): self._start_server() elif config.cluster_spec and config.master: raise ValueError( 'For distributed runtime, Experiment class only works with' 'tf.contrib.learn.RunConfig for now, but provided {}'.format( type(config))) extra_hooks = [] if delay_secs is None: task_id = self._estimator.config.task_id or 0 if self._delay_workers_by_global_step: # Wait 5500 global steps for the second worker. Each worker waits more # then previous one but with a diminishing number of steps. extra_hooks.append( basic_session_run_hooks.GlobalStepWaiterHook( int(8000.0 * math.log(task_id + 1)))) delay_secs = 0 else: # Wait 5 secs more for each new worker up to 60 secs. delay_secs = min(60, task_id * 5) if delay_secs > 0: elapsed_secs = time.time() - start remaining = delay_secs - elapsed_secs logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) return self._call_train(input_fn=self._train_input_fn, max_steps=self._train_steps, hooks=self._train_monitors + extra_hooks, saving_listeners=self._saving_listeners)