コード例 #1
0
ファイル: experiment.py プロジェクト: zzr199471/tensorflow
  def train(self, delay_secs=None):
    """Fit the estimator using the training data.

    Train the estimator for `self._train_steps` steps, after waiting for
    `delay_secs` seconds. If `self._train_steps` is `None`, train forever.

    Args:
      delay_secs: Start training after this many seconds.

    Returns:
      The trained estimator.
    """
    start = time.time()

    # Start the server, if needed. It's important to start the server before
    # we (optionally) sleep for the case where no device_filters are set.
    # Otherwise, the servers will wait to connect to each other before starting
    # to train. We might as well start as soon as we can.
    config = self._estimator.config
    if (config.environment != run_config.Environment.LOCAL and
        config.environment != run_config.Environment.GOOGLE and
        config.cluster_spec and config.master):
      self._start_server()

    extra_hooks = []
    if delay_secs is None:
      task_id = self._estimator.config.task_id or 0
      if self._delay_workers_by_global_step:
        # Wait 5500 global steps for the second worker. Each worker waits more
        # then previous one but with a diminishing number of steps.
        extra_hooks.append(
            basic_session_run_hooks.GlobalStepWaiterHook(
                int(8000.0 * math.log(task_id + 1))))
        delay_secs = 0
      else:
        # Wait 5 secs more for each new worker up to 60 secs.
        delay_secs = min(60, task_id * 5)

    if delay_secs > 0:
      elapsed_secs = time.time() - start
      remaining = delay_secs - elapsed_secs
      logging.info("Waiting %d secs before starting training.", remaining)
      time.sleep(delay_secs)

    return self._call_train(input_fn=self._train_input_fn,
                            max_steps=self._train_steps,
                            hooks=self._train_monitors + extra_hooks)
コード例 #2
0
 def test_wait_for_step(self):
   with ops.Graph().as_default():
     gstep = variables.get_or_create_global_step()
     hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
     hook.begin()
     with session_lib.Session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       waiter = threading.Thread(
           target=hook.before_run,
           args=(session_run_hook.SessionRunContext(
               original_args=None, session=sess),))
       waiter.daemon = True
       waiter.start()
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(state_ops.assign(gstep, 500))
       time.sleep(1.0)
       self.assertTrue(waiter.is_alive())
       sess.run(state_ops.assign(gstep, 1100))
       time.sleep(1.2)
       self.assertFalse(waiter.is_alive())
コード例 #3
0
    def train(self, delay_secs=None):
        """Fit the estimator using the training data.

    Train the estimator for `self._train_steps` steps, after waiting for
    `delay_secs` seconds. If `self._train_steps` is `None`, train forever.

    Args:
      delay_secs: Start training after this many seconds.

    Returns:
      The trained estimator.
    """
        start = time.time()

        # Start the server, if needed. It's important to start the server before
        # we (optionally) sleep for the case where no device_filters are set.
        # Otherwise, the servers will wait to connect to each other before starting
        # to train. We might as well start as soon as we can.
        config = self._estimator.config
        if isinstance(config, run_config.RunConfig):
            if (config.cluster_spec and config.master
                    and config.environment == run_config.Environment.LOCAL):
                logging.warn(
                    "ClusterSpec and master are provided, but environment is "
                    "set to 'local'. Set environment to 'cloud' if you intend "
                    "to use the distributed runtime.")
            if (config.environment != run_config.Environment.LOCAL
                    and config.environment != run_config.Environment.GOOGLE
                    and config.cluster_spec and config.master):
                self._start_server()
        elif config.cluster_spec and config.master:
            raise ValueError(
                'For distributed runtime, Experiment class only works with'
                'tf.contrib.learn.RunConfig for now, but provided {}'.format(
                    type(config)))

        extra_hooks = []
        if delay_secs is None:
            task_id = self._estimator.config.task_id or 0
            if self._delay_workers_by_global_step:
                # Wait 5500 global steps for the second worker. Each worker waits more
                # then previous one but with a diminishing number of steps.
                extra_hooks.append(
                    basic_session_run_hooks.GlobalStepWaiterHook(
                        int(8000.0 * math.log(task_id + 1))))
                delay_secs = 0
            else:
                # Wait 5 secs more for each new worker up to 60 secs.
                delay_secs = min(60, task_id * 5)

        if delay_secs > 0:
            elapsed_secs = time.time() - start
            remaining = delay_secs - elapsed_secs
            logging.info("Waiting %d secs before starting training.",
                         remaining)
            time.sleep(delay_secs)

        return self._call_train(input_fn=self._train_input_fn,
                                max_steps=self._train_steps,
                                hooks=self._train_monitors + extra_hooks,
                                saving_listeners=self._saving_listeners)