Esempio n. 1
0
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0
Esempio n. 2
0
  def _restore(self, path):
    """Restores this estimator from given path.

    Note: will rebuild the graph and initialize all parameters,
    and will ignore provided model.

    Args:
      path: Path to checkpoints and other information.
    """
    # Currently Saver requires absolute path to work correctly.
    path = os.path.abspath(path)

    self._graph = ops.Graph()
    with self._graph.as_default():
      endpoints_filename = os.path.join(path, 'endpoints')
      if not os.path.exists(endpoints_filename):
        raise ValueError("Restore folder doesn't contain endpoints.")
      with gfile.Open(endpoints_filename) as foutputs:
        endpoints = foutputs.read().split('\n')
      graph_filename = os.path.join(path, 'graph.pbtxt')
      if not os.path.exists(graph_filename):
        raise ValueError("Restore folder doesn't contain graph definition.")
      with gfile.Open(graph_filename) as fgraph:
        graph_def = graph_pb2.GraphDef()
        text_format.Merge(fgraph.read(), graph_def)
        (self._inp, self._out, self._model_predictions,
         self._model_loss) = importer.import_graph_def(
             graph_def, name='', return_elements=endpoints)
      saver_filename = os.path.join(path, 'saver.pbtxt')
      if not os.path.exists(saver_filename):
        raise ValueError("Restore folder doesn't contain saver definition.")
      with gfile.Open(saver_filename) as fsaver:
        saver_def = train.SaverDef()
        text_format.Merge(fsaver.read(), saver_def)
        self._saver = train.Saver(saver_def=saver_def)

      # Restore trainer
      self._global_step = self._graph.get_tensor_by_name('global_step:0')
      self._train = self._graph.get_operation_by_name('train')

      # Restore summaries.
      self._summaries = self._graph.get_operation_by_name(
          'MergeSummary/MergeSummary')

      # Restore session.
      if not isinstance(self._config, RunConfig):
        self._config = RunConfig(verbose=self.verbose)
      self._session = session.Session(self._config.tf_master,
                                      config=self._config.tf_config)
      checkpoint_path = train.latest_checkpoint(path)
      if checkpoint_path is None:
        raise ValueError(
            'Missing checkpoint files in the %s. Please '
            'make sure you are you have checkpoint file that describes '
            'latest checkpoints and appropriate checkpoints are there. '
            'If you have moved the folder, you at this point need to '
            'update manually update the paths in the checkpoint file.' % path)
      self._saver.restore(self._session, checkpoint_path)
    # Set to be initialized.
    self._initialized = True
Esempio n. 3
0
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0
Esempio n. 4
0
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
    # noinspection PyBroadException
    try:
        checkpoint_reader = training.NewCheckpointReader(
            training.latest_checkpoint(checkpoint_dir))
        return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
    except:
        return 0
Esempio n. 5
0
 def _load_global_step(self):
     try:
         checkpoint_reader = tf_training.NewCheckpointReader(
             tf_training.latest_checkpoint(self._checkpoint_dir))
         step = checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
         return step
     except Exception as e:
         print("Ignored: " + str(e.args))
         return 0
Esempio n. 6
0
    def _restore(self, path):
        """Restores this estimator from given path.

        Note: will rebuild the graph and initialize all parameters,
        and will ignore provided model.

        Args:
            path: Path to checkpoints and other information.
        """
        # Currently Saver requires absolute path to work correctly.
        path = os.path.abspath(path)

        self._graph = ops.Graph()
        with self._graph.as_default():
            endpoints_filename = os.path.join(path, 'endpoints')
            if not os.path.exists(endpoints_filename):
                raise ValueError("Restore folder doesn't contain endpoints.")
            with gfile.Open(endpoints_filename) as foutputs:
                endpoints = foutputs.read().split('\n')
            graph_filename = os.path.join(path, 'graph.pbtxt')
            if not os.path.exists(graph_filename):
                raise ValueError(
                    "Restore folder doesn't contain graph definition.")
            with gfile.Open(graph_filename) as fgraph:
                graph_def = graph_pb2.GraphDef()
                text_format.Merge(fgraph.read(), graph_def)
                (self._inp, self._out, self._model_predictions,
                 self._model_loss) = importer.import_graph_def(
                     graph_def, name='', return_elements=endpoints)
            saver_filename = os.path.join(path, 'saver.pbtxt')
            if not os.path.exists(saver_filename):
                raise ValueError(
                    "Restore folder doesn't contain saver defintion.")
            with gfile.Open(saver_filename) as fsaver:
                saver_def = train.SaverDef()
                text_format.Merge(fsaver.read(), saver_def)
                self._saver = train.Saver(saver_def=saver_def)

            # Restore trainer
            self._global_step = self._graph.get_tensor_by_name('global_step:0')
            self._train = self._graph.get_operation_by_name('train')

            # Restore summaries.
            self._summaries = self._graph.get_operation_by_name(
                'MergeSummary/MergeSummary')

            # Restore session.
            if not isinstance(self._config, RunConfig):
                self._config = RunConfig(verbose=self.verbose)
            self._session = session.Session(self._config.tf_master,
                                            config=self._config.tf_config)
            checkpoint_path = train.latest_checkpoint(path)
            if checkpoint_path is None:
                raise ValueError(
                    "Missing checkpoint files in the %s. Please "
                    "make sure you are you have checkpoint file that describes "
                    "latest checkpoints and appropriate checkpoints are there. "
                    "If you have moved the folder, you at this point need to "
                    "update manually update the paths in the checkpoint file."
                    % path)
            self._saver.restore(self._session, checkpoint_path)
        # Set to be initialized.
        self._initialized = True