コード例 #1
0
    def get_bias(self, model_dir):
        """Returns the bias of the model.

    Args:
      model_dir: Directory where model parameters, graph and etc. are saved.

    Returns:
      The bias weights created by this model.
    """
        return [
            load_variable(model_dir,
                          name=(self._scope + "/hiddenlayer_%d/biases" % i))
            for i, _ in enumerate(self._hidden_units)
        ] + [load_variable(model_dir, name=(self._scope + "/logits/biases"))]
コード例 #2
0
    def get_bias(self, model_dir):
        """Returns bias of the model.

    Args:
      model_dir: Directory where model parameters, graph and etc. are saved.

    Returns:
      The bias weights created by this model.
    """
        return load_variable(model_dir, name=(self._scope + "/bias_weight"))
コード例 #3
0
    def get_weights(self, model_dir):
        """Returns weights per feature of the linear part.

    Args:
      model_dir: Directory where model parameters, graph and etc. are saved.

    Returns:
      The weights created by this model (without the optimizer weights).
    """
        all_variables = [name for name, _ in list_variables(model_dir)]
        values = {}
        optimizer_regex = r".*/" + self._get_optimizer().get_name(
        ) + r"(_\d)?$"
        for name in all_variables:
            if (name.startswith(self._scope + "/")
                    and name != self._scope + "/bias_weight"
                    and not re.match(optimizer_regex, name)):
                values[name] = load_variable(model_dir, name)
        if len(values) == 1:
            return values[list(values.keys())[0]]
        return values
コード例 #4
0
def _train_internal(graph,
                    output_dir,
                    train_op,
                    loss_op,
                    global_step_tensor,
                    init_op,
                    init_feed_dict,
                    init_fn,
                    log_every_steps,
                    supervisor_is_chief,
                    supervisor_master,
                    supervisor_save_model_secs,
                    keep_checkpoint_max,
                    supervisor_save_summaries_steps,
                    feed_fn,
                    steps,
                    fail_on_nan_loss,
                    monitors,
                    max_steps):
  """See train."""
  if (steps is not None) and (max_steps is not None):
    raise ValueError('Can not provide both steps and max_steps.')
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  if train_op is None:
    raise ValueError('Missing train_op.')
  if loss_op is None:
    raise ValueError('Missing loss_op.')

  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
    if global_step_tensor is None:
      raise ValueError('No "global_step" was provided or found in the graph.')

    # Get current step.
    try:
      start_step = load_variable(output_dir, global_step_tensor.name)
    except (errors.NotFoundError, ValueError):
      start_step = 0

    summary_writer = (get_summary_writer(output_dir)
                      if supervisor_is_chief else None)

    # Add default chief monitors if none were provided.
    if not monitors:
      monitors = monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=supervisor_save_summaries_steps,
          summary_writer=summary_writer) if supervisor_is_chief else []

    # TODO(ipolosukhin): Replace all functionality of Supervisor
    # with Chief-Exclusive Monitors.
    if not supervisor_is_chief:
      # Prune list of monitor to the ones runnable on all workers.
      monitors = [monitor for monitor in monitors if monitor.run_on_all_workers]

    if max_steps is None:
      max_steps = (start_step + steps) if steps else None
    # Start monitors, can create graph parts.
    for monitor in monitors:
      monitor.begin(max_steps=max_steps)

  supervisor = tf_supervisor.Supervisor(
      graph,
      init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT,
      init_feed_dict=init_feed_dict,
      is_chief=supervisor_is_chief,
      logdir=output_dir,
      saver=_make_saver(graph, keep_checkpoint_max),
      global_step=global_step_tensor,
      summary_op=None,
      summary_writer=summary_writer,
      save_model_secs=supervisor_save_model_secs,
      init_fn=init_fn)
  session = supervisor.PrepareSession(master=supervisor_master,
                                      start_standard_services=True)
  supervisor.StartQueueRunners(session)

  with session:
    get_current_step = lambda: session.run(global_step_tensor)

    start_step = get_current_step()
    last_step = start_step
    last_log_step = start_step
    loss_value = None
    logging.info('Training steps [%d,%s)', last_step, 'inf'
                 if max_steps is None else str(max_steps))

    excinfo = None
    try:
      while not supervisor.ShouldStop() and (
          (max_steps is None) or (last_step < max_steps)):
        start_time = time.time()
        feed_dict = feed_fn() if feed_fn is not None else None

        outputs, should_stop = _run_with_monitors(
            session, last_step + 1, [train_op, loss_op], feed_dict, monitors)

        loss_value = outputs[loss_op.name]
        if np.isnan(loss_value):
          failure_message = 'Model diverged with loss = NaN.'
          if fail_on_nan_loss:
            logging.error(failure_message)
            raise monitors_lib.NanLossDuringTrainingError()
          else:
            logging.warning(failure_message)

        if should_stop:
          break

        this_step = get_current_step()

        if this_step <= last_step:
          logging.error(
              'Global step was not incremented by train op at step %s'
              ': new step %d', last_step, this_step)

        last_step = this_step
        is_last_step = (max_steps is not None) and (last_step >= max_steps)
        if is_last_step or (last_step - last_log_step >= log_every_steps):
          logging.info(
              'training step %d, loss = %.5f (%.3f sec/batch).',
              last_step, loss_value, float(time.time() - start_time))
          last_log_step = last_step
    except errors.OutOfRangeError as e:
      logging.warn('Got exception during tf.learn training loop possibly '
                   'due to exhausted input queue %s.', e)
    except StopIteration:
      logging.info('Exhausted input iterarator.')
    except BaseException as e:  # pylint: disable=broad-except
      # Hold on to any other exceptions while we try recording a final
      # checkpoint and summary.
      excinfo = sys.exc_info()
    finally:
      try:
        # Call supervisor.Stop() from within a try block because it re-raises
        # exceptions thrown by the supervised threads.
        supervisor.Stop(close_summary_writer=False)

        # Save one last checkpoint and summaries
        # TODO(wicke): This should be handled by Supervisor

        # In case we encountered an exception in the try block before we updated
        # last_step, update it here (again).
        last_step = get_current_step()
        if supervisor_is_chief:
          ckpt_path = supervisor.save_path
          logging.info('Saving checkpoint for step %d to checkpoint: %s.',
                       last_step, ckpt_path)
          supervisor.saver.save(session, ckpt_path, global_step=last_step)

          # Finish monitors.
          for monitor in monitors:
            monitor.end()

      # catch OutOfRangeError which is thrown when queue is out of data (and for
      # other reasons as well).
      except errors.OutOfRangeError as e:
        logging.warn('OutOfRangeError in tf.learn final checkpoint possibly '
                     'due to exhausted input queue. Note: summary_op is not '
                     'expected to trigger dequeues. %s.', e)
      except BaseException as e:  # pylint: disable=broad-except
        # If we don't already have an exception to re-raise, raise this one.
        if not excinfo:
          raise
        # Otherwise, log this one and raise the other in the finally block.
        logging.error('Got exception during tf.learn final checkpoint %s.', e)
      finally:
        if excinfo:
          reraise(*excinfo)
    return loss_value