def evaluate_once(master,
                  checkpoint_path,
                  logdir,
                  num_evals=1,
                  initial_op=None,
                  initial_op_feed_dict=None,
                  eval_op=None,
                  eval_op_feed_dict=None,
                  final_op=None,
                  final_op_feed_dict=None,
                  summary_op=_USE_DEFAULT,
                  summary_op_feed_dict=None,
                  variables_to_restore=None,
                  session_config=None,
                  hooks=None):
    """Evaluates the model at the given checkpoint path.

  Args:
    master: The BNS address of the TensorFlow master.
    checkpoint_path: The path to a checkpoint to use for evaluation.
    logdir: The directory where the TensorFlow summaries are written to.
    num_evals: The number of times to run `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: The summary_op to evaluate after running TF-Slims metric ops. By
      default the summary_op is set to tf.summary.merge_all().
    summary_op_feed_dict: An optional feed dictionary to use when running the
      `summary_op`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      slim.variables.GetVariablesToRestore() is used.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    hooks: A list of additional `SessionRunHook` objects to pass during the
      evaluation.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.
  """
    if summary_op == _USE_DEFAULT:
        summary_op = summary.merge_all()

    all_hooks = [
        evaluation.StopAfterNEvalsHook(num_evals),
    ]

    if summary_op is not None:
        all_hooks.append(
            evaluation.SummaryAtEndHook(log_dir=logdir,
                                        summary_op=summary_op,
                                        feed_dict=summary_op_feed_dict))
    if hooks is not None:
        all_hooks.extend(hooks)

    saver = None
    if variables_to_restore is not None:
        saver = tf_saver.Saver(variables_to_restore)

    return evaluation.evaluate_once(checkpoint_path,
                                    master=master,
                                    scaffold=monitored_session.Scaffold(
                                        init_op=initial_op,
                                        init_feed_dict=initial_op_feed_dict,
                                        saver=saver),
                                    eval_ops=eval_op,
                                    feed_dict=eval_op_feed_dict,
                                    final_ops=final_op,
                                    final_ops_feed_dict=final_op_feed_dict,
                                    hooks=all_hooks,
                                    config=session_config)
 def test_raise_when_scaffold_and_summary_op_both_present(self):
     with self.assertRaises(ValueError):
         basic_session_run_hooks.SummarySaverHook(
             scaffold=monitored_session.Scaffold(),
             summary_op=self.summary_op)
示例#3
0
def train(train_op,
          logdir,
          master='',
          is_chief=True,
          scaffold=None,
          hooks=None,
          chief_only_hooks=None,
          save_checkpoint_secs=600,
          save_summaries_steps=100,
          config=None):
    """Runs the training loop.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where the graph and checkpoints are saved.
    master: The URL of the master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    scaffold: An tf.train.Scaffold instance.
    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
      training loop.
    chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
      inside the training loop for the chief trainer only.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If
      `save_summaries_steps` is set to `None`, then the default summary saver
      isn't used.
    config: An instance of `tf.ConfigProto`.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
    `save_summaries_steps` are `None.
  """
    # TODO(nsilberman): move this logic into monitored_session.py
    scaffold = scaffold or monitored_session.Scaffold()

    hooks = hooks or []

    if is_chief:
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_dir=logdir,
            master=master,
            config=config)

        if chief_only_hooks:
            hooks.extend(chief_only_hooks)

        hooks.append(
            basic_session_run_hooks.StepCounterHook(output_dir=logdir))

        if save_summaries_steps:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_summaries_steps is None')
            hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    output_dir=logdir))

        if save_checkpoint_secs:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_checkpoint_secs is None')
            hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
    else:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)

    with monitored_session.MonitoredSession(session_creator=session_creator,
                                            hooks=hooks) as session:
        loss = None
        while not session.should_stop():
            loss = session.run(train_op)
    return loss
示例#4
0
def evaluate_repeatedly(
    checkpoint_dir,
    master='',
    scaffold=None,
    eval_ops=None,
    feed_dict=None,
    final_ops=None,
    final_ops_feed_dict=None,
    variables_to_restore=None,
    eval_interval_secs=60,
    hooks=None,
    config=None,
    max_number_of_evaluations=None,
    timeout=None):
  """Repeatedly searches for a checkpoint in `checkpoint_dir` and evaluates it.

  During a single evaluation, the `eval_ops` is run until the session is
  interrupted or requested to finish. This is typically requested via a
  `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
  the requested number of times.

  Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
  `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
  evaluated a single time after `eval_ops` has finished running and the fetched
  values of `final_ops` are returned. If `final_ops` is left as `None`, then
  `None` is returned.

  One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
  summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
  summaries run immedietly after the model checkpoint has been restored.

  Note that `evaluate_once` creates a local variable used to track the number of
  evaluations run via `tf.contrib.training.get_or_create_eval_step`.
  Consequently, if a custom local init op is provided via a `scaffold`, the
  caller should ensure that the local init op also initializes the eval step.

  Args:
    checkpoint_dir: The directory where checkpoints are stored.
    master: The BNS address of the TensorFlow master.
    scaffold: An tf.train.Scaffold instance for initializing variables and
      restoring variables. Note that `scaffold.init_fn` is used by the function
      to restore the checkpoint. If you supply a custom init_fn, then it must
      also take care of restoring the model from its checkpoint.
    eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
      to `Tensors`, which is run until the session is requested to stop,
      commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
    feed_dict: The feed dictionary to use when executing the `eval_ops`.
    final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
      to `Tensors`.
    final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      tf.contrib.framework.get_variables_to_restore() is used.
    eval_interval_secs: The minimum number of seconds between evaluations.
    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
      evaluation loop.
    config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    max_number_of_evaluations: The maximum times to run the evaluation. If left
      as `None`, then evaluation runs indefinitely.
    timeout: The maximum amount of time to wait between checkpoints. If left as
      `None`, then the process will wait indefinitely.

  Returns:
    The fetched values of `final_ops` or `None` if `final_ops` is `None`.
  """
  eval_step = get_or_create_eval_step()

  if eval_ops is not None:
    update_eval_step = state_ops.assign_add(eval_step, 1)

    if isinstance(eval_ops, dict):
      eval_ops['update_eval_step'] = update_eval_step
    elif isinstance(eval_ops, (tuple, list)):
      eval_ops = list(eval_ops) + [update_eval_step]
    else:
      eval_ops = [eval_ops, update_eval_step]

  # Must come before the scaffold check.
  if scaffold and scaffold.saver:
    saver = scaffold.saver
  else:
    saver = tf_saver.Saver(
        variables_to_restore or variables.get_variables_to_restore())

  scaffold = scaffold or monitored_session.Scaffold()

  # Prepare the run hooks.
  hooks = hooks or []

  final_ops_hook = _FinalOpsHook(final_ops, final_ops_feed_dict)
  hooks.append(final_ops_hook)

  num_evaluations = 0
  for checkpoint_path in checkpoints_iterator(
      checkpoint_dir, eval_interval_secs, timeout):

    session_creator = monitored_session.ChiefSessionCreator(
        scaffold=_scaffold_with_init(scaffold, saver, checkpoint_path),
        checkpoint_dir=None,
        master=master,
        config=config)

    with monitored_session.MonitoredSession(
        session_creator=session_creator, hooks=hooks) as session:
      logging.info(
          'Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                    time.gmtime()))
      if eval_ops is not None:
        while not session.should_stop():
          session.run(eval_ops, feed_dict)

      logging.info(
          'Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                    time.gmtime()))
    num_evaluations += 1

    reached_max = num_evaluations >= max_number_of_evaluations
    if max_number_of_evaluations and reached_max:
      return final_ops_hook.final_ops_values

  logging.info('Timed-out waiting for a checkpoint.')
  return final_ops_hook.final_ops_values
  def _model_fn_from_saved_model(self, features, labels, mode):
    """Load a SavedModel graph and return an EstimatorSpec."""
    # TODO(kathywu): Model function loads placeholders from the graph. Calling
    # export_all_saved_models creates another placeholder for the inputs, on top
    # of the original placeholders. There should be a way to avoid this.
    self._validate_mode(mode)

    g = ops.get_default_graph()
    if  training_util.get_global_step(g) is not None:
      raise RuntimeError(
          'Graph must not contain a global step tensor before the SavedModel is'
          ' loaded. Please make sure that the input function does not create a '
          'global step.')

    # Extract SignatureDef for information about the input and output tensors.
    signature_def = self._get_signature_def_for_mode(mode)

    # Generate input map for replacing the inputs in the SavedModel graph with
    # the provided features and labels.
    input_map = _generate_input_map(signature_def, features, labels)

    # Create a list of the names of output tensors. When the graph is loaded,
    # names of the output tensors may be remapped. This ensures that the correct
    # tensors are returned in the EstimatorSpec.
    output_tensor_names = [
        value.name for value in six.itervalues(signature_def.outputs)]

    # Load the graph. `output_tensors` contains output `Tensors` in the same
    # same order as the `output_tensor_names` list.
    tags = model_fn_lib.EXPORT_TAG_MAP[mode]
    _, output_tensors = self.saved_model_loader.load_graph(
        g, tags, input_map=input_map, return_elements=output_tensor_names)

    # Create a scaffold from the MetaGraphDef that contains ops to initialize
    # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
    # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
    scaffold = monitored_session.Scaffold(
        local_init_op=loader_impl._get_main_op_tensor(  # pylint: disable=protected-access
            self._get_meta_graph_def_for_mode(mode)))

    # Ensure that a global step tensor has been created.
    global_step_tensor = training_util.get_global_step(g)
    training_util.assert_global_step(global_step_tensor)

    # Extract values to return in the EstimatorSpec.
    output_map = dict(zip(output_tensor_names, output_tensors))
    outputs = {key: output_map[value.name]
               for key, value in six.iteritems(signature_def.outputs)}

    loss, predictions, metrics = _validate_and_extract_outputs(
        mode, outputs, signature_def.method_name)

    train_op = ops.get_collection(constants.TRAIN_OP_KEY)
    if len(train_op) > 1:
      raise RuntimeError('Multiple ops found in the train_op collection.')
    train_op = None if not train_op else train_op[0]

    _clear_saved_model_collections()
    return model_fn_lib.EstimatorSpec(
        scaffold=scaffold,
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions,
        eval_metric_ops=metrics)
示例#6
0
def train(args):
    """Train CIFAR-10 for a number of steps.

  Args:
    args: The command line arguments.
  """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create the global step.
        global_step = tf.contrib.framework.create_global_step()

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                                 args.batch_size)
        decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        cifar10.LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.GradientDescentOptimizer(lr)

        # Calculate the gradients for each model tower.
        tower_grads = []
        for i in xrange(args.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
                    # Calculate the loss for one tower of the CIFAR model. This function
                    # constructs the entire CIFAR model but shares the variables across
                    # all towers.
                    loss = tower_loss(scope, args)

                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()

                    # Retain the summaries from the final tower.
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)

                    # Calculate the gradients for the batch of data on this CIFAR tower.
                    grads = opt.compute_gradients(loss)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', lr))

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables.
        # To understand why the following line is necessary, see:
        # https://github.com/carpedm20/DCGAN-tensorflow/issues/59
        with tf.variable_scope(tf.get_variable_scope(), reuse=False):
            variable_averages = tf.train.ExponentialMovingAverage(
                cifar10.MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        scaffold = monitored_session.Scaffold(summary_op=summary_op)

        # allow_soft_placement must be set to True to build towers on GPU, as some
        # of the ops do not have GPU implementations.
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold,
            checkpoint_dir=args.train_dir,
            config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=args.log_device_placement))

        hooks = [
            # Hook to save the model every N steps and at the end.
            basic_session_run_hooks.CheckpointSaverHook(
                args.train_dir,
                checkpoint_basename=CHECKPOINT_BASENAME,
                save_steps=args.checkpoint_interval_steps,
                scaffold=scaffold),

            # Hook to save a summary every N steps.
            basic_session_run_hooks.SummarySaverHook(
                save_steps=args.summary_interval_steps,
                output_dir=args.train_dir,
                scaffold=scaffold),

            # Hook to stop at step N.
            basic_session_run_hooks.StopAtStepHook(
                last_step=args.train_max_steps)
        ]

        # Start a new monitored session. This will automatically restart the
        # sessions if the parameter servers are preempted.
        with monitored_session.MonitoredSession(
                session_creator=session_creator, hooks=hooks) as sess:

            while not sess.should_stop():
                start_time = time.time()
                _, loss_value, global_step_value = sess.run(
                    [train_op, loss, global_step])
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if global_step_value % 10 == 0:
                    num_examples_per_step = args.batch_size * args.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / args.num_gpus

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    logging.info(format_str %
                                 (datetime.now(), global_step_value,
                                  loss_value, examples_per_sec, sec_per_batch))
示例#7
0
  def _model_fn(features, labels, mode):
    """Function that returns predictions, training loss, and training op."""
    weights = None
    if weights_name and weights_name in features:
      weights = features.pop(weights_name)

    # If we're doing eval, optionally ignore device_assigner.
    # Also ignore device assigner if we're exporting (mode == INFER)
    dev_assn = device_assigner
    if (mode == model_fn_lib.ModeKeys.INFER or
        (local_eval and mode == model_fn_lib.ModeKeys.EVAL)):
      dev_assn = None

    graph_builder = graph_builder_class(params,
                                        device_assigner=dev_assn)
    inference = {}
    if (mode == model_fn_lib.ModeKeys.EVAL or
        mode == model_fn_lib.ModeKeys.INFER):
      inference[eval_metrics.INFERENCE_PROB_NAME] = (
          graph_builder.inference_graph(features))

      if not params.regression:
        inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
            inference[eval_metrics.INFERENCE_PROB_NAME], 1)

      if report_feature_importances:
        inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = (
            graph_builder.feature_importances())

    # labels might be None if we're doing prediction (which brings up the
    # question of why we force everything to adhere to a single model_fn).
    loss_deps = []
    training_graph = None
    training_hooks = []
    scaffold = None
    if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
      training_graph = control_flow_ops.group(
          graph_builder.training_graph(
              features, labels, input_weights=weights,
              num_trainers=num_trainers,
              trainer_id=trainer_id),
          state_ops.assign_add(contrib_framework.get_global_step(), 1))
      loss_deps.append(training_graph)
      if hasattr(graph_builder, 'finalize_training'):
        finalize_listener = EveryCheckpointPreSaveListener(
            graph_builder.finalize_training())
        scaffold = monitored_session.Scaffold()
        training_hooks.append(
            basic_session_run_hooks.CheckpointSaverHook(
                model_dir, save_secs=600, save_steps=None,
                scaffold=scaffold,
                listeners=[finalize_listener]))

    training_loss = None
    if (mode == model_fn_lib.ModeKeys.EVAL or
        mode == model_fn_lib.ModeKeys.TRAIN):
      with ops.control_dependencies(loss_deps):
        training_loss = graph_builder.training_loss(
            features, labels, name=LOSS_NAME)

    # Put weights back in
    if weights is not None:
      features[weights_name] = weights

    if early_stopping_rounds:
      training_hooks.append(TensorForestLossHook(early_stopping_rounds))

    return model_fn_lib.ModelFnOps(
        mode=mode,
        predictions=inference,
        loss=training_loss,
        train_op=training_graph,
        training_hooks=training_hooks,
        scaffold=scaffold)
示例#8
0
def PartialRestoreSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        restore_var_list=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=monitored_session.USE_DEFAULT,
        save_summaries_secs=monitored_session.USE_DEFAULT,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100):
    """Creates a `MonitoredSession` for training.

    Supports partial restoration from checkpoints with parameter
    `restore_var_list`, by adding `CheckpointRestorerHook`.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore. Please check `tf.train.MonitoredSession` for more
  information.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    restore_var_list: a list of variables, optional, if not all variables should
      be recovered from checkpoint.
      Useful when changing network structures during training, i.e., finetuning
      a pretrained model with new layers.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used. Default 100.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used. Default not enabled.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.

  Returns:
    A `MonitoredSession` object.
  """
    if save_summaries_steps == monitored_session.USE_DEFAULT \
            and save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_steps = 100
        save_summaries_secs = None
    elif save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_secs = None
    elif save_summaries_steps == monitored_session.USE_DEFAULT:
        save_summaries_steps = None

    scaffold = scaffold or monitored_session.Scaffold()
    if not is_chief:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)
        return monitored_session.MonitoredSession(
            session_creator=session_creator,
            hooks=hooks or [],
            stop_grace_period_secs=stop_grace_period_secs)

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    if restore_var_list is None:
        restore_checkpoint_dir = checkpoint_dir
    else:
        restore_checkpoint_dir = None
        all_hooks.append(
            CheckpointRestorerHook(checkpoint_dir, var_list=restore_var_list))
        all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        missing_vars = filter(lambda v: not (v in restore_var_list), all_vars)
        logging.warning("MonitoredTrainingSession not restoring %s",
                        missing_vars)
    session_creator = monitored_session.ChiefSessionCreator(
        scaffold=scaffold,
        checkpoint_dir=restore_checkpoint_dir,
        master=master,
        config=config)

    if checkpoint_dir:
        all_hooks.append(
            basic_session_run_hooks.StepCounterHook(
                output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))

        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=checkpoint_dir))
        if save_checkpoint_secs and save_checkpoint_secs > 0:
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks,
        stop_grace_period_secs=stop_grace_period_secs)
示例#9
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode=mode,
                                       keras_model=keras_model,
                                       custom_objects=custom_objects,
                                       features=features,
                                       labels=labels,
                                       optimizer_config=optimizer_config)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if distribution_strategy_context.has_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not ModeKeys.PREDICT:
            if mode is ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if not model._is_graph_network:
            # Reset model state to original state,
            # to avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)

        scaffold = None
        if save_object_ckpt:
            model._track_trackable(training_util.get_global_step(),
                                   'estimator_global_step')
            # Create saver that maps variable names to object-checkpoint keys.
            object_graph = graph_view.ObjectGraphView(model)
            var_list = object_graph.frozen_saveable_objects()
            saver = saver_lib.Saver(var_list=var_list, sharded=True)
            saver._object_restore_saver = trackable_util.frozen_saver(model)
            scaffold = monitored_session.Scaffold(saver=saver)

        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)
            },
            scaffold=scaffold)
示例#10
0
    def __new__(cls,
                mode,
                predictions=None,
                loss=None,
                train_op=None,
                eval_metric_ops=None,
                export_outputs=None,
                training_chief_hooks=None,
                training_hooks=None,
                scaffold=None):
        """Creates a validated `EstimatorSpec` instance.

    Depending on the value of `mode`, different arguments are required. Namely
    * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
    * For `mode == ModeKeys.EVAL`: required field is`loss`.
    * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.

    model_fn can populate all arguments independent of mode. In this case, some
    arguments will be ignored by `Estimator`. E.g. `train_op` will be ignored
    in eval and infer modes. Example:

    ```python
    def my_model_fn(mode, features, labels):
      predictions = ...
      loss = ...
      train_op = ...
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Alternatively, model_fn can just populate the arguments appropriate to the
    given mode. Example:

    ```python
    def my_model_fn(mode, features, labels):
      if (mode == tf.estimator.ModeKeys.TRAIN or
          mode == tf.estimator.ModeKeys.EVAL):
        loss = ...
      else:
        loss = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = ...
      else:
        train_op = None
      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = ...
      else:
        predictions = None

      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Args:
      mode: A `ModeKeys`. Specifies if this is training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, namely a
        `(metric_tensor, update_op)` tuple.
      export_outputs: Describes the output signature to be exported to
        `SavedModel` and used during serving.
        A dict `{name: (signature_method_name, predictions)}` where:
        * name: An arbitrary name for this output.
        * signature_method_name: One of the *_METHOD_NAME constants defined in
          `signature_constants`, such as
          `tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME`. Describes
          the type of `SignatureDef` to be exported.
        * predictions: Predictions `Tensor` of dict of `Tensor`.
        Single-headed models only need to specify one entry in this dictionary.
        Multi-headed models should specify one entry for each head.
      training_chief_hooks: A list of `tf.train.SessionRunHook` objects to
        run on the chief worker during training.
      training_hooks: A list of `tf.train.SessionRunHook` objects that to run on
        all workers during training.
      scaffold: A `tf.train.Scaffold` object that can be used to set
        initialization, saver, and more to be used in training.

    Returns:
      A validated `EstimatorSpec` object.

    Raises:
      ValueError: If validation fails.
      TypeError: If any of the arguments is not the expected type.
    """
        # Validate train_op.
        if train_op is None:
            if mode == ModeKeys.TRAIN:
                raise ValueError('Missing train_op.')
        else:
            _check_is_tensor_or_operation(train_op, 'train_op')

        # Validate loss.
        if loss is None:
            if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
                raise ValueError('Missing loss.')
        else:
            loss = _check_is_tensor(loss, 'loss')
            loss_shape = loss.get_shape()
            if loss_shape.num_elements() not in (None, 1):
                raise ValueError('Loss must be scalar, given: {}'.format(loss))
            if not loss_shape.is_compatible_with(tensor_shape.scalar()):
                loss = array_ops.reshape(loss, [])

        # Validate predictions.
        if predictions is None:
            if mode == ModeKeys.PREDICT:
                raise ValueError('Missing predictions.')
            predictions = {}
        else:
            if isinstance(predictions, dict):
                predictions = {
                    k: _check_is_tensor(v, 'predictions[{}]'.format(k))
                    for k, v in six.iteritems(predictions)
                }
            else:
                predictions = _check_is_tensor(predictions, 'predictions')

        # Validate eval_metric_ops.
        if eval_metric_ops is None:
            eval_metric_ops = {}
        else:
            if not isinstance(eval_metric_ops, dict):
                raise TypeError(
                    'eval_metric_ops must be a dict, given: {}'.format(
                        eval_metric_ops))
            for key, metric_value in six.iteritems(eval_metric_ops):
                if (not isinstance(metric_value, tuple)
                        or len(metric_value) != 2):
                    raise TypeError(
                        'Values of eval_metric_ops must be (metric_tensor, update_op) '
                        'tuples, given: {} for key: {}'.format(
                            metric_value, key))
                _check_is_tensor_or_operation(
                    metric_value[0], 'eval_metric_ops[{}]'.format(key))
                _check_is_tensor_or_operation(
                    metric_value[1], 'eval_metric_ops[{}]'.format(key))

        # Validate export_outputs.
        if export_outputs is not None:
            if not isinstance(export_outputs, dict):
                raise TypeError(
                    'export_outputs must be dict, given: {}'.format(
                        export_outputs))
            for v in six.itervalues(export_outputs):
                if not isinstance(v, tuple) or len(v) != 2:
                    raise TypeError(
                        'Values in export_outputs must be 2-tuple, given: {}'.
                        format(export_outputs))
                if v[0] not in (signature_constants.CLASSIFY_METHOD_NAME,
                                signature_constants.PREDICT_METHOD_NAME,
                                signature_constants.REGRESS_METHOD_NAME):
                    raise ValueError(
                        'Invalid signature_method_name in export_outputs, '
                        'given: {}'.format(export_outputs))

        # Validate that all tensors and ops are from the default graph.
        default_graph = ops.get_default_graph()
        for value in _prediction_values(predictions):
            if value.graph is not default_graph:
                raise ValueError(
                    'prediction values must be from the default graph.')
        if loss is not None and loss.graph is not default_graph:
            raise ValueError('loss must be from the default graph.')
        if train_op is not None and train_op.graph is not default_graph:
            raise ValueError('train_op must be from the default graph.')
        for value in _eval_metric_ops_values(eval_metric_ops):
            if value.graph is not default_graph:
                raise ValueError(
                    'eval_metric_ops values must be from the default graph.')

        # Validate hooks.
        if training_chief_hooks is None:
            training_chief_hooks = []
        if training_hooks is None:
            training_hooks = []
        for hook in training_hooks + training_chief_hooks:
            if not isinstance(hook, session_run_hook.SessionRunHook):
                raise TypeError(
                    'All hooks must be SessionRunHook instances, given: {}'.
                    format(hook))

        scaffold = scaffold or monitored_session.Scaffold()
        # Validate scaffold.
        if not isinstance(scaffold, monitored_session.Scaffold):
            raise TypeError(
                'scaffold must be tf.train.Scaffold. Given: {}'.format(
                    scaffold))

        return super(EstimatorSpec,
                     cls).__new__(cls,
                                  predictions=predictions,
                                  loss=loss,
                                  train_op=train_op,
                                  eval_metric_ops=eval_metric_ops,
                                  export_outputs=export_outputs,
                                  training_chief_hooks=training_chief_hooks,
                                  training_hooks=training_hooks,
                                  scaffold=scaffold)
示例#11
0
def evaluate_once(checkpoint_path,
                  master='',
                  scaffold=None,
                  eval_ops=None,
                  feed_dict=None,
                  final_ops=None,
                  final_ops_feed_dict=None,
                  variables_to_restore=None,
                  hooks=None,
                  config=None):
    """Evaluates the model at the given checkpoint path.

  During a single evaluation, the `eval_ops` is run until the session is
  interrupted or requested to finish. This is typically requested via a
  `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
  the requested number of times.

  Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
  `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is evaluated
  a single time after `eval_ops` has finished running and the fetched values of
  `final_ops` are returned. If `final_ops` is left as `None`, then `None` is
  returned.

  One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
  summaries after the `eval_ops` have run. If `eval_ops` is `None`, the summaries
  run immedietly after the model checkpoint has been restored.

  Note that `evaluate_once` creates a local variable used to track the number of
  evaluations run via `tf.contrib.training.get_or_create_eval_step`.
  Consequently, if a custom local init op is provided via a `scaffold`, the
  caller should ensure that the local init op also initializes the eval step.

  Args:
    checkpoint_path: The path to a checkpoint to use for evaluation.
    master: The BNS address of the TensorFlow master.
    scaffold: An tf.train.Scaffold instance for initializing variables and
      restoring variables. Note that `scaffold.init_fn` is used by the function
      to restore the checkpoint. If you supply a custom init_fn, then it must
      also take care of restoring the model from its checkpoint.
    eval_ops: A operation which is run until the session is requested to stop,
      commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
    feed_dict: The feed dictionary to use when executing the `eval_ops`.
    final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
      `Tensors`.
    final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      tf.contrib.framework.get_variables_to_restore() is used.
    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
      evaluation loop.
    config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.

  Returns:
    The fetched values of `final_ops` or `None` if `final_ops` is `None`.
  """
    eval_step = get_or_create_eval_step()

    if eval_ops is not None:
        eval_ops = control_flow_ops.with_dependencies([eval_ops],
                                                      state_ops.assign_add(
                                                          eval_step, 1))

    # Must come before the scaffold check.
    if scaffold and scaffold.saver:
        saver = scaffold.saver
    else:
        saver = tf_saver.Saver(variables_to_restore
                               or variables.get_variables_to_restore(),
                               write_version=saver_pb2.SaverDef.V2)

    scaffold = scaffold or monitored_session.Scaffold()
    scaffold = _scaffold_with_init(scaffold, saver, checkpoint_path)

    logging.info('Starting evaluation at ' +
                 time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))

    # Prepare the session creator.
    session_creator = monitored_session.ChiefSessionCreator(
        scaffold=scaffold, checkpoint_dir=None, master=master, config=config)

    # Prepare the run hooks.
    hooks = hooks or []

    final_ops_hook = _FinalOpsHook(final_ops, final_ops_feed_dict)
    hooks.append(final_ops_hook)

    with monitored_session.MonitoredSession(session_creator=session_creator,
                                            hooks=hooks) as session:
        if eval_ops is not None:
            while not session.should_stop():
                session.run(eval_ops, feed_dict)

    logging.info('Finished evaluation at ' +
                 time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
    return final_ops_hook.final_ops_values
示例#12
0
    def testTrainWithInitFromCheckpoint(self):
        logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
        logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

        if gfile.Exists(logdir1):  # For running on jenkins.
            gfile.DeleteRecursively(logdir1)
        if gfile.Exists(logdir2):  # For running on jenkins.
            gfile.DeleteRecursively(logdir2)

        # First, train the model one step (make sure the error is high).
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op()
            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir1,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir1,
                                                                save_steps=1,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=1),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertGreater(loss, .5)

        # Next, train the model to convergence.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            train_op = self.create_train_op()
            saver = saver_lib.Saver()
            loss = training.train(
                train_op,
                logdir1,
                hooks=[
                    basic_session_run_hooks.CheckpointSaverHook(logdir1,
                                                                save_steps=300,
                                                                saver=saver),
                    basic_session_run_hooks.StopAtStepHook(num_steps=300),
                ],
                save_checkpoint_secs=None,
                save_summaries_steps=None)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)

        # Finally, advance the model a single step and validate that the loss is
        # still low.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            train_op = self.create_train_op()

            model_variables = variables_lib2.global_variables()
            model_path = saver_lib.latest_checkpoint(logdir1)

            assign_fn = variables_lib.assign_from_checkpoint_fn(
                model_path, model_variables)

            def init_fn(_, session):
                assign_fn(session)

            loss = training.train(
                train_op,
                None,
                scaffold=monitored_session.Scaffold(init_fn=init_fn),
                hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)],
                save_checkpoint_secs=None,
                save_summaries_steps=None)

            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)
示例#13
0
def mnist_model(learning_rate, use_two_conv, use_two_fc, hparam):
    if FLAGS.log_dir is None or FLAGS.log_dir == "":
        raise ValueError("Must specify an explicit `log_dir`")
    if FLAGS.data_dir is None or FLAGS.data_dir == "":
        raise ValueError("Must specify an explicit `data_dir`")

    tf.reset_default_graph()
    device, target = device_and_target()
    with tf.device(device):
        global_step = slim.get_or_create_global_step()
        # Setup placeholders, and reshape the data
        x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
        x_image = tf.reshape(x, [-1, 28, 28, 1])
        tf.summary.image('input', x_image, 3)
        y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")

        if use_two_conv:
            conv1 = conv_layer(x_image, 1, 32, "conv1")
            conv_out = conv_layer(conv1, 32, 64, "conv2")
        else:
            conv1 = conv_layer(x_image, 1, 64, "conv")
            conv_out = tf.nn.max_pool(conv1,
                                      ksize=[1, 2, 2, 1],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME")

        flattened = tf.reshape(conv_out, [-1, 7 * 7 * 64])

        if use_two_fc:
            fc1 = fc_layer(flattened, 7 * 7 * 64, 1024, "fc1")
            embedding_input = fc1
            embedding_size = 1024
            logits = fc_layer(fc1, 1024, 10, "fc2")
        else:
            embedding_input = flattened
            embedding_size = 7 * 7 * 64
            logits = fc_layer(flattened, 7 * 7 * 64, 10, "fc")

        with tf.name_scope("cross_entropy"):
            xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                logits=logits, labels=y),
                                  name="cross_entropy")
            tf.summary.scalar("cross_entropy", xent)

        with tf.name_scope("train"):
            train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)

        with tf.name_scope("accuracy"):
            correct_prediction = tf.equal(tf.argmax(logits, 1),
                                          tf.argmax(y, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            tf.summary.scalar("accuracy", accuracy)
            accuracy = tf.Print(accuracy, [accuracy], message="Accuracy = ")
        summ = tf.summary.merge_all()

        embedding = tf.Variable(tf.zeros([1024, embedding_size]),
                                name="test_embedding")
        assignment = embedding.assign(embedding_input)
        saver = tf.train.Saver(save_relative_paths=True)
        scaffold = monitored_session.Scaffold(saver=saver)

        writer = tf.summary.FileWriter(FLAGS.log_dir)

        config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
        embedding_config = config.embeddings.add()
        embedding_config.tensor_name = embedding.name
        embedding_config.sprite.image_path = './sprite_1024.png'
        embedding_config.metadata_path = './labels_1024.tsv'
        # Specify the width and height of a single thumbnail.
        embedding_config.sprite.single_image_dim.extend([28, 28])

    with tf.train.MonitoredTrainingSession(master=target,
                                           is_chief=(FLAGS.task_index == 0),
                                           checkpoint_dir=FLAGS.log_dir,
                                           scaffold=scaffold) as sess:

        writer.add_graph(sess.graph)
        tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
            writer, config)

        for i in range(2001):
            batch = mnist.train.next_batch(100)
            if task_index == 0:
                if i % 5 == 0:
                    [train_accuracy, s] = sess.run([accuracy, summ],
                                                   feed_dict={
                                                       x: batch[0],
                                                       y: batch[1]
                                                   })
                    writer.add_summary(s, i)
                if i % 500 == 0:
                    sess.run(assignment,
                             feed_dict={
                                 x: mnist.test.images[:1024],
                                 y: mnist.test.labels[:1024]
                             })
            sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})
示例#14
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            features, labels = input_fn()
            estimator_spec = self._call_model_fn(features, labels,
                                                 ModeKeys.TRAIN)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.CheckpointSaverHook)
                    for h in (all_hooks + estimator_spec.training_hooks +
                              chief_hooks +
                              estimator_spec.training_chief_hooks)
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        plx_hooks.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks + estimator_spec.training_hooks,
                    chief_only_hooks=chief_hooks +
                    estimator_spec.training_chief_hooks,
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            summary_io.SummaryWriterCache.clear()
            return loss
def evaluation_loop(master,
                    checkpoint_dir,
                    logdir,
                    num_evals=1,
                    initial_op=None,
                    initial_op_feed_dict=None,
                    init_fn=None,
                    eval_op=None,
                    eval_op_feed_dict=None,
                    final_op=None,
                    final_op_feed_dict=None,
                    summary_op=_USE_DEFAULT,
                    summary_op_feed_dict=None,
                    variables_to_restore=None,
                    eval_interval_secs=60,
                    max_number_of_evaluations=None,
                    session_config=None,
                    timeout=None,
                    hooks=None):
    """Runs TF-Slim's Evaluation Loop.

  Args:
    master: The BNS address of the TensorFlow master.
    checkpoint_dir: The directory where checkpoints are stored.
    logdir: The directory where the TensorFlow summaries are written to.
    num_evals: The number of times to run `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: The summary_op to evaluate after running TF-Slims metric ops. By
      default the summary_op is set to tf.summary.merge_all().
    summary_op_feed_dict: An optional feed dictionary to use when running the
      `summary_op`.
    variables_to_restore: A list of TensorFlow variables to restore during
      evaluation. If the argument is left as `None` then
      slim.variables.GetVariablesToRestore() is used.
    eval_interval_secs: The minimum number of seconds between evaluations.
    max_number_of_evaluations: the max number of iterations of the evaluation.
      If the value is left as 'None', the evaluation continues indefinitely.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    timeout: The maximum amount of time to wait between checkpoints. If left as
      `None`, then the process will wait indefinitely.
    hooks: A list of additional `SessionRunHook` objects to pass during
      repeated evaluations.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.
  """
    if summary_op == _USE_DEFAULT:
        summary_op = summary.merge_all()

    all_hooks = [
        evaluation.StopAfterNEvalsHook(num_evals),
    ]

    if summary_op is not None:
        all_hooks.append(
            evaluation.SummaryAtEndHook(log_dir=logdir,
                                        summary_op=summary_op,
                                        feed_dict=summary_op_feed_dict))

    if hooks is not None:
        # Add custom hooks if provided.
        all_hooks.extend(hooks)

    saver = None
    if variables_to_restore is not None:
        saver = tf_saver.Saver(variables_to_restore)

    return evaluation.evaluate_repeatedly(
        checkpoint_dir,
        master=master,
        scaffold=monitored_session.Scaffold(
            init_op=initial_op,
            init_feed_dict=initial_op_feed_dict,
            init_fn=init_fn,
            saver=saver),
        eval_ops=eval_op,
        feed_dict=eval_op_feed_dict,
        final_ops=final_op,
        final_ops_feed_dict=final_op_feed_dict,
        eval_interval_secs=eval_interval_secs,
        hooks=all_hooks,
        config=session_config,
        max_number_of_evaluations=max_number_of_evaluations,
        timeout=timeout)
示例#16
0
def _monitored_train(graph,
                     output_dir,
                     train_op,
                     loss_op,
                     global_step_tensor=None,
                     init_op=None,
                     init_feed_dict=None,
                     init_fn=None,
                     log_every_steps=10,
                     supervisor_is_chief=True,
                     supervisor_master='',
                     supervisor_save_model_secs=600,
                     supervisor_save_model_steps=None,
                     keep_checkpoint_max=5,
                     supervisor_save_summaries_steps=100,
                     feed_fn=None,
                     steps=None,
                     fail_on_nan_loss=True,
                     hooks=None,
                     max_steps=None):
    """Train a model via monitored_session.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss. A `0` or negative value disables logging.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.      
    supervisor_save_model_secs: Save checkpoints every this many seconds. Can
        not be specified with `supervisor_save_model_steps`.
    supervisor_save_model_steps: Save checkpoints every this many steps. Can not
        be specified with `supervisor_save_model_secs`.
    keep_checkpoint_max: The maximum number of recent checkpoint files to
      keep. As new files are created, older files are deleted. If None or 0,
      all checkpoint files are kept. This is simply passed as the max_to_keep
      arg to `tf.Saver` constructor.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` seconds when training.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    hooks: List of `SessionRunHook` subclass instances. Used for callbacks
      inside the training loop.
    max_steps: Number of total steps for which to train model. If `None`,
      train forever. Two calls fit(steps=100) means 200 training iterations.
      On the other hand two calls of fit(max_steps=100) means, second call
      will not do any iteration since first call did all 100 steps.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
      is not provided. See `tf.contrib.framework.get_global_step` for how we
      look up the latter if not provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
      evaluates to `NaN`.
    ValueError: If both `steps` and `max_steps` are not `None`.
  """
    if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
    if not output_dir:
        raise ValueError('Output directory should be non-empty %s.' %
                         output_dir)
    if train_op is None:
        raise ValueError('Missing train_op.')
    if loss_op is None:
        raise ValueError('Missing loss_op.')
    if hooks is None:
        hooks = []
    if not isinstance(hooks, list):
        raise ValueError('Hooks should be a list.')
    with graph.as_default():
        global_step_tensor = contrib_variables.assert_or_get_global_step(
            graph, global_step_tensor)
    if global_step_tensor is None:
        raise ValueError(
            'No "global_step" was provided or found in the graph.')

    if max_steps is not None:
        try:
            start_step = checkpoints.load_variable(output_dir,
                                                   global_step_tensor.name)
            if max_steps <= start_step:
                logging.info(
                    'Skipping training since max_steps has already saved.')
                return None
        except:  # pylint: disable=bare-except
            pass

    # Adapted SessionRunHooks such as ExportMonitor depend on the
    # CheckpointSaverHook to be executed before they should be executed.
    # The `hooks` param comprises of deprecated monitor hooks
    # (such as ExportMonitor). Appending them after the basic_session_run_hooks.
    all_hooks = []
    with graph.as_default():
        all_hooks.append(
            basic_session_run_hooks.NanTensorHook(
                loss_op, fail_on_nan_loss=fail_on_nan_loss))
        if log_every_steps > 0:
            all_hooks.append(
                basic_session_run_hooks.LoggingTensorHook(
                    {
                        'loss': loss_op.name,
                        'step': global_step_tensor.name
                    },
                    every_n_iter=log_every_steps))

        def make_saver():
            return tf_saver.Saver(sharded=True,
                                  max_to_keep=keep_checkpoint_max,
                                  defer_build=True)

        scaffold = monitored_session.Scaffold(
            init_op=init_op,
            init_feed_dict=init_feed_dict,
            init_fn=init_fn,
            saver=monitored_session.Scaffold.get_or_default(
                'saver', ops.GraphKeys.SAVERS, make_saver))

        if not supervisor_is_chief:
            session_creator = monitored_session.WorkerSessionCreator(
                scaffold=scaffold, master=supervisor_master)
        else:
            session_creator = monitored_session.ChiefSessionCreator(
                scaffold=scaffold,
                checkpoint_dir=output_dir,
                master=supervisor_master)
            summary_writer = summary_io.SummaryWriterCache.get(output_dir)
            all_hooks.append(
                basic_session_run_hooks.StepCounterHook(
                    summary_writer=summary_writer))
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    save_steps=supervisor_save_summaries_steps,
                    summary_writer=summary_writer,
                    scaffold=scaffold))
            if (supervisor_save_model_secs is not None
                    or supervisor_save_model_steps is not None):
                all_hooks.append(
                    basic_session_run_hooks.CheckpointSaverHook(
                        output_dir,
                        save_secs=supervisor_save_model_secs,
                        save_steps=supervisor_save_model_steps,
                        scaffold=scaffold))

        if steps is not None or max_steps is not None:
            all_hooks.append(
                basic_session_run_hooks.StopAtStepHook(steps, max_steps))
        all_hooks.extend(hooks)

        with monitored_session.MonitoredSession(
                session_creator=session_creator,
                hooks=all_hooks) as super_sess:
            loss = None
            while not super_sess.should_stop():
                _, loss = super_sess.run([train_op, loss_op],
                                         feed_fn() if feed_fn else None)
            return loss
示例#17
0
  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None):
    """Creates a validated `EstimatorSpec` instance.

    Depending on the value of `mode`, different arguments are required. Namely
    * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
    * For `mode == ModeKeys.EVAL`: required field is `loss`.
    * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.

    model_fn can populate all arguments independent of mode. In this case, some
    arguments will be ignored by an `Estimator`. E.g. `train_op` will be
    ignored in eval and infer modes. Example:

    ```python
    def my_model_fn(mode, features, labels):
      predictions = ...
      loss = ...
      train_op = ...
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Alternatively, model_fn can just populate the arguments appropriate to the
    given mode. Example:

    ```python
    def my_model_fn(mode, features, labels):
      if (mode == tf.estimator.ModeKeys.TRAIN or
          mode == tf.estimator.ModeKeys.EVAL):
        loss = ...
      else:
        loss = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = ...
      else:
        train_op = None
      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = ...
      else:
        predictions = None

      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Args:
      mode: A `ModeKeys`. Specifies if this is training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name. The values of the
        dict are the results of calling a metric function, namely a
        `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
        without any impact on state (typically is a pure computation results
        based on variables.). For example, it should not trigger the `update_op`
        or requires any input fetching.
      export_outputs: Describes the output signatures to be exported to
        `SavedModel` and used during serving.
        A dict `{name: output}` where:
        * name: An arbitrary name for this output.
        * output: an `ExportOutput` object such as `ClassificationOutput`,
            `RegressionOutput`, or `PredictOutput`.
        Single-headed models only need to specify one entry in this dictionary.
        Multi-headed models should specify one entry for each head, one of
        which must be named using
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
      training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to
        run on the chief worker during training.
      training_hooks: Iterable of `tf.train.SessionRunHook` objects to run
        on all workers during training.
      scaffold: A `tf.train.Scaffold` object that can be used to set
        initialization, saver, and more to be used in training.
      evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to
        run during evaluation.

    Returns:
      A validated `EstimatorSpec` object.

    Raises:
      ValueError: If validation fails.
      TypeError: If any of the arguments is not the expected type.
    """
    # Validate train_op.
    if train_op is None:
      if mode == ModeKeys.TRAIN:
        raise ValueError('Missing train_op.')
    else:
      _check_is_tensor_or_operation(train_op, 'train_op')

    # Validate loss.
    if loss is None:
      if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
        raise ValueError('Missing loss.')
    else:
      loss = _check_is_tensor(loss, 'loss')
      loss_shape = loss.get_shape()
      if loss_shape.num_elements() not in (None, 1):
        raise ValueError('Loss must be scalar, given: {}'.format(loss))
      if not loss_shape.is_compatible_with(tensor_shape.scalar()):
        loss = array_ops.reshape(loss, [])

    # Validate predictions.
    if predictions is None:
      if mode == ModeKeys.PREDICT:
        raise ValueError('Missing predictions.')
      predictions = {}
    else:
      if isinstance(predictions, dict):
        predictions = {
            k: _check_is_tensor(v, 'predictions[{}]'.format(k))
            for k, v in six.iteritems(predictions)
        }
      else:
        predictions = _check_is_tensor(predictions, 'predictions')

    # Validate eval_metric_ops.
    if eval_metric_ops is None:
      eval_metric_ops = {}
    else:
      if not isinstance(eval_metric_ops, dict):
        raise TypeError(
            'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
      for key, metric_value_and_update in six.iteritems(eval_metric_ops):
        if (not isinstance(metric_value_and_update, tuple) or
            len(metric_value_and_update) != 2):
          raise TypeError(
              'Values of eval_metric_ops must be (metric_value, update_op) '
              'tuples, given: {} for key: {}'.format(
                  metric_value_and_update, key))
        metric_value, metric_update = metric_value_and_update
        for metric_value_member in nest.flatten(metric_value):
          # Allow (possibly nested) tuples for metric values, but require that
          # each of them be Tensors or Operations.
          _check_is_tensor_or_operation(metric_value_member,
                                        'eval_metric_ops[{}]'.format(key))
        _check_is_tensor_or_operation(metric_update,
                                      'eval_metric_ops[{}]'.format(key))

    # Validate export_outputs.
    if export_outputs is not None:
      if not isinstance(export_outputs, dict):
        raise TypeError('export_outputs must be dict, given: {}'.format(
            export_outputs))
      for v in six.itervalues(export_outputs):
        if not isinstance(v, ExportOutput):
          raise TypeError(
              'Values in export_outputs must be ExportOutput objects. '
              'Given: {}'.format(export_outputs))
      # Note export_outputs is allowed to be empty.
      if len(export_outputs) == 1:
        (key, value), = export_outputs.items()
        if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          export_outputs[
              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
      if len(export_outputs) > 1:
        if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
            not in export_outputs):
          raise ValueError(
              'Multiple export_outputs were provided, but none of them is '
              'specified as the default.  Do this by naming one of them with '
              'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.')

    # Validate that all tensors and ops are from the default graph.
    default_graph = ops.get_default_graph()

    # We enumerate possible error causes here to aid in debugging.
    error_message_template = (
        '{0} with "{1}" must be from the default graph. '
        'Possible causes of this error include: \n\n'
        '1) {0} was created outside the context of the default graph.'
        '\n\n'
        '2) The object passed through to EstimatorSpec was not created '
        'in the most recent call to "model_fn".')

    if isinstance(predictions, dict):
      for key, value in six.iteritems(predictions):
        if value.graph is not default_graph:
          raise ValueError(error_message_template.format(
              'prediction values',
              '{0}: {1}'.format(key, value.name)))
    elif predictions is not None:
      # 'predictions' must be a single Tensor.
      if predictions.graph is not default_graph:
        raise ValueError(error_message_template.format(
            'prediction values', predictions.name))

    if loss is not None and loss.graph is not default_graph:
      raise ValueError(error_message_template.format('loss', loss.name))
    if train_op is not None and train_op.graph is not default_graph:
      raise ValueError(error_message_template.format('train_op', train_op.name))
    for key, value in list(six.iteritems(eval_metric_ops)):
      values = nest.flatten(value)
      for value in values:
        if value.graph is not default_graph:
          raise ValueError(error_message_template.format(
              'eval_metric_ops',
              '{0}: {1}'.format(key, value.name)))

    # Validate hooks.
    training_chief_hooks = tuple(training_chief_hooks or [])
    training_hooks = tuple(training_hooks or [])
    evaluation_hooks = tuple(evaluation_hooks or [])
    for hook in training_hooks + training_chief_hooks + evaluation_hooks:
      if not isinstance(hook, session_run_hook.SessionRunHook):
        raise TypeError(
            'All hooks must be SessionRunHook instances, given: {}'.format(
                hook))

    scaffold = scaffold or monitored_session.Scaffold()
    # Validate scaffold.
    if not isinstance(scaffold, monitored_session.Scaffold):
      raise TypeError(
          'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold))

    return super(EstimatorSpec, cls).__new__(
        cls,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs=export_outputs,
        training_chief_hooks=training_chief_hooks,
        training_hooks=training_hooks,
        scaffold=scaffold,
        evaluation_hooks=evaluation_hooks)
示例#18
0
    def _train_model(self, env, first_update, update_frequency, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            global_episode = get_or_create_global_episode(g)
            global_timestep = get_or_create_global_timestep(g)
            update_episode_op = tf.assign_add(global_episode, 1)
            update_timestep_op = tf.assign_add(global_timestep, 1)
            no_run_hooks = tf.no_op(name='no_run_hooks')
            with ops.device('/cpu:0'):
                features, labels = self._prepare_input_fn(Modes.TRAIN, env)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = [
                plx_hooks.EpisodeLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'global_timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_episodes=1),  # TODO: save every episode?
                plx_hooks.EpisodeCounterHook(output_dir=self.model_dir)
            ]
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeCheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeCheckpointSaverHook(
                            self._model_dir,
                            save_episodes=1,  # TODO: save every episode?
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeSummarySaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeSummarySaverHook(
                            scaffold=scaffold,
                            save_episodes=1,  # TODO: save every episode?
                            output_dir=self._model_dir,
                        )
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks +
                    list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=
                    0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=
                    0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    loss = self.run_episode(
                        env=env,
                        sess=mon_sess,
                        features=features,
                        labels=labels,
                        no_run_hooks=no_run_hooks,
                        global_step=global_step,
                        update_episode_op=update_episode_op,
                        update_timestep_op=update_timestep_op,
                        first_update=first_update,
                        update_frequency=update_frequency,
                        estimator_spec=estimator_spec)
            summary_io.SummaryWriterCache.clear()
            return loss