def object_graph_key_mapping(checkpoint_path: str) -> Dict[str, str]:
    """Return name to key mappings from the checkpoint."""
    reader = tf.train.load_checkpoint(checkpoint_path)
    object_graph_string = reader.get_tensor('_CHECKPOINTABLE_OBJECT_GRAPH')
    object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
    object_graph_proto.ParseFromString(object_graph_string)
    names_to_keys = {}
    for node in object_graph_proto.nodes:
        for attribute in node.attributes:
            names_to_keys[attribute.full_name] = attribute.checkpoint_key
    return names_to_keys
def _fill_object_graph_proto(graph_view, trackable_objects, node_ids,
                             slot_variables):
    """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
    object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
    for checkpoint_id, trackable in enumerate(trackable_objects):
        assert node_ids[trackable] == checkpoint_id
        object_proto = object_graph_proto.nodes.add()
        object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
        for child in graph_view.list_children(trackable):
            child_proto = object_proto.children.add()
            child_proto.node_id = node_ids[child.ref]
            child_proto.local_name = child.name
    return object_graph_proto
Exemple #3
0
 def _fill_object_graph_proto(self, trackable_objects,
                              node_ids,
                              slot_variables,
                              object_graph_proto=None):
   """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
   if object_graph_proto is None:
     object_graph_proto = (
         trackable_object_graph_pb2.TrackableObjectGraph())
   for checkpoint_id, trackable in enumerate(trackable_objects):
     assert node_ids[trackable] == checkpoint_id
     object_proto = object_graph_proto.nodes.add()
     object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
     for child in self.list_dependencies(trackable):
       child_proto = object_proto.children.add()
       child_proto.node_id = node_ids[child.ref]
       child_proto.local_name = child.name
   return object_graph_proto
def object_graph_key_mapping(checkpoint_path):
    """Return name to key mappings from the checkpoint.

  Args:
    checkpoint_path: string, path to object-based checkpoint

  Returns:
    Dictionary mapping tensor names to checkpoint keys.
  """
    reader = tf.train.load_checkpoint(checkpoint_path)
    object_graph_string = reader.get_tensor('_CHECKPOINTABLE_OBJECT_GRAPH')
    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    names_to_keys = {}
    for node in object_graph_proto.nodes:
        for attribute in node.attributes:
            names_to_keys[attribute.full_name] = attribute.checkpoint_key
    return names_to_keys
  def __init__(self, save_path):
    """Configure the checkpoint view.

    Args:
      save_path: The path to the checkpoint.

    Raises:
      ValueError: If the save_path does not lead to a TF2 checkpoint.
    """

    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
    try:
      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
    except errors_impl.NotFoundError as not_found_error:
      raise ValueError(
          f"The specified checkpoint \"{save_path}\" does not appear to be "
          "object-based (saved with TF2) since it is missing the key "
          f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
          "TF1 name-based saver and does not contain an object dependency graph."
      ) from not_found_error
    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    self._object_graph_proto = object_graph_proto
Exemple #6
0
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status