예제 #1
0
 def _serialize_gathered_objects(self, checkpointable_objects, path_to_root,
                                 object_map=None):
   """Create SaveableObjects and protos for gathered objects."""
   object_names = object_identity.ObjectIdentityDictionary()
   for obj, path in path_to_root.items():
     object_names[obj] = _object_prefix_from_path(path)
   node_ids = object_identity.ObjectIdentityDictionary()
   for node_id, node in enumerate(checkpointable_objects):
     node_ids[node] = node_id
   slot_variables = _serialize_slot_variables(
       checkpointable_objects=checkpointable_objects,
       node_ids=node_ids,
       object_names=object_names)
   object_graph_proto = self._fill_object_graph_proto(
       checkpointable_objects=checkpointable_objects,
       node_ids=node_ids,
       slot_variables=slot_variables)
   named_saveable_objects, feed_additions = (
       self._add_attributes_to_object_graph(
           checkpointable_objects=checkpointable_objects,
           object_graph_proto=object_graph_proto,
           node_ids=node_ids,
           object_names=object_names,
           object_map=object_map))
   return named_saveable_objects, object_graph_proto, feed_additions
예제 #2
0
    def __init__(self, checkpoint_view):
        self.checkpoint_view = checkpoint_view
        checkpointable_objects, node_ids, slot_variables = (
            self.checkpoint_view.objects_ids_and_slot_variables())
        self.nodes = checkpointable_objects
        self.node_ids = node_ids
        self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary(
        )
        self.slot_variables = slot_variables
        self.concrete_functions = []

        # Also add `Function`s as nodes.
        nodes_without_functions = list(self.nodes)
        seen_function_names = set()
        for node in nodes_without_functions:
            for function in checkpoint_view.list_functions(node).values():
                if function not in self.node_ids:
                    self.node_ids[function] = len(self.nodes)
                    self.nodes.append(function)
                if isinstance(function, def_function.Function):
                    # Force listing the concrete functions for the side effects:
                    #  - populate the cache for functions that have an input_signature
                    #  and have not been called.
                    #  - force side effects of creation of concrete functions, e.g. create
                    #  variables on first run.
                    concrete_functions = (
                        function.
                        _list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
                else:
                    concrete_functions = [function]
                for concrete_function in concrete_functions:
                    if concrete_function.name not in seen_function_names:
                        seen_function_names.add(concrete_function.name)
                        self.concrete_functions.append(concrete_function)
예제 #3
0
    def map_resources(self):
        """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
        # Only makes sense when adding to the export Graph
        assert not context.executing_eagerly()
        # TODO(allenl): Handle MirroredVariables and other types of variables which
        # may need special casing.
        object_map = object_identity.ObjectIdentityDictionary()
        resource_map = {}
        asset_info = _AssetInfo(asset_defs=[],
                                asset_initializers_by_resource={},
                                asset_filename_map={},
                                asset_index={})
        for node_id, obj in enumerate(self.nodes):
            if isinstance(obj, tracking.TrackableResource):
                new_resource = obj.create_resource()
                resource_map[obj.resource_handle] = new_resource
                self.captured_tensor_node_ids[obj.resource_handle] = node_id
            elif resource_variable_ops.is_resource_variable(obj):
                new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                    obj)
                object_map[obj] = new_variable
                resource_map[obj.handle] = new_variable.handle
                self.captured_tensor_node_ids[obj.handle] = node_id
            elif isinstance(obj, tracking.TrackableAsset):
                _process_asset(obj, asset_info, resource_map)
                self.captured_tensor_node_ids[obj.asset_path] = node_id

        for concrete_function in self.concrete_functions:
            for capture in concrete_function.captured_inputs:
                if (isinstance(capture, ops.EagerTensor)
                        and capture.dtype not in _UNCOPIABLE_DTYPES
                        and capture not in self.captured_tensor_node_ids):
                    copied_tensor = constant_op.constant(capture.numpy())
                    node_id = len(self.nodes)
                    node = _CapturedConstant(eager_tensor=capture,
                                             graph_tensor=copied_tensor)
                    self.nodes.append(node)
                    self.node_ids[capture] = node_id
                    self.node_ids[node] = node_id
                    self.captured_tensor_node_ids[capture] = node_id
                    resource_map[capture] = copied_tensor

        return object_map, resource_map, asset_info
예제 #4
0
def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
    """Gather and name slot variables."""
    non_slot_objects = list(checkpointable_objects)
    slot_variables = object_identity.ObjectIdentityDictionary()
    for checkpointable in non_slot_objects:
        if (isinstance(checkpointable, optimizer_v1.Optimizer)
                # TODO(b/110718070): Fix Keras imports.
                or hasattr(checkpointable,
                           "_create_or_restore_slot_variable")):
            naming_scheme = _slot_variable_naming_for_optimizer(
                optimizer_path=object_names[checkpointable])
            slot_names = checkpointable.get_slot_names()
            for slot_name in slot_names:
                for original_variable_node_id, original_variable in enumerate(
                        non_slot_objects):
                    try:
                        slot_variable = checkpointable.get_slot(
                            original_variable, slot_name)
                    except (AttributeError, KeyError):
                        slot_variable = None
                    if slot_variable is None:
                        continue
                    slot_variable._maybe_initialize_checkpointable()  # pylint: disable=protected-access
                    if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
                        # TODO(allenl): Gather dependencies of slot variables.
                        raise NotImplementedError(
                            "Currently only variables with no dependencies can be saved as "
                            "slot variables. File a feature request if this limitation "
                            "bothers you.")
                    if slot_variable in node_ids:
                        raise NotImplementedError(
                            "A slot variable was re-used as a dependency of a "
                            "Checkpointable object. This is not currently allowed. File a "
                            "feature request if this limitation bothers you.")
                    checkpoint_name = naming_scheme(
                        variable_path=object_names[original_variable],
                        slot_name=slot_name)
                    object_names[slot_variable] = checkpoint_name
                    slot_variable_node_id = len(checkpointable_objects)
                    node_ids[slot_variable] = slot_variable_node_id
                    checkpointable_objects.append(slot_variable)
                    slot_variable_proto = (
                        checkpointable_object_graph_pb2.
                        CheckpointableObjectGraph.CheckpointableObject.
                        SlotVariableReference(
                            slot_name=slot_name,
                            original_variable_node_id=original_variable_node_id,
                            slot_variable_node_id=slot_variable_node_id))
                    slot_variables.setdefault(checkpointable,
                                              []).append(slot_variable_proto)
    return slot_variables
예제 #5
0
  def objects_ids_and_slot_variables(self):
    """Traverse the object graph and list all accessible objects.

    Looks for `Checkpointable` objects which are dependencies of
    `root_checkpointable`. Includes slot variables only if the variable they are
    slotting for and the optimizer are dependencies of `root_checkpointable`
    (i.e. if they would be saved with a checkpoint).

    Returns:
      A tuple of (checkpointable objects, object -> node id, slot variables)
    """
    checkpointable_objects, path_to_root = self._breadth_first_traversal()
    object_names = object_identity.ObjectIdentityDictionary()
    for obj, path in path_to_root.items():
      object_names[obj] = _object_prefix_from_path(path)
    node_ids = object_identity.ObjectIdentityDictionary()
    for node_id, node in enumerate(checkpointable_objects):
      node_ids[node] = node_id
    slot_variables = _serialize_slot_variables(
        checkpointable_objects=checkpointable_objects,
        node_ids=node_ids,
        object_names=object_names)
    return checkpointable_objects, node_ids, slot_variables
예제 #6
0
 def _breadth_first_traversal(self):
   """Find shortest paths to all dependencies of self.root."""
   bfs_sorted = []
   to_visit = collections.deque([self.root])
   path_to_root = object_identity.ObjectIdentityDictionary()
   path_to_root[self.root] = ()
   while to_visit:
     current_checkpointable = to_visit.popleft()
     if isinstance(current_checkpointable, tracking.NotCheckpointable):
       raise NotImplementedError(
           ("The object %s does not support object-based saving. File a "
            "feature request if this limitation bothers you. In the meantime, "
            "you can remove the dependency on this object and save everything "
            "else.")
           % (current_checkpointable,))
     bfs_sorted.append(current_checkpointable)
     for name, dependency in self.list_dependencies(current_checkpointable):
       if dependency not in path_to_root:
         path_to_root[dependency] = (
             path_to_root[current_checkpointable] + (
                 base.CheckpointableReference(name, dependency),))
         to_visit.append(dependency)
   return bfs_sorted, path_to_root
예제 #7
0
 def __init__(self, root):
     super(_AugmentedGraphView, self).__init__(root)
     # Object -> (name -> dep)
     self._extra_dependencies = object_identity.ObjectIdentityDictionary()
     self._functions = object_identity.ObjectIdentityDictionary()