def train(self, input_fn, steps=None, hooks=None, max_steps=None, saving_listeners=None, save_best_ckpt=False): with context.graph_mode(): if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if steps is not None and steps <= 0: raise ValueError( 'Must specify steps > 0, given: {}'.format(steps)) if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps > 0, given: {}'.format(max_steps)) if max_steps is not None: start_step = _load_global_step_from_checkpoint_dir( self._model_dir) if max_steps <= start_step: logging.info( 'Skipping training since max_steps has already saved.') return self hooks = estimator_lib._check_hooks_type(hooks) hooks.append(training.StopAtStepHook(steps, max_steps)) saving_listeners = estimator_lib._check_listeners_type( saving_listeners) loss = self._train_model(input_fn, hooks, saving_listeners, save_best_ckpt) logging.info('Loss for final step: %s.', loss) return self
def fit(self, input_fn, hooks=None, steps=None, max_steps=None): """Trains a model given training data input_fn. Args: input_fn: Input function returning a tuple of: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the training loop. steps: Number of steps for which to train model. If `None`, train forever or train until input_fn generates the `OutOfRange` or `StopIteration` error. 'steps' works incrementally. If you call two times fit(steps=10) then training occurs in total 20 steps. If `OutOfRange` or `StopIteration` error occurs in the middle, training stops before 20 steps. If you don't want to have incremental behaviour please set `max_steps` instead. If set, `max_steps` must be `None`. max_steps: Number of total steps for which to train model. If `None`, train forever or train until input_fn generates the `OutOfRange` or `StopIteration` error. If set, `steps` must be `None`. If `OutOfRange` or `StopIteration` error occurs in the middle, training stops before `max_steps` steps. Two calls to `fit(steps=100)` means 200 training iterations. On the other hand, two calls to `fit(max_steps=100)` means that the second call will not do any iteration since first call did all 100 steps. Returns: `self`, for chaining. Raises: ValueError: If both `steps` and `max_steps` are not `None`. ValueError: If either `steps` or `max_steps` is <= 0. """ if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if steps is not None and steps <= 0: raise ValueError( 'Must specify steps >= 0, given: {}'.format(steps)) if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps >= 0, given: {}'.format(max_steps)) if max_steps is not None: start_step = _load_global_step_from_checkpoint_dir(self._model_dir) if max_steps <= start_step: logging.info( 'Skipping training since max_steps has already saved.') return self hooks = hooks[:] if hooks else [] if steps is not None or max_steps is not None: hooks.append(training.StopAtStepHook(steps, max_steps)) loss = self._train_model(input_fn=input_fn, hooks=hooks) logging.info('Loss for final step: %s.', loss) return self
def make_stop_at_checkpoint_step_hook(estimator, last_step, wait_after_file_check_secs=30): """Creates a proper StopAtCheckpointStepHook based on chief status.""" if estimator.config.is_chief: return training.StopAtStepHook(last_step=last_step) return _StopAtCheckpointStepHook( model_dir=estimator.model_dir, last_step=last_step, wait_after_file_check_secs=wait_after_file_check_secs)
def _convert_train_steps_to_hooks(self, steps, max_steps): if steps is not None or max_steps is not None: return [training.StopAtStepHook(steps, max_steps)] else: return []