Пример #1
0
 def _serialize_gathered_objects(self, trackable_objects, path_to_root,
                                 object_map=None):
   """Create SaveableObjects and protos for gathered objects."""
   object_names = object_identity.ObjectIdentityDictionary()
   for obj, path in path_to_root.items():
     object_names[obj] = _object_prefix_from_path(path)
   node_ids = object_identity.ObjectIdentityDictionary()
   for node_id, node in enumerate(trackable_objects):
     node_ids[node] = node_id
   slot_variables = _serialize_slot_variables(
       trackable_objects=trackable_objects,
       node_ids=node_ids,
       object_names=object_names)
   object_graph_proto = self._fill_object_graph_proto(
       trackable_objects=trackable_objects,
       node_ids=node_ids,
       slot_variables=slot_variables)
   named_saveable_objects, feed_additions = (
       self._add_attributes_to_object_graph(
           trackable_objects=trackable_objects,
           object_graph_proto=object_graph_proto,
           node_ids=node_ids,
           object_names=object_names,
           object_map=object_map))
   return named_saveable_objects, object_graph_proto, feed_additions
Пример #2
0
 def __init__(self, root):
     if (not context.executing_eagerly() and not ops.inside_function()):
         saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
     else:
         saveables_cache = None
     super(_AugmentedGraphView, self).__init__(root, saveables_cache)
     # Object -> (name -> dep)
     self._extra_dependencies = object_identity.ObjectIdentityDictionary()
     self._functions = object_identity.ObjectIdentityDictionary()
Пример #3
0
 def __init__(self, root):
     if (not context.executing_eagerly() and not ops.inside_function()):
         saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
     else:
         saveables_cache = None
     super(_AugmentedGraphView, self).__init__(root, saveables_cache)
     # Object -> (name -> dep)
     self._extra_dependencies = object_identity.ObjectIdentityDictionary()
     self._functions = object_identity.ObjectIdentityDictionary()
     # Cache shared between objects in the same object graph. This is passed to
     # each trackable object's `_list_extra_dependencies_for_serialization` and
     # `_list_functions_for_serialization` function.
     self._serialization_cache = object_identity.ObjectIdentityDictionary()
Пример #4
0
    def __init__(self, checkpoint_view):
        self.checkpoint_view = checkpoint_view
        trackable_objects, node_ids, slot_variables = (
            self.checkpoint_view.objects_ids_and_slot_variables())
        self.nodes = trackable_objects
        self.node_ids = node_ids
        self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary(
        )
        self.slot_variables = slot_variables
        self.concrete_functions = []

        # Also add `Function`s as nodes.
        nodes_without_functions = list(self.nodes)
        seen_function_names = set()
        for node in nodes_without_functions:
            for function in checkpoint_view.list_functions(node).values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                if isinstance(function, def_function.Function):
                    # Force listing the concrete functions for the side effects:
                    #  - populate the cache for functions that have an input_signature
                    #  and have not been called.
                    #  - force side effects of creation of concrete functions, e.g. create
                    #  variables on first run.
                    concrete_functions = (
                        function.
                        _list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
                else:
                    concrete_functions = [function]
                for concrete_function in concrete_functions:
                    if concrete_function.name not in seen_function_names:
                        seen_function_names.add(concrete_function.name)
                        self.concrete_functions.append(concrete_function)
Пример #5
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
Пример #6
0
  def objects_ids_and_slot_variables(self):
    """Traverse the object graph and list all accessible objects.

    Looks for `Trackable` objects which are dependencies of
    `root_trackable`. Includes slot variables only if the variable they are
    slotting for and the optimizer are dependencies of `root_trackable`
    (i.e. if they would be saved with a checkpoint).

    Returns:
      A tuple of (trackable objects, object -> node id, slot variables)
    """
    trackable_objects, path_to_root = self._breadth_first_traversal()
    object_names = object_identity.ObjectIdentityDictionary()
    for obj, path in path_to_root.items():
      object_names[obj] = _object_prefix_from_path(path)
    node_ids = object_identity.ObjectIdentityDictionary()
    for node_id, node in enumerate(trackable_objects):
      node_ids[node] = node_id
    slot_variables = _serialize_slot_variables(
        trackable_objects=trackable_objects,
        node_ids=node_ids,
        object_names=object_names)
    return trackable_objects, node_ids, slot_variables
Пример #7
0
def _serialize_slot_variables(trackable_objects, node_ids, object_names):
  """Gather and name slot variables."""
  non_slot_objects = list(trackable_objects)
  slot_variables = object_identity.ObjectIdentityDictionary()
  for trackable in non_slot_objects:
    if (isinstance(trackable, optimizer_v1.Optimizer)
        # TODO(b/110718070): Fix Keras imports.
        or hasattr(trackable, "_create_or_restore_slot_variable")):
      naming_scheme = _slot_variable_naming_for_optimizer(
          optimizer_path=object_names[trackable])
      slot_names = trackable.get_slot_names()
      for slot_name in slot_names:
        for original_variable_node_id, original_variable in enumerate(
            non_slot_objects):
          try:
            slot_variable = trackable.get_slot(
                original_variable, slot_name)
          except (AttributeError, KeyError):
            slot_variable = None
          if slot_variable is None:
            continue
          slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
          if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
            # TODO(allenl): Gather dependencies of slot variables.
            raise NotImplementedError(
                "Currently only variables with no dependencies can be saved as "
                "slot variables. File a feature request if this limitation "
                "bothers you.")
          if slot_variable in node_ids:
            raise NotImplementedError(
                "A slot variable was re-used as a dependency of a "
                "Trackable object. This is not currently allowed. File a "
                "feature request if this limitation bothers you.")
          checkpoint_name = naming_scheme(
              variable_path=object_names[original_variable],
              slot_name=slot_name)
          object_names[slot_variable] = checkpoint_name
          slot_variable_node_id = len(trackable_objects)
          node_ids[slot_variable] = slot_variable_node_id
          trackable_objects.append(slot_variable)
          slot_variable_proto = (
              trackable_object_graph_pb2.TrackableObjectGraph
              .TrackableObject.SlotVariableReference(
                  slot_name=slot_name,
                  original_variable_node_id=original_variable_node_id,
                  slot_variable_node_id=slot_variable_node_id))
          slot_variables.setdefault(trackable, []).append(
              slot_variable_proto)
  return slot_variables
Пример #8
0
 def _breadth_first_traversal(self):
     """Find shortest paths to all dependencies of self.root."""
     bfs_sorted = []
     to_visit = collections.deque([self.root])
     path_to_root = object_identity.ObjectIdentityDictionary()
     path_to_root[self.root] = ()
     while to_visit:
         current_trackable = to_visit.popleft()
         if isinstance(current_trackable, tracking.NotTrackable):
             raise NotImplementedError((
                 "The object %s does not support object-based saving. File a "
                 "feature request if this limitation bothers you. In the meantime, "
                 "you can remove the dependency on this object and save everything "
                 "else.") % (current_trackable, ))
         bfs_sorted.append(current_trackable)
         for name, dependency in self.list_dependencies(current_trackable):
             if dependency not in path_to_root:
                 path_to_root[dependency] = (
                     path_to_root[current_trackable] +
                     (base.TrackableReference(name, dependency), ))
                 to_visit.append(dependency)
     return bfs_sorted, path_to_root
Пример #9
0
 def __init__(self, root):
     super(_AugmentedGraphView, self).__init__(root)
     # Object -> (name -> dep)
     self._extra_dependencies = object_identity.ObjectIdentityDictionary()
     self._functions = object_identity.ObjectIdentityDictionary()