Пример #1
0
    def _train_model(self,
                     input_fn,
                     steps,
                     feed_fn=None,
                     init_op=None,
                     init_feed_fn=None,
                     init_fn=None,
                     device_fn=None,
                     monitors=None,
                     log_every_steps=100,
                     fail_on_nan_loss=True,
                     max_steps=None):
        # TODO(wicke): Remove this once Model and associated code are gone.
        if hasattr(self._config, 'execution_mode'):
            if self._config.execution_mode not in ('all', 'train'):
                return

            # Stagger startup of worker sessions based on task id.
            sleep_secs = min(
                self._config.training_worker_max_startup_secs,
                self._config.task *
                self._config.training_worker_session_startup_stagger_secs)
            if sleep_secs:
                logging.info('Waiting %d secs before starting task %d.',
                             sleep_secs, self._config.task)
                time.sleep(sleep_secs)

        # Device allocation
        device_fn = device_fn or self._device_fn

        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = contrib_framework.create_global_step(g)
            features, targets = input_fn()
            self._check_inputs(features, targets)
            train_op, loss_op = self._get_train_ops(features, targets)

            # Add default monitors.
            if monitors is None:
                monitors = []

            is_chief = self._config.task == 0

            if is_chief:
                monitors += monitors_lib.get_default_monitors(
                    loss_op=loss_op,
                    summary_op=logging_ops.get_summary_op(),
                    save_summary_steps=self._config.save_summary_steps,
                    summary_writer=graph_actions.get_summary_writer(
                        self._model_dir))
            else:
                monitors = []

            # Setup monitors.
            for monitor in monitors:
                monitor.set_estimator(self)

            return graph_actions.train(
                graph=g,
                output_dir=self._model_dir,
                train_op=train_op,
                loss_op=loss_op,
                global_step_tensor=global_step,
                init_op=init_op,
                init_feed_dict=init_feed_fn()
                if init_feed_fn is not None else None,
                init_fn=init_fn,
                log_every_steps=log_every_steps,
                supervisor_is_chief=is_chief,
                supervisor_master=self._config.master,
                supervisor_save_model_secs=self._config.save_checkpoints_secs,
                keep_checkpoint_max=self._config.keep_checkpoint_max,
                feed_fn=feed_fn,
                steps=steps,
                fail_on_nan_loss=fail_on_nan_loss,
                monitors=monitors,
                max_steps=max_steps)
Пример #2
0
def _train_internal(graph,
                    output_dir,
                    train_op,
                    loss_op,
                    global_step_tensor,
                    init_op,
                    init_feed_dict,
                    init_fn,
                    log_every_steps,
                    supervisor_is_chief,
                    supervisor_master,
                    supervisor_save_model_secs,
                    keep_checkpoint_max,
                    supervisor_save_summaries_steps,
                    feed_fn,
                    steps,
                    fail_on_nan_loss,
                    monitors,
                    max_steps):
  """See train."""
  if (steps is not None) and (max_steps is not None):
    raise ValueError('Can not provide both steps and max_steps.')
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  if train_op is None:
    raise ValueError('Missing train_op.')
  if loss_op is None:
    raise ValueError('Missing loss_op.')

  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
    if global_step_tensor is None:
      raise ValueError('No "global_step" was provided or found in the graph.')

    # Get current step.
    try:
      start_step = load_variable(output_dir, global_step_tensor.name)
    except (errors.NotFoundError, ValueError):
      start_step = 0

    summary_writer = (get_summary_writer(output_dir)
                      if supervisor_is_chief else None)

    # Add default chief monitors if none were provided.
    if not monitors:
      monitors = monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=supervisor_save_summaries_steps,
          summary_writer=summary_writer) if supervisor_is_chief else []

    # TODO(ipolosukhin): Replace all functionality of Supervisor
    # with Chief-Exclusive Monitors.
    if not supervisor_is_chief:
      # Prune list of monitor to the ones runnable on all workers.
      monitors = [monitor for monitor in monitors if monitor.run_on_all_workers]

    if max_steps is None:
      max_steps = (start_step + steps) if steps else None
    # Start monitors, can create graph parts.
    for monitor in monitors:
      monitor.begin(max_steps=max_steps)

  supervisor = tf_supervisor.Supervisor(
      graph,
      init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT,
      init_feed_dict=init_feed_dict,
      is_chief=supervisor_is_chief,
      logdir=output_dir,
      saver=_make_saver(graph, keep_checkpoint_max),
      global_step=global_step_tensor,
      summary_op=None,
      summary_writer=summary_writer,
      save_model_secs=supervisor_save_model_secs,
      init_fn=init_fn)
  session = supervisor.PrepareSession(master=supervisor_master,
                                      start_standard_services=True)
  supervisor.StartQueueRunners(session)

  with session:
    get_current_step = lambda: session.run(global_step_tensor)

    start_step = get_current_step()
    last_step = start_step
    last_log_step = start_step
    loss_value = None
    logging.info('Training steps [%d,%s)', last_step, 'inf'
                 if max_steps is None else str(max_steps))

    excinfo = None
    try:
      while not supervisor.ShouldStop() and (
          (max_steps is None) or (last_step < max_steps)):
        start_time = time.time()
        feed_dict = feed_fn() if feed_fn is not None else None

        outputs, should_stop = _run_with_monitors(
            session, last_step + 1, [train_op, loss_op], feed_dict, monitors)

        loss_value = outputs[loss_op.name]
        if np.isnan(loss_value):
          failure_message = 'Model diverged with loss = NaN.'
          if fail_on_nan_loss:
            logging.error(failure_message)
            raise monitors_lib.NanLossDuringTrainingError()
          else:
            logging.warning(failure_message)

        if should_stop:
          break

        this_step = get_current_step()

        if this_step <= last_step:
          logging.error(
              'Global step was not incremented by train op at step %s'
              ': new step %d', last_step, this_step)

        last_step = this_step
        is_last_step = (max_steps is not None) and (last_step >= max_steps)
        if is_last_step or (last_step - last_log_step >= log_every_steps):
          logging.info(
              'training step %d, loss = %.5f (%.3f sec/batch).',
              last_step, loss_value, float(time.time() - start_time))
          last_log_step = last_step
    except errors.OutOfRangeError as e:
      logging.warn('Got exception during tf.learn training loop possibly '
                   'due to exhausted input queue %s.', e)
    except StopIteration:
      logging.info('Exhausted input iterarator.')
    except BaseException as e:  # pylint: disable=broad-except
      # Hold on to any other exceptions while we try recording a final
      # checkpoint and summary.
      excinfo = sys.exc_info()
    finally:
      try:
        # Call supervisor.Stop() from within a try block because it re-raises
        # exceptions thrown by the supervised threads.
        supervisor.Stop(close_summary_writer=False)

        # Save one last checkpoint and summaries
        # TODO(wicke): This should be handled by Supervisor

        # In case we encountered an exception in the try block before we updated
        # last_step, update it here (again).
        last_step = get_current_step()
        if supervisor_is_chief:
          ckpt_path = supervisor.save_path
          logging.info('Saving checkpoint for step %d to checkpoint: %s.',
                       last_step, ckpt_path)
          supervisor.saver.save(session, ckpt_path, global_step=last_step)

          # Finish monitors.
          for monitor in monitors:
            monitor.end()

      # catch OutOfRangeError which is thrown when queue is out of data (and for
      # other reasons as well).
      except errors.OutOfRangeError as e:
        logging.warn('OutOfRangeError in tf.learn final checkpoint possibly '
                     'due to exhausted input queue. Note: summary_op is not '
                     'expected to trigger dequeues. %s.', e)
      except BaseException as e:  # pylint: disable=broad-except
        # If we don't already have an exception to re-raise, raise this one.
        if not excinfo:
          raise
        # Otherwise, log this one and raise the other in the finally block.
        logging.error('Got exception during tf.learn final checkpoint %s.', e)
      finally:
        if excinfo:
          reraise(*excinfo)
    return loss_value
Пример #3
0
def train(graph,
          output_dir,
          train_op,
          loss_op,
          global_step_tensor=None,
          init_op=None,
          init_feed_dict=None,
          init_fn=None,
          log_every_steps=10,
          supervisor_is_chief=True,
          supervisor_master='',
          supervisor_save_model_secs=600,
          supervisor_save_summaries_steps=100,
          feed_fn=None,
          steps=None,
          fail_on_nan_loss=True,
          monitors=None):
  """Train a model.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save a checkpoint every
      `supervisor_save_model_secs` seconds when training.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` seconds when training.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    monitors: List of `BaseMonitor` subclass instances. Used for callbacks
      inside the training loop.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `global_step_tensor` is not provided. See
        `tf.contrib.framework.get_global_step` for how we look it up if not
        provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
        evaluates to `NaN`.
  """
  if not output_dir:
    raise ValueError('Output directory should be non-empty.')

  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
    if global_step_tensor is None:
      raise ValueError('No "global_step" was provided or found in the graph.')

    # Get current step.
    try:
      start_step = checkpoints.load_variable(
          output_dir, global_step_tensor.name)
    except (errors.NotFoundError, ValueError):
      start_step = 0

    summary_writer = (get_summary_writer(output_dir)
                      if supervisor_is_chief else None)

    # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
    if not supervisor_is_chief:
      # monitors should run only on the chief.
      monitors = []
    elif not monitors:
      monitors = monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=supervisor_save_summaries_steps,
          summary_writer=summary_writer)

    max_steps = (start_step + steps) if steps else None
    # Start monitors, can create graph parts.
    for monitor in monitors:
      monitor.begin(max_steps=max_steps)

  supervisor = tf_supervisor.Supervisor(
      graph,
      init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT,
      init_feed_dict=init_feed_dict,
      is_chief=supervisor_is_chief,
      logdir=output_dir,
      saver=_make_saver(graph),
      global_step=global_step_tensor,
      summary_op=None,
      summary_writer=summary_writer,
      save_model_secs=supervisor_save_model_secs,
      init_fn=init_fn)
  session = supervisor.PrepareSession(master=supervisor_master,
                                      start_standard_services=True)
  supervisor.StartQueueRunners(session)

  with session:
    get_current_step = lambda: session.run(global_step_tensor)

    start_step = get_current_step()
    last_step = start_step
    last_log_step = start_step
    loss_value = None
    logging.info('Training steps [%d,%s)', last_step, 'inf'
                 if max_steps is None else str(max_steps))

    excinfo = None
    try:
      while not supervisor.ShouldStop() and (
          (max_steps is None) or (last_step < max_steps)):
        start_time = time.time()
        feed_dict = feed_fn() if feed_fn is not None else None

        outputs, should_stop = _run_with_monitors(
            session, last_step + 1, [train_op, loss_op], feed_dict, monitors)

        loss_value = outputs[loss_op.name]
        if np.isnan(loss_value):
          failure_message = 'Model diverged with loss = NaN.'
          if fail_on_nan_loss:
            logging.error(failure_message)
            raise NanLossDuringTrainingError()
          else:
            logging.warning(failure_message)

        if should_stop:
          break

        this_step = get_current_step()

        if this_step <= last_step:
          logging.error(
              'Global step was not incremented by train op at step %s'
              ': new step %d', last_step, this_step)

        last_step = this_step
        is_last_step = (max_steps is not None) and (last_step >= max_steps)
        if is_last_step or (last_step - last_log_step >= log_every_steps):
          logging.info(
              'training step %d, loss = %.5f (%.3f sec/batch).',
              last_step, loss_value, float(time.time() - start_time))
          last_log_step = last_step
    except errors.OutOfRangeError as e:
      logging.warn('Got exception during tf.learn training loop possibly '
                   'due to exhausted input queue %s.', e)
    except BaseException as e:  # pylint: disable=broad-except
      # Hold on to any other exceptions while we try recording a final
      # checkpoint and summary.
      excinfo = sys.exc_info()
    finally:
      try:
        # Call supervisor.Stop() from within a try block because it re-raises
        # exceptions thrown by the supervised threads.
        supervisor.Stop(close_summary_writer=False)

        # Save one last checkpoint and summaries
        # TODO(wicke): This should be handled by Supervisor

        # In case we encountered an exception in the try block before we updated
        # last_step, update it here (again).
        last_step = get_current_step()
        if supervisor_is_chief:
          ckpt_path = supervisor.save_path
          logging.info('Saving checkpoint for step %d to checkpoint: %s.',
                       last_step, ckpt_path)
          supervisor.saver.save(session, ckpt_path, global_step=last_step)

          # Finish monitors.
          for monitor in monitors:
            monitor.end()

      # catch OutOfRangeError which is thrown when queue is out of data (and for
      # other reasons as well).
      except errors.OutOfRangeError as e:
        logging.warn('OutOfRangeError in tf.learn final checkpoint possibly '
                     'due to exhausted input queue. Note: summary_op is not '
                     'expected to trigger dequeues. %s.', e)
      except BaseException as e:  # pylint: disable=broad-except
        # If we don't already have an exception to re-raise, raise this one.
        if not excinfo:
          raise
        # Otherwise, log this one and raise the other in the finally block.
        logging.error('Got exception during tf.learn final checkpoint %s.', e)
      finally:
        if excinfo:
          reraise(*excinfo)
    return loss_value
Пример #4
0
  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True):
    if self._config.execution_mode not in ('all', 'train'):
      return

    # Stagger startup of worker sessions based on task id.
    sleep_secs = min(self._config.training_worker_max_startup_secs,
                     self._config.task *
                     self._config.training_worker_session_startup_stagger_secs)
    if sleep_secs:
      logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                   self._config.task)
      time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)

      # Add default monitors.
      if monitors is None:
        monitors = []
      monitors += monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=100,
          summary_writer=graph_actions.get_summary_writer(self._model_dir))

      is_chief = self._config.task == 0
      if not is_chief:
        # Run monitors only on chief.
        monitors = []

      # Setup monitors.
      for monitor in monitors:
        monitor.set_estimator(self)

      return train(
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=is_chief,
          supervisor_master=self._config.master,
          feed_fn=feed_fn,
          max_steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          monitors=monitors)
Пример #5
0
def train(graph,
          output_dir,
          train_op,
          loss_op,
          global_step_tensor=None,
          init_op=None,
          init_feed_dict=None,
          init_fn=None,
          log_every_steps=10,
          supervisor_is_chief=True,
          supervisor_master='',
          supervisor_save_model_secs=600,
          supervisor_save_summaries_steps=100,
          feed_fn=None,
          steps=None,
          fail_on_nan_loss=True,
          monitors=None):
    """Train a model.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save a checkpoint every
      `supervisor_save_model_secs` seconds when training.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` seconds when training.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    monitors: List of `BaseMonitor` subclass instances. Used for callbacks
      inside the training loop.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `global_step_tensor` is not provided. See
        `tf.contrib.framework.get_global_step` for how we look it up if not
        provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
        evaluates to `NaN`.
  """
    if not output_dir:
        raise ValueError('Output directory should be non-empty.')

    with graph.as_default():
        global_step_tensor = contrib_variables.assert_or_get_global_step(
            graph, global_step_tensor)
        if global_step_tensor is None:
            raise ValueError(
                'No "global_step" was provided or found in the graph.')

        # Get current step.
        try:
            start_step = checkpoints.load_variable(output_dir,
                                                   global_step_tensor.name)
        except (errors.NotFoundError, ValueError):
            start_step = 0

        summary_writer = (get_summary_writer(output_dir)
                          if supervisor_is_chief else None)

        # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
        if not supervisor_is_chief:
            # monitors should run only on the chief.
            monitors = []
        elif not monitors:
            monitors = monitors_lib.get_default_monitors(
                loss_op=loss_op,
                summary_op=logging_ops.get_summary_op(),
                save_summary_steps=supervisor_save_summaries_steps,
                summary_writer=summary_writer)

        # Start monitors, can create graph parts.
        for monitor in monitors:
            monitor.begin(max_steps=start_step + steps)

    supervisor = tf_supervisor.Supervisor(
        graph,
        init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT,
        init_feed_dict=init_feed_dict,
        is_chief=supervisor_is_chief,
        logdir=output_dir,
        saver=_make_saver(graph),
        global_step=global_step_tensor,
        summary_op=None,
        summary_writer=summary_writer,
        save_model_secs=supervisor_save_model_secs,
        init_fn=init_fn)
    session = supervisor.PrepareSession(master=supervisor_master,
                                        start_standard_services=True)
    supervisor.StartQueueRunners(session)

    with session:
        get_current_step = lambda: session.run(global_step_tensor)

        start_step = get_current_step()
        max_steps = start_step + steps
        last_step = start_step
        last_log_step = start_step
        loss_value = None
        logging.info('Training steps [%d,%s)', last_step,
                     'inf' if max_steps is None else str(max_steps))

        excinfo = None
        try:
            while not supervisor.ShouldStop() and ((max_steps is None) or
                                                   (last_step < max_steps)):
                start_time = time.time()
                feed_dict = feed_fn() if feed_fn is not None else None

                outputs, should_stop = _run_with_monitors(
                    session, last_step + 1, [train_op, loss_op], feed_dict,
                    monitors)

                loss_value = outputs[loss_op.name]
                if np.isnan(loss_value):
                    failure_message = 'Model diverged with loss = NaN.'
                    if fail_on_nan_loss:
                        logging.error(failure_message)
                        raise NanLossDuringTrainingError()
                    else:
                        logging.warning(failure_message)

                if should_stop:
                    break

                this_step = get_current_step()

                if this_step <= last_step:
                    logging.error(
                        'Global step was not incremented by train op at step %s'
                        ': new step %d', last_step, this_step)

                last_step = this_step
                is_last_step = (max_steps
                                is not None) and (last_step >= max_steps)
                if is_last_step or (last_step - last_log_step >=
                                    log_every_steps):
                    logging.info(
                        'training step %d, loss = %.5f (%.3f sec/batch).',
                        last_step, loss_value, float(time.time() - start_time))
                    last_log_step = last_step
        except errors.OutOfRangeError as e:
            logging.warn(
                'Got exception during tf.learn training loop possibly '
                'due to exhausted input queue %s.', e)
        except BaseException as e:  # pylint: disable=broad-except
            # Hold on to any other exceptions while we try recording a final
            # checkpoint and summary.
            excinfo = sys.exc_info()
        finally:
            try:
                # Call supervisor.Stop() from within a try block because it re-raises
                # exceptions thrown by the supervised threads.
                supervisor.Stop(close_summary_writer=False)

                # Save one last checkpoint and summaries
                # TODO(wicke): This should be handled by Supervisor

                # In case we encountered an exception in the try block before we updated
                # last_step, update it here (again).
                last_step = get_current_step()
                if supervisor_is_chief:
                    ckpt_path = supervisor.save_path
                    logging.info(
                        'Saving checkpoint for step %d to checkpoint: %s.',
                        last_step, ckpt_path)
                    supervisor.saver.save(session,
                                          ckpt_path,
                                          global_step=last_step)

                    # Finish monitors.
                    for monitor in monitors:
                        monitor.end()

            # catch OutOfRangeError which is thrown when queue is out of data (and for
            # other reasons as well).
            except errors.OutOfRangeError as e:
                logging.warn(
                    'OutOfRangeError in tf.learn final checkpoint possibly '
                    'due to exhausted input queue. Note: summary_op is not '
                    'expected to trigger dequeues. %s.', e)
            except BaseException as e:  # pylint: disable=broad-except
                # If we don't already have an exception to re-raise, raise this one.
                if not excinfo:
                    raise
                # Otherwise, log this one and raise the other in the finally block.
                logging.error(
                    'Got exception during tf.learn final checkpoint %s.', e)
            finally:
                if excinfo:
                    reraise(*excinfo)
        return loss_value
Пример #6
0
  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True):
    # TODO(wicke): This is a hack and needs to go.
    if self._config.execution_mode not in ('all', 'train'):
      return

    if not self._model_dir:
      raise ValueError('Estimator\'s model_dir should be non-empty.')

    # Stagger startup of worker sessions based on task id.
    sleep_secs = min(self._config.training_worker_max_startup_secs,
                     self._config.task *
                     self._config.training_worker_session_startup_stagger_secs)
    if sleep_secs:
      logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                   self._config.task)
      time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)

      # Add default monitors.
      if monitors is None:
        monitors = []
      monitors += monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=100,
          summary_writer=graph_actions.get_summary_writer(self._model_dir))

      is_chief = self._config.task == 0
      if not is_chief:
        # Run monitors only on chief.
        monitors = []

      # Setup monitors.
      for monitor in monitors:
        monitor.set_estimator(self)

      return graph_actions.train(
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=is_chief,
          supervisor_master=self._config.master,
          feed_fn=feed_fn,
          max_steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          monitors=monitors)