Example #1
0
    def train(self,
              input_fn,
              steps=None,
              hooks=None,
              max_steps=None,
              saving_listeners=None,
              save_best_ckpt=False):
        with context.graph_mode():
            if (steps is not None) and (max_steps is not None):
                raise ValueError('Can not provide both steps and max_steps.')
            if steps is not None and steps <= 0:
                raise ValueError(
                    'Must specify steps > 0, given: {}'.format(steps))
            if max_steps is not None and max_steps <= 0:
                raise ValueError(
                    'Must specify max_steps > 0, given: {}'.format(max_steps))

            if max_steps is not None:
                start_step = _load_global_step_from_checkpoint_dir(
                    self._model_dir)
                if max_steps <= start_step:
                    logging.info(
                        'Skipping training since max_steps has already saved.')
                    return self

            hooks = estimator_lib._check_hooks_type(hooks)
            hooks.append(training.StopAtStepHook(steps, max_steps))

            saving_listeners = estimator_lib._check_listeners_type(
                saving_listeners)
            loss = self._train_model(input_fn, hooks, saving_listeners,
                                     save_best_ckpt)
            logging.info('Loss for final step: %s.', loss)
            return self
Example #2
0
    def fit(self, input_fn, hooks=None, steps=None, max_steps=None):
        """Trains a model given training data input_fn.

    Args:
      input_fn: Input function returning a tuple of:
          features - `Tensor` or dictionary of string feature name to `Tensor`.
          labels - `Tensor` or dictionary of `Tensor` with labels.
      hooks: List of `SessionRunHook` subclass instances. Used for callbacks
        inside the training loop.
      steps: Number of steps for which to train model. If `None`, train forever
        or train until input_fn generates the `OutOfRange` or `StopIteration`
        error. 'steps' works incrementally. If you call two times fit(steps=10)
        then training occurs in total 20 steps. If `OutOfRange` or
        `StopIteration` error occurs in the middle, training stops before 20
        steps. If you don't want to have incremental behaviour please set
        `max_steps` instead. If set, `max_steps` must be `None`.
      max_steps: Number of total steps for which to train model. If `None`,
        train forever or train until input_fn generates the `OutOfRange` or
        `StopIteration` error. If set, `steps` must be `None`. If `OutOfRange`
        or `StopIteration` error occurs in the middle, training stops before
        `max_steps` steps.

        Two calls to `fit(steps=100)` means 200 training
        iterations. On the other hand, two calls to `fit(max_steps=100)` means
        that the second call will not do any iteration since first call did
        all 100 steps.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps` is <= 0.
    """
        if (steps is not None) and (max_steps is not None):
            raise ValueError('Can not provide both steps and max_steps.')
        if steps is not None and steps <= 0:
            raise ValueError(
                'Must specify steps >= 0, given: {}'.format(steps))
        if max_steps is not None and max_steps <= 0:
            raise ValueError(
                'Must specify max_steps >= 0, given: {}'.format(max_steps))

        if max_steps is not None:
            start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
            if max_steps <= start_step:
                logging.info(
                    'Skipping training since max_steps has already saved.')
                return self

        hooks = hooks[:] if hooks else []
        if steps is not None or max_steps is not None:
            hooks.append(training.StopAtStepHook(steps, max_steps))

        loss = self._train_model(input_fn=input_fn, hooks=hooks)
        logging.info('Loss for final step: %s.', loss)
        return self
Example #3
0
def make_stop_at_checkpoint_step_hook(estimator,
                                      last_step,
                                      wait_after_file_check_secs=30):
    """Creates a proper StopAtCheckpointStepHook based on chief status."""

    if estimator.config.is_chief:
        return training.StopAtStepHook(last_step=last_step)
    return _StopAtCheckpointStepHook(
        model_dir=estimator.model_dir,
        last_step=last_step,
        wait_after_file_check_secs=wait_after_file_check_secs)
Example #4
0
 def _convert_train_steps_to_hooks(self, steps, max_steps):
   if steps is not None or max_steps is not None:
     return [training.StopAtStepHook(steps, max_steps)]
   else:
     return []