Example #1
0
  def __init__(self, object_graph_proto, saved_model_proto, export_dir):
    meta_graph = saved_model_proto.meta_graphs[0]
    self._asset_file_def = meta_graph.asset_file_def
    self._operation_attributes = {
        node.name: node.attr for node in meta_graph.graph_def.node}
    self._proto = object_graph_proto
    self._export_dir = export_dir
    self._concrete_functions = (
        function_deserialization.load_function_def_library(
            meta_graph.graph_def.library))

    for name, concrete_function in self._concrete_functions.items():
      # Wrap all the concrete function so that they are capable of dealing with
      # both in replica and cross replica cases.
      self._concrete_functions[name] = _WrapperFunction(concrete_function)

    self._load_all()
    # TODO(b/124045874): There are limitations with functions whose captures
    # trigger other functions to be executed. For now it is only guaranteed to
    # work if the captures of a function only trigger functions without
    # captures.
    self._setup_functions_structures()
    self._setup_functions_captures()
    self._restore_checkpoint()

    for node in self._nodes:
      if isinstance(node, tracking.CapturableResource):
        init_op = node._initialize()  # pylint: disable=protected-access
        if not context.executing_eagerly():
          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
Example #2
0
  def __init__(self, object_graph_proto, saved_model_proto, export_dir,
               ckpt_options):
    meta_graph = saved_model_proto.meta_graphs[0]
    self._asset_file_def = meta_graph.asset_file_def
    self._operation_attributes = {
        node.name: node.attr for node in meta_graph.graph_def.node}
    self._proto = object_graph_proto
    self._export_dir = export_dir
    self._concrete_functions = (
        function_deserialization.load_function_def_library(
            meta_graph.graph_def.library))
    self._checkpoint_options = ckpt_options

    for name, concrete_function in self._concrete_functions.items():
      # Wrap all the concrete function so that they are capable of dealing with
      # both in replica and cross replica cases.
      self._concrete_functions[name] = _WrapperFunction(concrete_function)

    self._load_all()
    self._restore_checkpoint()

    for node in self._nodes:
      if isinstance(node, tracking.CapturableResource):
        init_op = node._initialize()  # pylint: disable=protected-access
        if not context.executing_eagerly():
          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
Example #3
0
    def __init__(self, object_graph_proto, saved_model_proto, export_dir):
        meta_graph = saved_model_proto.meta_graphs[0]
        self._asset_file_def = meta_graph.asset_file_def
        self._operation_attributes = {
            node.name: node.attr
            for node in meta_graph.graph_def.node
        }
        self._proto = object_graph_proto
        self._export_dir = export_dir
        self._concrete_functions = (
            function_deserialization.load_function_def_library(
                meta_graph.graph_def.library))
        self._load_all()
        # TODO(b/124045874): There are limitations with functions whose captures
        # trigger other functions to be executed. For now it is only guaranteed to
        # work if the captures of a function only trigger functions without
        # captures.
        self._setup_functions_structures()
        self._setup_functions_captures()
        self._restore_checkpoint()

        for node in self._nodes:
            if isinstance(node, tracking.TrackableResource):
                init_op = node.initialize()
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      init_op)
Example #4
0
    def __init__(self, object_graph_proto, saved_model_proto, export_dir,
                 ckpt_options, save_options, filters):
        meta_graph = saved_model_proto.meta_graphs[0]
        self._asset_file_def = meta_graph.asset_file_def
        self._operation_attributes = {
            node.name: node.attr
            for node in meta_graph.graph_def.node
        }
        self._proto = object_graph_proto
        self._export_dir = export_dir
        self._concrete_functions = (
            function_deserialization.load_function_def_library(
                library=meta_graph.graph_def.library,
                saved_object_graph=self._proto,
                wrapper_function=_WrapperFunction))
        # Store a set of all concrete functions that have been set up with
        # captures.
        self._restored_concrete_functions = set()
        self._checkpoint_options = ckpt_options
        self._save_options = save_options

        self._pretty_printer = checkpoint.ObjectGraphProtoPrettyPrinter(
            self._proto)

        # Stores user-defined node_filters argument.
        self._node_filters = filters
        # Stores map of string paths to integers.
        self._node_path_to_id = self._convert_node_paths_to_ints()
        self._loaded_nodes = {}
        if isinstance(filters, dict):
            # If node_filters is a dict, then the values may contain already created
            # trackable objects. In this case, create a dictionary mapping node IDs to
            # the already created nodes. This dict will be updated in
            # `_retrieve_all_filtered_nodes` with tracked children.
            for node_path, node in filters.items():
                if isinstance(node, tuple):
                    self._loaded_nodes[self._node_path_to_id[node_path]] = node
                else:
                    self._loaded_nodes[self._node_path_to_id[node_path]] = (
                        node, setattr)

        # Get a list of all integer node ids to load, or None if all nodes should be
        # loaded. This list includes ids of child nodes.
        self._filtered_nodes = self._retrieve_all_filtered_nodes()

        # Order all nodes or filtered nodes using the dependencies.
        self._ordered_node_ids = self._generate_ordered_node_ids()

        self._load_all()

        if not save_options.experimental_skip_checkpoint:
            self._restore_checkpoint()
        for node in self._nodes:
            if isinstance(node, resource.CapturableResource):
                init_op = node._initialize()  # pylint: disable=protected-access
                if not context.executing_eagerly():
                    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                          init_op)
Example #5
0
 def __init__(self, object_graph_proto, saved_model_proto, export_dir):
   meta_graph = saved_model_proto.meta_graphs[0]
   self._asset_file_def = meta_graph.asset_file_def
   self._proto = object_graph_proto
   self._export_dir = export_dir
   self._functions = function_deserialization.load_function_def_library(
       meta_graph.graph_def.library)
   self._load_all()
   self._bind_function_captures()
   self._restore_checkpoint()
Example #6
0
 def __init__(self, object_graph_proto, saved_model_proto, export_dir):
   meta_graph = saved_model_proto.meta_graphs[0]
   self._asset_file_def = meta_graph.asset_file_def
   self._proto = object_graph_proto
   self._export_dir = export_dir
   self._concrete_functions = (
       function_deserialization.load_function_def_library(
           meta_graph.graph_def.library))
   self._load_all()
   self._setup_functions()
   self._restore_checkpoint()
Example #7
0
    def __init__(self, object_graph_proto, saved_model_proto, export_dir,
                 ckpt_options, save_options, filters):
        meta_graph = saved_model_proto.meta_graphs[0]
        self._asset_file_def = meta_graph.asset_file_def
        self._operation_attributes = {
            node.name: node.attr
            for node in meta_graph.graph_def.node
        }
        self._proto = object_graph_proto
        self._export_dir = export_dir
        self._concrete_functions = (
            function_deserialization.load_function_def_library(
                meta_graph.graph_def.library))
        self._checkpoint_options = ckpt_options
        self._save_options = save_options

        # Stores user-defined node_filters argument.
        self._node_filters = filters
        # Stores map of string paths to integers.
        self._node_path_to_id = self._convert_node_paths_to_ints()
        self._loaded_nodes = {}
        if isinstance(filters, dict):
            # If node_filters is a dict, then the values may contain already created
            # trackable objects. In this case, create a dictionary mapping node IDs to
            # the already created nodes. This dict will be updated in
            # `_retrieve_all_filtered_nodes` with tracked dependencies.
            for node_path, node in filters.items():
                if isinstance(node, tuple):
                    self._loaded_nodes[self._node_path_to_id[node_path]] = node
                else:
                    self._loaded_nodes[self._node_path_to_id[node_path]] = (
                        node, setattr)

        # Get a list of all integer node ids to load, or None if all nodes should be
        # loaded. This list includes ids of child nodes.
        self._filtered_nodes = self._retrieve_all_filtered_nodes()

        for name, concrete_function in self._concrete_functions.items():
            # Wrap all the concrete function so that they are capable of dealing with
            # both in replica and cross replica cases.
            self._concrete_functions[name] = _WrapperFunction(
                concrete_function)

        self._load_all()
        self._restore_checkpoint()

        for node in self._nodes:
            if isinstance(node, tracking.CapturableResource):
                init_op = node._initialize()  # pylint: disable=protected-access
                if not context.executing_eagerly():
                    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                          init_op)
Example #8
0
 def __init__(self, object_graph_proto, saved_model_proto, export_dir):
   meta_graph = saved_model_proto.meta_graphs[0]
   self._asset_file_def = meta_graph.asset_file_def
   self._operation_attributes = {
       node.name: node.attr for node in meta_graph.graph_def.node}
   self._proto = object_graph_proto
   self._export_dir = export_dir
   self._concrete_functions = (
       function_deserialization.load_function_def_library(
           meta_graph.graph_def.library))
   self._load_all()
   self._setup_functions()
   self._restore_checkpoint()
Example #9
0
  def __init__(self, object_graph_proto, saved_model_proto, export_dir):
    meta_graph = saved_model_proto.meta_graphs[0]
    self._asset_file_def = meta_graph.asset_file_def
    self._operation_attributes = {
        node.name: node.attr for node in meta_graph.graph_def.node}
    self._proto = object_graph_proto
    self._export_dir = export_dir
    self._concrete_functions = (
        function_deserialization.load_function_def_library(
            meta_graph.graph_def.library))
    self._load_all()
    # TODO(b/124045874): There are limitations with functions whose captures
    # trigger other functions to be executed. For now it is only guaranteed to
    # work if the captures of a function only trigger functions without
    # captures.
    self._setup_functions_structures()
    self._setup_functions_captures()
    self._restore_checkpoint()

    for node in self._nodes:
      if isinstance(node, tracking.CapturableResource):
        init_op = node._initialize()  # pylint: disable=protected-access
        if not context.executing_eagerly():
          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
Example #10
0
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_shared_name_suffix = "_load_{}".format(ops.uid())
        functions = function_deserialization.load_function_def_library(
            meta_graph_def.graph_def.library,
            load_shared_name_suffix=load_shared_name_suffix)
        # Replace existing functions in the MetaGraphDef with renamed functions so
        # we don't have duplicates or name collisions.
        meta_graph_def.graph_def.library.Clear()
        for function in functions.values():
            meta_graph_def.graph_def.library.function.add().CopyFrom(
                function.function_def)
        # We've renamed functions and shared names. We need the same operation on
        # the GraphDef itself for consistency.
        for node_def in meta_graph_def.graph_def.node:
            function_deserialization.fix_node_def(
                node_def,
                functions,
                load_shared_name_suffix,
                debug_name="MetaGraph import")

        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.Asset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root