示例#1
0
文件: base.py 项目: tsipporah/skflow
    def _restore(self, path):
        """Restores this estimator from given path.

        Note: will rebuild the graph and initialize all parameters,
        and will ignore provided model.

        Args:
            path: Path to checkpoints and other information.
        """
        # Currently Saver requires absolute path to work correctly.
        path = os.path.abspath(path)

        self._graph = tf.Graph()
        with self._graph.as_default():
            endpoints_filename = os.path.join(path, 'endpoints')
            if not os.path.exists(endpoints_filename):
                raise ValueError("Restore folder doesn't contain endpoints.")
            with open(endpoints_filename) as foutputs:
                endpoints = foutputs.read().split('\n')
            graph_filename = os.path.join(path, 'graph.pbtxt')
            if not os.path.exists(graph_filename):
                raise ValueError(
                    "Restore folder doesn't contain graph definition.")
            with open(graph_filename) as fgraph:
                graph_def = tf.GraphDef()
                text_format.Merge(fgraph.read(), graph_def)
                (self._inp, self._out, self._model_predictions,
                 self._model_loss) = tf.import_graph_def(
                     graph_def, return_elements=endpoints)
            saver_filename = os.path.join(path, 'saver.pbtxt')
            if not os.path.exists(saver_filename):
                raise ValueError(
                    "Restore folder doesn't contain saver defintion.")
            with open(saver_filename) as fsaver:
                from tensorflow.python.training import saver_pb2
                saver_def = saver_pb2.SaverDef()
                text_format.Merge(fsaver.read(), saver_def)
                # ??? For some reason the saver def doesn't have import/ prefix.
                saver_def.filename_tensor_name = 'import/' + saver_def.filename_tensor_name
                saver_def.restore_op_name = 'import/' + saver_def.restore_op_name
                self._saver = tf.train.Saver(saver_def=saver_def)
            self._session = tf.Session(
                self.tf_master,
                config=tf.ConfigProto(
                    log_device_placement=self.verbose > 1,
                    inter_op_parallelism_threads=self.num_cores,
                    intra_op_parallelism_threads=self.num_cores))
            self._graph.get_operation_by_name('import/save/restore_all')
            checkpoint_path = tf.train.latest_checkpoint(path)
            if checkpoint_path is None:
                raise ValueError(
                    "Missing checkpoint files in the %s. Please "
                    "make sure you are you have checkpoint file that describes "
                    "latest checkpoints and appropriate checkpoints are there. "
                    "If you have moved the folder, you at this point need to "
                    "update manually update the paths in the checkpoint file."
                    % path)
            self._saver.restore(self._session, checkpoint_path)
        # Set to be initialized.
        self._initialized = True
示例#2
0
    def as_saver_def(self):
        """Generates a `SaverDef` representation of this saver.

    Returns:
      A `SaverDef` proto.
    """
        return saver_pb2.SaverDef(
            filename_tensor_name=self._filename_tensor_name,
            save_tensor_name=self._save_tensor_name,
            restore_op_name=self._restore_op_name,
            max_to_keep=self._max_to_keep,
            keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
            sharded=self._sharded)
示例#3
0
    def build(self,
              names_to_variables,
              reshape=False,
              sharded=False,
              max_to_keep=5,
              keep_checkpoint_every_n_hours=10000.0,
              name=None,
              restore_sequentially=False):
        """Adds save/restore nodes to the graph and creates a SaverDef proto.

    Args:
      names_to_variables: A dictionary mapping name to a Variable.
        Each name will be associated with the
        corresponding variable in the checkpoint.
      reshape: If True, allow restoring parameters from a checkpoint
        that where the parameters have a different shape.  This is
        only needed when you try to restore from a Dist-Belief checkpoint,
        and only some times.
      sharded: If True, shard the checkpoints, one per device that has
        Parameters nodes.
      max_to_keep: maximum number of checkpoints to keep.  As new checkpoints
        are created, old ones are deleted.  If None or 0, no checkpoints are
        deleted.  Presently the number is only roughly enforced.  For example
        in case of restarts more than max_to_keep checkpoints may be kept.
      keep_checkpoint_every_n_hours: How often checkpoints should be kept.
        Defaults to 10,000 hours.
      name: string.  Optional name to use as a prefix when adding operations.
      restore_sequentially: A Bool, which if true, causes restore of different
        variables to happen sequentially within each device.

    Returns:
      A SaverDef proto.

    Raises:
      TypeError: If 'names_to_variables' is not a dictionary mapping string
        keys to variable Tensors.
      ValueError: If any of the keys or values in 'names_to_variables' is not
        unique.
    """
        vars_to_save = self._ValidateAndSliceInputs(names_to_variables)
        if max_to_keep is None:
            max_to_keep = 0

        with ops.op_scope([vs.var for vs in vars_to_save], name,
                          "save") as name:
            # Add the Constant string tensor for the filename.
            filename_tensor = constant_op.constant("model")

            # Add the save ops.
            if sharded:
                per_device = self._GroupByDevices(vars_to_save)
                save_tensor = self._AddShardedSaveOps(filename_tensor,
                                                      per_device)
                restore_op = self._AddShardedRestoreOps(
                    filename_tensor, per_device, restore_sequentially, reshape)
            else:
                save_tensor = self._AddSaveOps(filename_tensor, vars_to_save)
                restore_op = self._AddRestoreOps(filename_tensor, vars_to_save,
                                                 restore_sequentially, reshape)

        assert restore_op.name.endswith("restore_all"), restore_op.name

        return saver_pb2.SaverDef(
            filename_tensor_name=filename_tensor.name,
            save_tensor_name=save_tensor.name,
            restore_op_name=restore_op.name,
            max_to_keep=max_to_keep,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            sharded=sharded)