def object_graph_key_mapping(checkpoint_path: str) -> Dict[str, str]: """Return name to key mappings from the checkpoint.""" reader = tf.train.load_checkpoint(checkpoint_path) object_graph_string = reader.get_tensor('_CHECKPOINTABLE_OBJECT_GRAPH') object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() object_graph_proto.ParseFromString(object_graph_string) names_to_keys = {} for node in object_graph_proto.nodes: for attribute in node.attributes: names_to_keys[attribute.full_name] = attribute.checkpoint_key return names_to_keys
def _fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables): """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() for checkpoint_id, trackable in enumerate(trackable_objects): assert node_ids[trackable] == checkpoint_id object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(trackable, ())) for child in graph_view.list_children(trackable): child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name return object_graph_proto
def _fill_object_graph_proto(self, trackable_objects, node_ids, slot_variables, object_graph_proto=None): """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" if object_graph_proto is None: object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) for checkpoint_id, trackable in enumerate(trackable_objects): assert node_ids[trackable] == checkpoint_id object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(trackable, ())) for child in self.list_dependencies(trackable): child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name return object_graph_proto
def object_graph_key_mapping(checkpoint_path): """Return name to key mappings from the checkpoint. Args: checkpoint_path: string, path to object-based checkpoint Returns: Dictionary mapping tensor names to checkpoint keys. """ reader = tf.train.load_checkpoint(checkpoint_path) object_graph_string = reader.get_tensor('_CHECKPOINTABLE_OBJECT_GRAPH') object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) names_to_keys = {} for node in object_graph_proto.nodes: for attribute in node.attributes: names_to_keys[attribute.full_name] = attribute.checkpoint_key return names_to_keys
def __init__(self, save_path): """Configure the checkpoint view. Args: save_path: The path to the checkpoint. Raises: ValueError: If the save_path does not lead to a TF2 checkpoint. """ reader = py_checkpoint_reader.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError as not_found_error: raise ValueError( f"The specified checkpoint \"{save_path}\" does not appear to be " "object-based (saved with TF2) since it is missing the key " f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " "TF1 name-based saver and does not contain an object dependency graph." ) from not_found_error object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) self._object_graph_proto = object_graph_proto
def restore(self, save_path, options=None): """Restore a training checkpoint with host mesh placement.""" options = options or checkpoint_options.CheckpointOptions() if save_path is None: return util.InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access save_path=save_path, dtype_map=dtype_map) if not graph_building: for existing_trackable in self._graph_view.list_objects(): # pylint: disable=protected-access existing_trackable._maybe_initialize_trackable() existing_trackable._name_based_restores.add( restore_coordinator) existing_trackable._name_based_attribute_restore( restore_coordinator) # pylint: enable=protected-access return util.NameBasedSaverStatus(restore_coordinator, graph_view=self._graph_view) if graph_building: if self._file_prefix_placeholder is None: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. self._file_prefix_placeholder = api.pack( [constant_op.constant("model")] * self._mesh.num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. file_prefix_tensor = api.pack([constant_op.constant(save_path)] * self._mesh.num_local_devices(), layout.Layout.replicated( self._mesh.host_mesh(), rank=0)) file_prefix_feed_dict = None object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) # DTensor Change: Hook the proper DSaver in restore. checkpoint = _DCheckpointRestoreCoordinator( mesh=self._mesh, object_graph_proto=object_graph_proto, save_path=save_path, save_path_tensor=file_prefix_tensor, reader=reader, restore_op_cache=self._restore_op_cache, graph_view=self._graph_view, options=options, saveables_cache=self._saveables_cache) base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) # Attached dependencies are not attached to the root, so should be restored # separately. if self._graph_view.attached_dependencies: for ref in self._graph_view.attached_dependencies: if ref.name == "root": # Root dependency is automatically added to attached dependencies -- # this can be ignored since it maps back to the root object. continue proto_id = None # Find proto ID of attached dependency (if it is in the proto). for proto_ref in object_graph_proto.nodes[0].children: if proto_ref.local_name == ref.name: proto_id = proto_ref.node_id break if proto_id in checkpoint.object_by_proto_id: # Object has already been restored. This can happen when there's an # indirect connection from the attached object to the root. continue base.CheckpointPosition(checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) load_status = util.CheckpointLoadStatus( checkpoint, graph_view=self._graph_view, feed_dict=file_prefix_feed_dict) return load_status