示例#1
0
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()
    saveable_view.fill_object_graph_proto(proto)

    node_ids = util.ObjectIdentityDictionary()
    for i, obj in enumerate(saveable_view.nodes):
        node_ids[obj] = i
        if resource_variable_ops.is_resource_variable(obj):
            node_ids[obj.handle] = i
        elif isinstance(obj, tracking.TrackableAsset):
            node_ids[obj.asset_path.handle] = i

    for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
示例#2
0
def _write_object_graph(root, export_dir, asset_file_def_index):
    """Save a SavedObjectGraph proto for `root`."""
    # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
    # checkpoint. It will eventually go into the SavedModel.
    proto = saved_object_graph_pb2.SavedObjectGraph()

    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    util.fill_object_graph_proto(checkpointable_objects, node_ids,
                                 slot_variables, proto)

    node_ids = util.ObjectIdentityDictionary()
    for i in range(len(checkpointable_objects)):
        obj = checkpointable_objects[i]
        node_ids[obj] = i
        if resource_variable_ops.is_resource_variable(obj):
            node_ids[obj.handle] = i
        elif isinstance(obj, tracking.TrackableAsset):
            node_ids[obj.asset_path.handle] = i

    for obj, obj_proto in zip(checkpointable_objects, proto.nodes):
        _write_object_proto(obj, obj_proto, asset_file_def_index)

    function_serialization.add_polymorphic_functions_to_object_graph_proto(
        checkpointable_objects, proto, node_ids)

    extra_asset_dir = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
    file_io.recursive_create_dir(extra_asset_dir)
    object_graph_filename = os.path.join(extra_asset_dir,
                                         compat.as_bytes("object_graph.pb"))
    file_io.write_string_to_file(object_graph_filename,
                                 proto.SerializeToString())
示例#3
0
    def __init__(self, root):
        checkpointable_objects, node_ids, slot_variables = util.find_objects(
            root)
        self.nodes = checkpointable_objects
        self.node_ids = node_ids
        self.slot_variables = slot_variables
        self.functions = util.ObjectIdentityDictionary()

        # Also add `Function`s as nodes.
        nodes_without_functions = list(self.nodes)
        for obj in nodes_without_functions:
            self.functions[obj] = self._list_functions(obj)
            for function in self.functions[obj].values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                    # Avoids recursing into functions to see if other functions are
                    # assigned to attributes. This is sometimes true for concrete
                    # functions but not helpful.
                    self.functions[function] = {}
                if isinstance(function, def_function.Function):
                    # Force listing the concrete functions for the side effects:
                    #  - populate the cache for functions that have an input_signature
                    #  and have not been called.
                    #  - force side effects of creation of concrete functions, e.g. create
                    #  variables on first run.
                    function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
示例#4
0
文件: save.py 项目: zaazad/tensorflow
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = util.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj.create_resource()
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path.handle] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (isinstance(capture, ops.EagerTensor)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(capture.numpy())
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
示例#5
0
  def __init__(self, root):
    checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
    self.nodes = checkpointable_objects
    self.node_ids = node_ids
    self.slot_variables = slot_variables
    self.polymorphic_functions = util.ObjectIdentityDictionary()

    # Also add polymorphic functions as nodes.
    for obj in self.nodes:
      self.polymorphic_functions[obj] = self._list_polymorphic_functions(obj)
      for function in self.polymorphic_functions[obj].values():
        if function not in self.node_ids:
          self.node_ids[function] = len(self.nodes)
          self.nodes.append(function)
        # Force listing the concrete functions for the side effects:
        #  - populate the cache for polymorphic functions that have an
        #  input_signature and have not been called.
        #  - force side effects of creation of concrete functions, e.g. create
        #  variables on first run.
        function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
示例#6
0
def _map_resources(accessible_objects):
    """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map, asset_info):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
      asset_info: An _AssetInfo tuple describing external assets referenced from
        accessible_objects.
  """
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = util.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(asset_defs=[],
                            asset_initializers_by_resource={},
                            asset_filename_map={},
                            asset_index={})
    for obj in accessible_objects:
        if isinstance(obj, tracking.TrackableResource):
            new_resource = obj.create_resource()
            resource_map[obj.resource_handle] = new_resource
        elif resource_variable_ops.is_resource_variable(obj):
            new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                obj)
            object_map[obj] = new_variable
            resource_map[obj.handle] = new_variable.handle
        elif isinstance(obj, tracking.TrackableAsset):
            _process_asset(obj, asset_info, resource_map)
    return object_map, resource_map, asset_info