Ejemplo n.º 1
0
  def export(self, last_checkpoint, output_dir):
    """Builds a prediction graph and xports the model.

    Args:
      last_checkpoint: Path to the latest checkpoint file from training.
      output_dir: Path to the folder to be used to output the model.
    """
    logging.info('Exporting prediction graph to %s', output_dir)
    with tf.Session(graph=tf.Graph()) as sess:
      # Build and save prediction meta graph and trained variable values.
      inputs, outputs = self.build_prediction_graph()
      signature_def_map = {
        'serving_default': signature_def_utils.predict_signature_def(inputs, outputs)
      }
      init_op = tf.global_variables_initializer()
      sess.run(init_op)
      self.restore_from_checkpoint(sess, self.inception_checkpoint_file,
                                   last_checkpoint)
      init_op_serving = control_flow_ops.group(
          variables.local_variables_initializer(),
          data_flow_ops.tables_initializer())

      builder = saved_model_builder.SavedModelBuilder(output_dir)
      builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING],
          signature_def_map=signature_def_map,
          legacy_init_op=init_op_serving)
      builder.save(False)
Ejemplo n.º 2
0
    def export(self, last_checkpoint, output_dir):
        """Builds a prediction graph and xports the model.

    Args:
      last_checkpoint: Path to the latest checkpoint file from training.
      output_dir: Path to the folder to be used to output the model.
    """
        logging.info('Exporting prediction graph to %s', output_dir)
        with tf.Session(graph=tf.Graph()) as sess:
            # Build and save prediction meta graph and trained variable values.
            inputs, outputs = self.build_prediction_graph()
            signature_def_map = {
                'serving_default':
                signature_def_utils.predict_signature_def(inputs, outputs)
            }
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
            self.restore_from_checkpoint(sess, self.inception_checkpoint_file,
                                         last_checkpoint)
            init_op_serving = control_flow_ops.group(
                variables.local_variables_initializer(),
                data_flow_ops.tables_initializer())

            builder = saved_model_builder.SavedModelBuilder(output_dir)
            builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING],
                signature_def_map=signature_def_map,
                legacy_init_op=init_op_serving)
            builder.save(False)
def _get_local_init_op():
  local_init_op = _get_first_op_from_collection(
      ops.GraphKeys.LOCAL_INIT_OP)
  if local_init_op is None:
    op_list = [variables.local_variables_initializer(),
               data_flow_ops.tables_initializer()]
    if op_list:
      local_init_op = control_flow_ops.group(*op_list)
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
  return local_init_op
Ejemplo n.º 4
0
def _get_local_init_op():
  local_init_op = _get_first_op_from_collection(
      ops.GraphKeys.LOCAL_INIT_OP)
  if local_init_op is None:
    op_list = [variables.local_variables_initializer(),
               data_flow_ops.tables_initializer()]
    if op_list:
      local_init_op = control_flow_ops.group(*op_list)
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
  return local_init_op
Ejemplo n.º 5
0
def _export_graph(graph, saver, checkpoint_path, export_dir,
                  default_graph_signature, named_graph_signatures,
                  exports_to_keep):
  """Exports graph via session_bundle, by creating a Session."""
  with graph.as_default():
    with tf_session.Session('') as session:
      variables.local_variables_initializer()
      data_flow_ops.tables_initializer()
      saver.restore(session, checkpoint_path)

      export = exporter.Exporter(saver)
      export.init(init_op=control_flow_ops.group(
          variables.local_variables_initializer(),
          data_flow_ops.tables_initializer()),
                  default_graph_signature=default_graph_signature,
                  named_graph_signatures=named_graph_signatures,
                  assets_collection=ops.get_collection(
                      ops.GraphKeys.ASSET_FILEPATHS))
      return export.export(export_dir, contrib_variables.get_global_step(),
                           session, exports_to_keep=exports_to_keep)
Ejemplo n.º 6
0
def main_op():
  """Returns a main op to init variables and tables.

  Returns the main op including the group of ops that initializes all
  variables, initializes local variables and initialize all tables.

  Returns:
    The set of ops to be run as part of the main op upon the load operation.
  """
  init = variables.global_variables_initializer()
  init_local = variables.local_variables_initializer()
  init_tables = tf_data_flow_ops.tables_initializer()
  return control_flow_ops.group(init, init_local, init_tables)
Ejemplo n.º 7
0
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
    """Run `output_dict` tensors with each input in `feed_dicts`.

  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
  init all variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Yields:
    A sequence of dicts of values read from `output_dict` tensors, one item
    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
    if not output_dict:
        raise ValueError('output_dict is invalid: %s.' % output_dict)
    if not feed_dicts:
        raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

    graph = contrib_ops.get_graph_from_inputs(output_dict.values())
    with graph.as_default() as g:
        with tf_session.Session('') as session:
            session.run(
                resources.initialize_resources(resources.shared_resources() +
                                               resources.local_resources()))
            if restore_checkpoint_path:
                _restore_from_checkpoint(session, g, restore_checkpoint_path)
            else:
                session.run(variables.global_variables_initializer())
            session.run(variables.local_variables_initializer())
            session.run(data_flow_ops.tables_initializer())
            coord = coordinator.Coordinator()
            threads = None
            try:
                threads = queue_runner.start_queue_runners(session,
                                                           coord=coord)
                for f in feed_dicts:
                    yield session.run(output_dict, f)
            finally:
                coord.request_stop()
                if threads:
                    coord.join(threads, stop_grace_period_secs=120)
Ejemplo n.º 8
0
 def testBuildSequenceInputInput(self):
     sequence_input = dynamic_rnn_estimator.build_sequence_input(
         self.GetColumnsToTensors(), self.sequence_feature_columns,
         self.context_feature_columns)
     with self.test_session() as sess:
         sess.run(variables.global_variables_initializer())
         sess.run(data_flow_ops.tables_initializer())
         sequence_input_val = sess.run(sequence_input)
     expected_shape = np.array([
         3,  # expected batch size
         2,  # padded sequence length
         3 + 8 + 2  # location keys + embedding dim + measurement dimension
     ])
     self.assertAllEqual(expected_shape, sequence_input_val.shape)
 def testBuildSequenceInputInput(self):
   sequence_input = dynamic_rnn_estimator.build_sequence_input(
       self.GetColumnsToTensors(), self.sequence_feature_columns,
       self.context_feature_columns)
   with self.test_session() as sess:
     sess.run(variables.global_variables_initializer())
     sess.run(data_flow_ops.tables_initializer())
     sequence_input_val = sess.run(sequence_input)
   expected_shape = np.array([
       3,  # expected batch size
       2,  # padded sequence length
       3 + 8 + 2  # location keys + embedding dim + measurement dimension
   ])
   self.assertAllEqual(expected_shape, sequence_input_val.shape)
Ejemplo n.º 10
0
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
  """Run `output_dict` tensors with each input in `feed_dicts`.

  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
  init all variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Yields:
    A sequence of dicts of values read from `output_dict` tensors, one item
    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
  if not output_dict:
    raise ValueError('output_dict is invalid: %s.' % output_dict)
  if not feed_dicts:
    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

  graph = contrib_ops.get_graph_from_inputs(output_dict.values())
  with graph.as_default() as g:
    with tf_session.Session('') as session:
      session.run(
          resources.initialize_resources(resources.shared_resources() +
                                         resources.local_resources()))
      if restore_checkpoint_path:
        _restore_from_checkpoint(session, g, restore_checkpoint_path)
      else:
        session.run(variables.global_variables_initializer())
      session.run(variables.local_variables_initializer())
      session.run(data_flow_ops.tables_initializer())
      coord = coordinator.Coordinator()
      threads = None
      try:
        threads = queue_runner.start_queue_runners(session, coord=coord)
        for f in feed_dicts:
          yield session.run(output_dict, f)
      finally:
        coord.request_stop()
        if threads:
          coord.join(threads, stop_grace_period_secs=120)
Ejemplo n.º 11
0
    def testConstructRNN(self):
        initial_state = None
        sequence_input = dynamic_rnn_estimator.build_sequence_input(
            self.GetColumnsToTensors(), self.sequence_feature_columns,
            self.context_feature_columns)
        activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
            initial_state, sequence_input, self.rnn_cell,
            self.mock_target_column.num_label_columns)

        # Obtain values of activations and final state.
        with session.Session() as sess:
            sess.run(variables.global_variables_initializer())
            sess.run(data_flow_ops.tables_initializer())
            activations, final_state = sess.run([activations_t, final_state_t])

        expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
        self.assertAllEqual(expected_activations_shape, activations.shape)
        expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
        self.assertAllEqual(expected_state_shape, final_state.shape)
Ejemplo n.º 12
0
    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      my_int = variables.Variable(1, name='my_int',
                                  collections=[ops.GraphKeys.LOCAL_VARIABLES])
      scores = constant_op.constant([3.])
      with ops.control_dependencies(
          [variables.local_variables_initializer(),
           data_flow_ops.tables_initializer()]):
        assign_op = state_ops.assign(my_int, 12345)

      # local_initSop must be an Operation, not a Tensor.
      custom_local_init_op = control_flow_ops.group(assign_op)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(local_init_op=custom_local_init_op),
          export_outputs={'test': export_output.ClassificationOutput(scores)})
Ejemplo n.º 13
0
  def _init_local_init_op(self, local_init_op=USE_DEFAULT):
    """Initializes local_init_op.

    Args:
      local_init_op: `Operation` run for every new supervisor instance. If set
      to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
      collection. If the collection is empty, create an op that initializes
      all local variables and all tables.
    """
    if local_init_op is Supervisor.USE_DEFAULT:
      local_init_op = self._get_first_op_from_collection(
          ops.GraphKeys.LOCAL_INIT_OP)
      if local_init_op is None:
        op_list = [variables.local_variables_initializer(),
                   data_flow_ops.tables_initializer()]
        if op_list:
          local_init_op = control_flow_ops.group(*op_list)
          ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
    self._local_init_op = local_init_op
  def testConstructRNN(self):
    initial_state = None
    sequence_input = dynamic_rnn_estimator.build_sequence_input(
        self.GetColumnsToTensors(), self.sequence_feature_columns,
        self.context_feature_columns)
    activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
        initial_state, sequence_input, self.rnn_cell,
        self.mock_target_column.num_label_columns)

    # Obtain values of activations and final state.
    with session.Session() as sess:
      sess.run(variables.global_variables_initializer())
      sess.run(data_flow_ops.tables_initializer())
      activations, final_state = sess.run([activations_t, final_state_t])

    expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
    self.assertAllEqual(expected_activations_shape, activations.shape)
    expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
    self.assertAllEqual(expected_state_shape, final_state.shape)
Ejemplo n.º 15
0
    def _model_fn_scaffold(features, labels, mode):
      _, _ = features, labels
      my_int = variables.Variable(1, name='my_int',
                                  collections=[ops.GraphKeys.LOCAL_VARIABLES])
      scores = constant_op.constant([3.])
      with ops.control_dependencies(
          [variables.local_variables_initializer(),
           data_flow_ops.tables_initializer()]):
        assign_op = state_ops.assign(my_int, 12345)

      # local_initSop must be an Operation, not a Tensor.
      custom_local_init_op = control_flow_ops.group(assign_op)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          predictions=constant_op.constant([[1.]]),
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          scaffold=training.Scaffold(local_init_op=custom_local_init_op),
          export_outputs={'test': export.ClassificationOutput(scores)})
Ejemplo n.º 16
0
  def _init_local_init_op(self, local_init_op=USE_DEFAULT):
    """Initializes local_init_op.

    Args:
      local_init_op: `Operation` run for every new supervisor instance. If set
      to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
      collection. If the collection is empty, create an op that initializes
      all local variables and all tables.
    """
    if local_init_op is Supervisor.USE_DEFAULT:
      local_init_op = self._get_first_op_from_collection(
          ops.GraphKeys.LOCAL_INIT_OP)
      if local_init_op is None:
        op_list = [variables.local_variables_initializer(),
                   data_flow_ops.tables_initializer()]
        if op_list:
          local_init_op = control_flow_ops.group(*op_list)
          ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
    self._local_init_op = local_init_op
Ejemplo n.º 17
0
 def _default_local_init_op():
   return control_flow_ops.group(variables.local_variables_initializer(),
                                 data_flow_ops.tables_initializer())
Ejemplo n.º 18
0
    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  eval_result=None):
        with ops.Graph().as_default() as g:
            contrib_variables.create_global_step(g)

            input_ops = serving_from_csv_input(train_config, args, keep_target)
            model_fn_ops = estimator._call_model_fn(
                input_ops.features, None, model_fn_lib.ModeKeys.INFER)
            output_fetch_tensors = make_output_tensors(
                train_config=train_config,
                args=args,
                input_ops=input_ops,
                model_fn_ops=model_fn_ops,
                keep_target=keep_target)

            signature_def_map = {
                'serving_default':
                signature_def_utils.predict_signature_def(
                    input_ops.default_inputs, output_fetch_tensors)
            }

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
            if not checkpoint_path:
                raise NotFittedError("Couldn't find trained model at %s." %
                                     estimator._model_dir)

            export_dir = saved_model_export_utils.get_timestamped_export_dir(
                export_dir_base)

            with tf_session.Session('') as session:
                # variables.initialize_local_variables()
                variables.local_variables_initializer()
                data_flow_ops.tables_initializer()
                saver_for_restore = saver.Saver(variables.global_variables(),
                                                sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer())

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=init_op)
                builder.save(False)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

        # only keep the last 3 models
        saved_model_export_utils.garbage_collect_exports(
            python_portable_string(export_dir_base), exports_to_keep=3)

        # save the last model to the model folder.
        # export_dir_base = A/B/intermediate_models/
        if keep_target:
            final_dir = os.path.join(args.job_dir, 'evaluation_model')
        else:
            final_dir = os.path.join(args.job_dir, 'model')
        if file_io.is_directory(final_dir):
            file_io.delete_recursively(final_dir)
        file_io.recursive_create_dir(final_dir)
        _recursive_copy(export_dir, final_dir)

        return export_dir
Ejemplo n.º 19
0
 def _default_local_init_op():
   return control_flow_ops.group(variables.local_variables_initializer(),
                                 data_flow_ops.tables_initializer())
Ejemplo n.º 20
0
  def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
    with ops.Graph().as_default() as g:
      contrib_variables.create_global_step(g)

      input_ops = serving_from_csv_input(train_config, args, keep_target)
      model_fn_ops = estimator._call_model_fn(input_ops.features,
                                              None,
                                              model_fn_lib.ModeKeys.INFER)
      output_fetch_tensors = make_output_tensors(
          train_config=train_config,
          args=args,
          input_ops=input_ops,
          model_fn_ops=model_fn_ops,
          keep_target=keep_target)

      signature_def_map = {
        'serving_default': signature_def_utils.predict_signature_def(input_ops.default_inputs,
                                                                     output_fetch_tensors)
      }

      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
      if not checkpoint_path:
        raise NotFittedError("Couldn't find trained model at %s."
                             % estimator._model_dir)

      export_dir = saved_model_export_utils.get_timestamped_export_dir(
          export_dir_base)

      with tf_session.Session('') as session:
        # variables.initialize_local_variables()
        variables.local_variables_initializer()
        data_flow_ops.tables_initializer()
        saver_for_restore = saver.Saver(
            variables.global_variables(),
            sharded=True)
        saver_for_restore.restore(session, checkpoint_path)

        init_op = control_flow_ops.group(
            variables.local_variables_initializer(),
            data_flow_ops.tables_initializer())

        # Perform the export
        builder = saved_model_builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(
            session, [tag_constants.SERVING],
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            legacy_init_op=init_op)
        builder.save(False)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          gfile.MakeDirs(dest_path)
          gfile.Copy(source, dest_absolute)

    # only keep the last 3 models
    saved_model_export_utils.garbage_collect_exports(
        python_portable_string(export_dir_base),
        exports_to_keep=3)

    # save the last model to the model folder.
    # export_dir_base = A/B/intermediate_models/
    if keep_target:
      final_dir = os.path.join(args.job_dir, 'evaluation_model')
    else:
      final_dir = os.path.join(args.job_dir, 'model')
    if file_io.is_directory(final_dir):
      file_io.delete_recursively(final_dir)
    file_io.recursive_create_dir(final_dir)
    _recursive_copy(export_dir, final_dir)

    return export_dir
Ejemplo n.º 21
0
    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  eval_result=None):
        with ops.Graph().as_default() as g:
            contrib_variables.create_global_step(g)

            input_ops = feature_transforms.build_csv_serving_tensors(
                args.output_dir_from_analysis_step, features, schema, stats,
                keep_target)
            model_fn_ops = estimator._call_model_fn(
                input_ops.features, None, model_fn_lib.ModeKeys.INFER)
            output_fetch_tensors = make_prediction_output_tensors(
                args=args,
                features=features,
                input_ops=input_ops,
                model_fn_ops=model_fn_ops,
                keep_target=keep_target)

            # Don't use signature_def_utils.predict_signature_def as that renames
            # tensor names if there is only 1 input/output tensor!
            signature_inputs = {
                key: tf.saved_model.utils.build_tensor_info(tensor)
                for key, tensor in six.iteritems(input_ops.default_inputs)
            }
            signature_outputs = {
                key: tf.saved_model.utils.build_tensor_info(tensor)
                for key, tensor in six.iteritems(output_fetch_tensors)
            }
            signature_def_map = {
                'serving_default':
                signature_def_utils.build_signature_def(
                    signature_inputs, signature_outputs,
                    tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
            }

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 estimator._model_dir)

            export_dir = saved_model_export_utils.get_timestamped_export_dir(
                export_dir_base)

            with tf_session.Session('') as session:
                variables.local_variables_initializer()
                data_flow_ops.tables_initializer()
                saver_for_restore = saver.Saver(variables.global_variables(),
                                                sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer())

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=init_op)
                builder.save(False)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    file_io.recursive_create_dir(dest_path)
                    file_io.copy(source, dest_absolute)

        # only keep the last 3 models
        saved_model_export_utils.garbage_collect_exports(export_dir_base,
                                                         exports_to_keep=3)

        # save the last model to the model folder.
        # export_dir_base = A/B/intermediate_models/
        if keep_target:
            final_dir = os.path.join(args.job_dir, 'evaluation_model')
        else:
            final_dir = os.path.join(args.job_dir, 'model')
        if file_io.is_directory(final_dir):
            file_io.delete_recursively(final_dir)
        file_io.recursive_create_dir(final_dir)
        recursive_copy(export_dir, final_dir)

        return export_dir
Ejemplo n.º 22
0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          trace_every_n_steps=None):
  """Runs a training loop using a TensorFlow supervisor.

  When the sync_optimizer is supplied, gradient updates are applied
  synchronously. Otherwise, gradient updates are applied asynchronous.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where training logs are written to. If None, model
      checkpoints and summaries will not be written.
    train_step_fn: The function to call in order to execute a single gradient
      step. The function must have take exactly four arguments: the current
      session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
      default, two `Boolean`, scalar ops called "should_stop" and "should_log"
      are provided.
    log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step and logged.
    graph: The graph to pass to the supervisor. If no graph is supplied the
      default graph is used.
    master: The address of the tensorflow master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    global_step: The `Tensor` representing the global step. If left as `None`,
      then slim.variables.get_or_create_global_step() is used.
    number_of_steps: The max number of gradient steps to take during training,
      as measured by 'global_step': training will stop if global_step is
      greater than 'number_of_steps'. If the value is left as None, training
      proceeds indefinitely.
    init_op: The initialization operation. If left to its default value, then
      the session is initialized by calling `tf.global_variables_initializer()`.
    init_feed_dict: A feed dictionary to use when executing the `init_op`.
    local_init_op: The local initialization operation. If left to its default
      value, then the session is initialized by calling
      `tf.local_variables_initializer()` and `tf.tables_initializer()`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    ready_op: Operation to check if the model is ready to use. If left to its
      default value, then the session checks for readiness by calling
      `tf.report_uninitialized_variables()`.
    summary_op: The summary operation.
    save_summaries_secs: How often, in seconds, to save summaries.
    summary_writer: `SummaryWriter` to use.  Can be `None`
      to indicate that no summaries should be written. If unset, we
      create a SummaryWriter.
    startup_delay_steps: The number of steps to wait for before beginning. Note
      that this must be 0 if a sync_optimizer is supplied.
    saver: Saver to save checkpoints. If None, a default one will be created
      and used.
    save_interval_secs: How often, in seconds, to save the model to `logdir`.
    sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the
      argument is supplied, gradient updates will be synchronous. If left as
      `None`, gradient updates will be asynchronous.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
      and add it to the summaries every `trace_every_n_steps`. If None, no trace
      information will be produced or saved.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `train_op` is empty or if `startup_delay_steps` is
      non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
      negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
      provided.
  """
  if train_op is None:
    raise ValueError('train_op cannot be None.')

  if logdir is None:
    if summary_op != _USE_DEFAULT:
      raise ValueError('Cannot provide summary_op because logdir=None')
    if saver is not None:
      raise ValueError('Cannot provide saver because logdir=None')
    if trace_every_n_steps is not None:
      raise ValueError('Cannot provide trace_every_n_steps because '
                       'logdir=None')

  if sync_optimizer is not None and startup_delay_steps > 0:
    raise ValueError(
        'startup_delay_steps must be zero when sync_optimizer is supplied.')

  if number_of_steps is not None and number_of_steps <= 0:
    raise ValueError(
        '`number_of_steps` must be either None or a positive number.')

  graph = graph or ops.get_default_graph()
  with graph.as_default():
    if global_step is None:
      global_step = variables.get_or_create_global_step()
    saver = saver or tf_saver.Saver()

    with ops.name_scope('init_ops'):
      if init_op == _USE_DEFAULT:
        init_op = tf_variables.global_variables_initializer()

      if ready_op == _USE_DEFAULT:
        ready_op = tf_variables.report_uninitialized_variables()

      if local_init_op == _USE_DEFAULT:
        local_init_op = control_flow_ops.group(
            tf_variables.local_variables_initializer(),
            data_flow_ops.tables_initializer())

      if sync_optimizer is not None and isinstance(
          sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
        with ops.control_dependencies([local_init_op] if local_init_op is
                                      not None else []):
          if is_chief:
            local_init_op = sync_optimizer.chief_init_op
          else:
            local_init_op = sync_optimizer.local_step_init_op
        ready_for_local_init_op = sync_optimizer.ready_for_local_init_op
      else:
        ready_for_local_init_op = None

    if summary_op == _USE_DEFAULT:
      summary_op = summary.merge_all()

    if summary_writer == _USE_DEFAULT:
      summary_writer = supervisor.Supervisor.USE_DEFAULT

    if is_chief and sync_optimizer is not None:
      if not isinstance(sync_optimizer,
                        (sync_replicas_optimizer.SyncReplicasOptimizer)):
        raise ValueError(
            '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')

      # Need to create these BEFORE the supervisor finalizes the graph:
      init_tokens_op = sync_optimizer.get_init_tokens_op()
      chief_queue_runner = sync_optimizer.get_chief_queue_runner()

    if train_step_kwargs == _USE_DEFAULT:
      with ops.name_scope('train_step'):
        train_step_kwargs = {}

        if number_of_steps:
          should_stop_op = math_ops.greater_equal(global_step, number_of_steps)
        else:
          should_stop_op = constant_op.constant(False)
        train_step_kwargs['should_stop'] = should_stop_op
        train_step_kwargs['should_log'] = math_ops.equal(
            math_ops.mod(global_step, log_every_n_steps), 0)
        if is_chief and trace_every_n_steps is not None:
          train_step_kwargs['should_trace'] = math_ops.equal(
              math_ops.mod(global_step, trace_every_n_steps), 0)
          train_step_kwargs['logdir'] = logdir

  sv = supervisor.Supervisor(
      graph=graph,
      is_chief=is_chief,
      logdir=logdir,
      init_op=init_op,
      init_feed_dict=init_feed_dict,
      local_init_op=local_init_op,
      ready_for_local_init_op=ready_for_local_init_op,
      ready_op=ready_op,
      summary_op=summary_op,
      summary_writer=summary_writer,
      global_step=global_step,
      saver=saver,
      save_summaries_secs=save_summaries_secs,
      save_model_secs=save_interval_secs,
      init_fn=init_fn)

  if summary_writer is not None:
    train_step_kwargs['summary_writer'] = sv.summary_writer

  should_retry = True
  while should_retry:
    try:
      should_retry = False
      with sv.managed_session(
          master, start_standard_services=False, config=session_config) as sess:
        logging.info('Starting Session.')
        if is_chief:
          if logdir:
            sv.start_standard_services(sess)
        elif startup_delay_steps > 0:
          _wait_for_step(sess, global_step,
                         min(startup_delay_steps, number_of_steps or
                             sys.maxint))
        sv.start_queue_runners(sess)
        logging.info('Starting Queues.')
        if is_chief and sync_optimizer is not None:
          sv.start_queue_runners(sess, [chief_queue_runner])
          sess.run(init_tokens_op)
        try:
          while not sv.should_stop():
            total_loss, should_stop = train_step_fn(
                sess, train_op, global_step, train_step_kwargs)
            if should_stop:
              logging.info('Stopping Training.')
              break
        except errors.OutOfRangeError:
          # OutOfRangeError is thrown when epoch limit per
          # tf.train.limit_epochs is reached.
          logging.info('Caught OutOfRangeError. Stopping Training.')
        if logdir and sv.is_chief:
          logging.info('Finished training! Saving model to disk.')
          sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

    except errors.AbortedError:
      # Always re-run on AbortedError as it indicates a restart of one of the
      # distributed tensorflow servers.
      logging.info('Retrying training!')
      should_retry = True

  return total_loss
Ejemplo n.º 23
0
  def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
    with ops.Graph().as_default() as g:
      contrib_variables.create_global_step(g)

      input_ops = feature_transforms.build_csv_serving_tensors(
          args.analysis, features, schema, stats, keep_target)
      model_fn_ops = estimator._call_model_fn(input_ops.features,
                                              None,
                                              model_fn_lib.ModeKeys.INFER)
      output_fetch_tensors = make_prediction_output_tensors(
          args=args,
          features=features,
          input_ops=input_ops,
          model_fn_ops=model_fn_ops,
          keep_target=keep_target)

      # Don't use signature_def_utils.predict_signature_def as that renames
      # tensor names if there is only 1 input/output tensor!
      signature_inputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
                          for key, tensor in six.iteritems(input_ops.default_inputs)}
      signature_outputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
                           for key, tensor in six.iteritems(output_fetch_tensors)}
      signature_def_map = {
          'serving_default':
              signature_def_utils.build_signature_def(
                  signature_inputs,
                  signature_outputs,
                  tf.saved_model.signature_constants.PREDICT_METHOD_NAME)}

      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
      if not checkpoint_path:
        raise ValueError("Couldn't find trained model at %s."
                         % estimator._model_dir)

      export_dir = saved_model_export_utils.get_timestamped_export_dir(
          export_dir_base)

      with tf_session.Session('') as session:
        variables.local_variables_initializer()
        data_flow_ops.tables_initializer()
        saver_for_restore = saver.Saver(
            variables.global_variables(),
            sharded=True)
        saver_for_restore.restore(session, checkpoint_path)

        init_op = control_flow_ops.group(
            variables.local_variables_initializer(),
            data_flow_ops.tables_initializer())

        # Perform the export
        builder = saved_model_builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(
            session, [tag_constants.SERVING],
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            legacy_init_op=init_op)
        builder.save(False)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          file_io.recursive_create_dir(dest_path)
          file_io.copy(source, dest_absolute)

    # only keep the last 3 models
    saved_model_export_utils.garbage_collect_exports(
        export_dir_base,
        exports_to_keep=3)

    # save the last model to the model folder.
    # export_dir_base = A/B/intermediate_models/
    if keep_target:
      final_dir = os.path.join(args.job_dir, 'evaluation_model')
    else:
      final_dir = os.path.join(args.job_dir, 'model')
    if file_io.is_directory(final_dir):
      file_io.delete_recursively(final_dir)
    file_io.recursive_create_dir(final_dir)
    recursive_copy(export_dir, final_dir)

    return export_dir
Ejemplo n.º 24
0
def do_training(train_op, init_fn=None, summary_op=None, lr=None):
    global savers
    graph = ops.get_default_graph()
    with graph.as_default():
        global_step = variables.get_or_create_global_step()
        saver = tf_saver.Saver(max_to_keep=0)

        with ops.name_scope('init_ops'):
            init_op = tf_variables.global_variables_initializer()

            ready_op = tf_variables.report_uninitialized_variables()

            local_init_op = control_flow_ops.group(
                tf_variables.local_variables_initializer(),
                data_flow_ops.tables_initializer())

        summary_writer = supervisor.Supervisor.USE_DEFAULT
        with ops.name_scope('train_step'):
            train_step_kwargs = {}

            if not FLAGS.max_number_of_steps is None:
                should_stop_op = math_ops.greater_equal(
                    global_step, FLAGS.max_number_of_steps)
            else:
                should_stop_op = constant_op.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            if FLAGS.log_every_n_steps > 0:
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)
        prefix = "loc/net"
        lp = len(prefix)
        vdic = {
            "InceptionV2" + v.op.name[lp:]: v
            for v in tf.trainable_variables()
            if v.name.startswith(prefix) and v.name.find("Logits/") < 0
        }
        _saver = tf_saver.Saver(vdic)
        savers.append(_saver)
        for i in xrange(NUM_STN):
            prefix = "stn%d/net" % i
            lp = len(prefix)
            vdic = {
                "InceptionV2" + v.op.name[lp:]: v
                for v in tf.trainable_variables()
                if v.name.startswith(prefix) and v.name.find("Logits/") < 0
            }
            # saver = tf.train.Saver(vdic)
            _saver = tf_saver.Saver(vdic)
            savers.append(_saver)
    prt("savers %d" % len(savers))

    is_chief = True
    logdir = FLAGS.train_dir

    sv = supervisor.Supervisor(graph=graph,
                               is_chief=is_chief,
                               logdir=logdir,
                               init_op=init_op,
                               init_feed_dict=None,
                               local_init_op=local_init_op,
                               ready_for_local_init_op=None,
                               ready_op=ready_op,
                               summary_op=summary_op,
                               summary_writer=summary_writer,
                               global_step=global_step,
                               saver=saver,
                               save_summaries_secs=FLAGS.save_summaries_secs,
                               save_model_secs=FLAGS.save_interval_secs,
                               init_fn=init_fn)

    if summary_writer is not None:
        train_step_kwargs['summary_writer'] = sv.summary_writer

    with sv.managed_session('', start_standard_services=False,
                            config=None) as sess:
        logging.info('Starting Session.')
        if is_chief:
            if logdir:
                sv.start_standard_services(sess)
        elif startup_delay_steps > 0:
            _wait_for_step(
                sess, global_step,
                min(startup_delay_steps, number_of_steps or sys.maxint))
        sv.start_queue_runners(sess)
        logging.info('Starting Queues.')
        try:
            while not sv.should_stop():
                total_loss, global_step_value, should_stop = train_step(
                    sess, train_op, global_step, lr, train_step_kwargs)
                current_epoch = int(
                    math.ceil(float(global_step_value) / FLAGS.steps_in_epoch))
                if global_step_value > 0 and global_step_value % FLAGS.save_every_n_steps == 0:
                    sv.saver.save(sess,
                                  sv.save_path,
                                  global_step=sv.global_step)

                if should_stop:
                    logging.info('Stopping Training.')
                    break
        except errors.OutOfRangeError:
            # OutOfRangeError is thrown when epoch limit per
            # tf.train.limit_epochs is reached.
            logging.info('Caught OutOfRangeError. Stopping Training.')
        if logdir and sv.is_chief:
            logging.info('Finished training! Saving model to disk.')
            sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
Ejemplo n.º 25
0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          trace_every_n_steps=None):
    """Runs a training loop using a TensorFlow supervisor.

  When the sync_optimizer is supplied, gradient updates are applied
  synchronously. Otherwise, gradient updates are applied asynchronous.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where training logs are written to. If None, model
      checkpoints and summaries will not be written.
    train_step_fn: The function to call in order to execute a single gradient
      step. The function must have take exactly four arguments: the current
      session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
      default, two `Boolean`, scalar ops called "should_stop" and "should_log"
      are provided.
    log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step and logged.
    graph: The graph to pass to the supervisor. If no graph is supplied the
      default graph is used.
    master: The address of the tensorflow master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    global_step: The `Tensor` representing the global step. If left as `None`,
      then slim.variables.get_or_create_global_step() is used.
    number_of_steps: The max number of gradient steps to take during training.
      If the value is left as None, training proceeds indefinitely.
    init_op: The initialization operation. If left to its default value, then
      the session is initialized by calling `tf.global_variables_initializer()`.
    init_feed_dict: A feed dictionary to use when executing the `init_op`.
    local_init_op: The local initialization operation. If left to its default
      value, then the session is initialized by calling
      `tf.local_variables_initializer()` and `tf.tables_initializer()`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    ready_op: Operation to check if the model is ready to use. If left to its
      default value, then the session checks for readiness by calling
      `tf.report_uninitialized_variables()`.
    summary_op: The summary operation.
    save_summaries_secs: How often, in seconds, to save summaries.
    summary_writer: `SummaryWriter` to use.  Can be `None`
      to indicate that no summaries should be written. If unset, we
      create a SummaryWriter.
    startup_delay_steps: The number of steps to wait for before beginning. Note
      that this must be 0 if a sync_optimizer is supplied.
    saver: Saver to save checkpoints. If None, a default one will be created
      and used.
    save_interval_secs: How often, in seconds, to save the model to `logdir`.
    sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the
      argument is supplied, gradient updates will be synchronous. If left as
      `None`, gradient updates will be asynchronous.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
      and add it to the summaries every `trace_every_n_steps`. If None, no trace
      information will be produced or saved.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `train_op` is empty or if `startup_delay_steps` is
      non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
      negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
      provided.
  """
    if train_op is None:
        raise ValueError('train_op cannot be None.')

    if logdir is None:
        if summary_op != _USE_DEFAULT:
            raise ValueError('Cannot provide summary_op because logdir=None')
        if saver is not None:
            raise ValueError('Cannot provide saver because logdir=None')
        if trace_every_n_steps is not None:
            raise ValueError('Cannot provide trace_every_n_steps because '
                             'logdir=None')

    if sync_optimizer is not None and startup_delay_steps > 0:
        raise ValueError(
            'startup_delay_steps must be zero when sync_optimizer is supplied.'
        )

    if number_of_steps is not None and number_of_steps <= 0:
        raise ValueError(
            '`number_of_steps` must be either None or a positive number.')

    graph = graph or ops.get_default_graph()
    with graph.as_default():
        if global_step is None:
            global_step = variables.get_or_create_global_step()
        saver = saver or tf_saver.Saver()

        with ops.name_scope('init_ops'):
            if init_op == _USE_DEFAULT:
                init_op = tf_variables.global_variables_initializer()

            if ready_op == _USE_DEFAULT:
                ready_op = tf_variables.report_uninitialized_variables()

            if local_init_op == _USE_DEFAULT:
                local_init_op = control_flow_ops.group(
                    tf_variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer())

            if sync_optimizer is not None and isinstance(
                    sync_optimizer,
                    sync_replicas_optimizer.SyncReplicasOptimizer):
                with ops.control_dependencies(
                    [local_init_op] if local_init_op is not None else []):
                    if is_chief:
                        local_init_op = sync_optimizer.chief_init_op
                    else:
                        local_init_op = sync_optimizer.local_step_init_op
                ready_for_local_init_op = sync_optimizer.ready_for_local_init_op
            else:
                ready_for_local_init_op = None

        if summary_op == _USE_DEFAULT:
            summary_op = summary.merge_all()

        if summary_writer == _USE_DEFAULT:
            summary_writer = supervisor.Supervisor.USE_DEFAULT

        if is_chief and sync_optimizer is not None:
            if not isinstance(sync_optimizer,
                              (sync_replicas_optimizer.SyncReplicasOptimizer)):
                raise ValueError(
                    '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.'
                )

            # Need to create these BEFORE the supervisor finalizes the graph:
            init_tokens_op = sync_optimizer.get_init_tokens_op()
            chief_queue_runner = sync_optimizer.get_chief_queue_runner()

        if train_step_kwargs == _USE_DEFAULT:
            with ops.name_scope('train_step'):
                train_step_kwargs = {}

                if number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, log_every_n_steps), 0)
                if is_chief and trace_every_n_steps is not None:
                    train_step_kwargs['should_trace'] = math_ops.equal(
                        math_ops.mod(global_step, trace_every_n_steps), 0)
                    train_step_kwargs['logdir'] = logdir

    sv = supervisor.Supervisor(graph=graph,
                               is_chief=is_chief,
                               logdir=logdir,
                               init_op=init_op,
                               init_feed_dict=init_feed_dict,
                               local_init_op=local_init_op,
                               ready_for_local_init_op=ready_for_local_init_op,
                               ready_op=ready_op,
                               summary_op=summary_op,
                               summary_writer=summary_writer,
                               global_step=global_step,
                               saver=saver,
                               save_summaries_secs=save_summaries_secs,
                               save_model_secs=save_interval_secs,
                               init_fn=init_fn)

    if summary_writer is not None:
        train_step_kwargs['summary_writer'] = sv.summary_writer

    should_retry = True
    while should_retry:
        try:
            should_retry = False
            with sv.managed_session(master,
                                    start_standard_services=False,
                                    config=session_config) as sess:
                logging.info('Starting Session.')
                if is_chief:
                    if logdir:
                        sv.start_standard_services(sess)
                elif startup_delay_steps > 0:
                    _wait_for_step(
                        sess, global_step,
                        min(startup_delay_steps, number_of_steps
                            or sys.maxint))
                sv.start_queue_runners(sess)
                logging.info('Starting Queues.')
                if is_chief and sync_optimizer is not None:
                    sv.start_queue_runners(sess, [chief_queue_runner])
                    sess.run(init_tokens_op)
                try:
                    while not sv.should_stop():
                        total_loss, should_stop = train_step_fn(
                            sess, train_op, global_step, train_step_kwargs)
                        if should_stop:
                            logging.info('Stopping Training.')
                            break
                except errors.OutOfRangeError:
                    # OutOfRangeError is thrown when epoch limit per
                    # tf.train.limit_epochs is reached.
                    logging.info('Caught OutOfRangeError. Stopping Training.')
                if logdir and sv.is_chief:
                    logging.info('Finished training! Saving model to disk.')
                    sv.saver.save(sess,
                                  sv.save_path,
                                  global_step=sv.global_step)

        except errors.AbortedError:
            # Always re-run on AbortedError as it indicates a restart of one of the
            # distributed tensorflow servers.
            logging.info('Retrying training!')
            should_retry = True

    return total_loss
Ejemplo n.º 26
0
def _initialized_session():
    sess = session.Session()
    sess.run(variables_lib.global_variables_initializer())
    sess.run(data_flow_ops.tables_initializer())
    return sess