コード例 #1
0
def _parse_input_saver_proto(input_saver, input_binary):
  """Parser input tensorflow Saver into SaverDef proto."""
  if not gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_saver, mode) as f:
    saver_def = saver_pb2.SaverDef()
    if input_binary:
      saver_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), saver_def)
  return saver_def
コード例 #2
0
 def to_proto(self):
     """Serializes to a SaverDef referencing the current graph."""
     filename_tensor = array_ops.placeholder(shape=[],
                                             dtype=dtypes.string,
                                             name="saver_filename")
     # TODO(allenl): Add save and restore function names to the proto directly.
     signature = (tensor_spec.TensorSpec(shape=(), dtype=dtypes.string), )
     # Autograph is off because of reference cycles which must be collected when
     # a function is created and destroyed (as in tf.saved_model.save). It's also
     # not necessary, so having it off may be slightly faster.
     #
     # TODO(b/121302372): We should be able to decorate save() and restore()
     # unconditionally.
     save_tensor = def_function.function(self.save,
                                         input_signature=signature,
                                         autograph=False)(filename_tensor)
     restore_op = def_function.function(self.restore,
                                        input_signature=signature,
                                        autograph=False)(filename_tensor).op
     return saver_pb2.SaverDef(filename_tensor_name=filename_tensor.name,
                               save_tensor_name=save_tensor.name,
                               restore_op_name=restore_op.name,
                               version=saver_pb2.SaverDef.V2)
コード例 #3
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
コード例 #4
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
コード例 #5
0
    def __init__(self,
                 var_list=None,
                 reshape=False,
                 sharded=False,
                 max_to_keep=5,
                 keep_checkpoint_every_n_hours=10000.0,
                 name=None,
                 restore_sequentially=False,
                 saver_def=None,
                 builder=None,
                 defer_build=False,
                 allow_empty=False,
                 write_version=saver_pb2.SaverDef.V2,
                 pad_step_number=False,
                 save_relative_paths=False,
                 filename=None):
        # pylint: disable=too-many-arguments, too-many-locals
        """
        Saver for AutoDist.

        This saver saves the variables that maps to the *original*, *user-declared*, *single-node* graph,
        instead of the transformed graph. Hence, the saved variables can be loaded either by the original
        user code (for resuming single-node training or inference), or by the AutoDist-distributed code (to
        resume distributed training). Differently, AutoDist saver saves the meta_graph_def that maps to the
        *AutoDist transformed* graph (TODO).

        AutoDist saver implements this by writing/loading the contents that maps to the master_replica
        (default to replica_index=0) of the transformed graph.

        Refer to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
        for a detailed explanation of the signature.
        """
        # TODO(Hao): support constructing saver after AutoDist graph transformation
        _autodist = autodist.autodist.get_default_autodist()
        if _autodist.is_built():
            raise ValueError(
                'Saver must be used before create_distributed_session().')

        # A saver will append relevant save/restore ops to all variables in var_list, i.e. one saver
        # maps to all variables, and encloses them under a "saver" scope.
        super(Saver, self).__init__(
            var_list=var_list,
            reshape=reshape,
            sharded=sharded,
            max_to_keep=max_to_keep,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            name=name,
            restore_sequentially=restore_sequentially,
            saver_def=saver_def,
            builder=builder,
            defer_build=defer_build,
            allow_empty=allow_empty,
            write_version=write_version,
            pad_step_number=pad_step_number,
            save_relative_paths=save_relative_paths,
            filename=filename)

        # Note: tensorflow does not add user-declared saver to collections, so have to track it in info.
        item = get_default_graph_item()
        new_saver_def = saver_pb2.SaverDef()
        new_saver_def.CopyFrom(self.to_proto())
        item.info.update_savers([new_saver_def], replace=False)
コード例 #6
0
    def _build_internal(self,
                        names_to_saveables,
                        reshape=False,
                        sharded=False,
                        max_to_keep=5,
                        keep_checkpoint_every_n_hours=10000.0,
                        name=None,
                        restore_sequentially=False,
                        filename="model",
                        build_save=True,
                        build_restore=True):
        """build() with option to only perform save and restore."""
        if not context.executing_eagerly() and (not build_save or
                                                not build_restore):
            raise ValueError("save and restore operations need to be built together "
                            " when eager execution is not enabled.")

        saveables = saveable_object_util.validate_and_slice_inputs(
            names_to_saveables)
        if max_to_keep is None:
            max_to_keep = 0

        with ops.name_scope(name, "save",
                            [saveable.op for saveable in saveables]) as name:
            # Add a placeholder string tensor for the filename.
            filename_tensor = array_ops.placeholder_with_default(
                filename or "model", shape=(), name="filename")
            # Keep the name "Const" for backwards compatibility.
            filename_tensor = array_ops.placeholder_with_default(
                filename_tensor, shape=(), name="Const")

            # Add the save ops.
            if sharded:
                per_device = self._GroupByDevices(saveables)
                if build_save:
                    op_list = []
                    with tf.name_scope("Save_Weight_Update_Sharding"):
                        grad_and_var_items = util.get_all_grad_item()
                        for item in grad_and_var_items:
                            if item.var in names_to_saveables:
                                rank_id = item.root_rank_id
                                if rank_id >= 0:
                                    with tf.get_default_graph().control_dependencies(op_list):
                                        out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id)
                                    op_list.append(out_var[0].op)
                    if len(op_list) > 0:
                        with tf.get_default_graph().control_dependencies(op_list):
                            save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
                    else:
                        save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
                if build_restore:
                    restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
                                                        restore_sequentially, reshape)
            else:
                if build_save:
                    op_list = []
                    with tf.name_scope("Save_Weight_Update_Sharding"):
                        grad_and_var_items = util.get_all_grad_item()
                        for item in grad_and_var_items:
                            if item.var in names_to_saveables:
                                rank_id = item.root_rank_id
                                if rank_id >= 0:
                                    with tf.get_default_graph().control_dependencies(op_list):
                                        out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id)
                                    op_list.append(out_var[0].op)
                    if len(op_list) > 0:
                        with tf.get_default_graph().control_dependencies(op_list):
                            save_tensor = self._AddSaveOps(filename_tensor, saveables)
                    else:
                        save_tensor = self._AddSaveOps(filename_tensor, saveables)
                if build_restore:
                    restore_op = self._AddRestoreOps(filename_tensor, saveables,
                                                restore_sequentially, reshape)

        # In the following use case, it's possible to have restore_ops be called
        # something else:
        # - Build inference graph and export a meta_graph.
        # - Import the inference meta_graph
        # - Extend the inference graph to a train graph.
        # - Export a new meta_graph.
        # Now the second restore_op will be called "restore_all_1".
        # As such, comment out the assert for now until we know whether supporting
        # such usage model makes sense.
        #
        # assert restore_op.name.endswith("restore_all"), restore_op.name
        if context.executing_eagerly():
            # Store the tensor values to the tensor_names.
            save_tensor_name = save_tensor.numpy() if build_save else ""
            return saver_pb2.SaverDef(
                filename_tensor_name=filename_tensor.numpy(),
                save_tensor_name=save_tensor_name,
                restore_op_name="",
                max_to_keep=max_to_keep,
                sharded=sharded,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                version=self._write_version)
        else:
            graph = ops.get_default_graph()
            # Do some sanity checking on collections containing
            # PartitionedVariables. If a saved collection has a PartitionedVariable,
            # the GraphDef needs to include concat ops to get the value (or there'll
            # be a lookup error on load).
            check_collection_list = graph.get_all_collection_keys()
            for collection_type in check_collection_list:
                for element in graph.get_collection(collection_type):
                    if isinstance(element, variables.PartitionedVariable):
                        try:
                            graph.get_operation_by_name(element.name)
                        except KeyError:
                            # Create a concat op for this PartitionedVariable. The user may
                            # not need it, but we'll try looking it up on MetaGraph restore
                            # since it's in a collection.
                            element.as_tensor()
            return saver_pb2.SaverDef(
                filename_tensor_name=filename_tensor.name,
                save_tensor_name=save_tensor.name,
                restore_op_name=restore_op.name,
                max_to_keep=max_to_keep,
                sharded=sharded,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                version=self._write_version)