def read_ckpt(ckpt): reader = py_checkpoint_reader.NewCheckpointReader(ckpt) weights = { n: torch.as_tensor(tf2pth(reader.get_tensor(n))) for (n, _) in reader.get_variable_to_shape_map().items() } return weights
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, all_tensor_names=False, count_exclude_pattern=""): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. all_tensor_names: Boolean indicating whether to print all tensor names. count_exclude_pattern: Regex string, pattern to exclude tensors when count. """ try: reader = py_checkpoint_reader.NewCheckpointReader(file_name) if all_tensors or all_tensor_names: var_to_shape_map = reader.get_variable_to_shape_map() var_to_dtype_map = reader.get_variable_to_dtype_map() for key, value in sorted(var_to_shape_map.items()): print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value)) if all_tensors: print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8", errors="ignore")) else: if not reader.has_tensor(tensor_name): print("Tensor %s not found in checkpoint" % tensor_name) return var_to_shape_map = reader.get_variable_to_shape_map() var_to_dtype_map = reader.get_variable_to_dtype_map() print("tensor: %s (%s) %s" % (tensor_name, var_to_dtype_map[tensor_name].name, var_to_shape_map[tensor_name])) print(reader.get_tensor(tensor_name)) # Count total number of parameters print("# Total number of params: %d" % _count_total_params( reader, count_exclude_pattern=count_exclude_pattern)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") if ("Data loss" in str(e) and any(e in file_name for e in [".index", ".meta", ".data"])): proposed_file = ".".join(file_name.split(".")[0:-1]) v2_file_error_template = """ It's likely that this is a V2 checkpoint and you need to provide the filename *prefix*. Try removing the '.' and extension. Try: inspect checkpoint --file_name = {}""" print(v2_file_error_template.format(proposed_file))
def restore_saveables( self, tensor_saveables: Dict[str, saveable_object.SaveableObject], python_saveables: List[base.PythonStateSaveable], registered_savers: Optional[Dict[str, Dict[str, base.Trackable]]] = None ) -> Optional[List[ops.Operation]]: """Run or build restore operations for SaveableObjects. Args: tensor_saveables: `SaveableObject`s which correspond to Tensors. python_saveables: `PythonStateSaveable`s which correspond to Python values. registered_savers: a dict mapping saver names-> object name -> Trackable. This argument is not implemented for DTensorCheckpoint. Returns: When graph building, a list of restore operations, either cached or newly created, to restore `tensor_saveables`. """ del registered_savers restore_ops = [] # Eagerly run restorations for Python state. reader = None for saveable in python_saveables: if reader is None: # Lazily create the NewCheckpointReader, since this requires file access # and we may not have any Python saveables. reader = py_checkpoint_reader.NewCheckpointReader( self.save_path_string) spec_names = [spec.name for spec in saveable.specs] saveable.python_restore( [reader.get_tensor(name) for name in spec_names]) # If we have new SaveableObjects, extract and cache restore ops. if tensor_saveables: validated_saveables = saveable_object_util.validate_and_slice_inputs( tensor_saveables) validated_names = set(saveable.name for saveable in validated_saveables) if set(tensor_saveables.keys()) != validated_names: raise AssertionError( ("Saveable keys changed when validating. Got back %s, was " "expecting %s") % (tensor_saveables.keys(), validated_names)) # DTensor change: Use _DSaver that does restore on DTensor with # customized DTensorRestoreV2 op. new_restore_ops = _DSaver(self._mesh, validated_saveables).restore( self.save_path_tensor, self.options) if not context.executing_eagerly(): for name, restore_op in sorted(new_restore_ops.items()): restore_ops.append(restore_op) assert name not in self.restore_ops_by_name self.restore_ops_by_name[name] = restore_op return restore_ops
def benchmark_raw_restore(self): checkpoint_path = _save_checkpoint() all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader( checkpoint_path).get_variable_to_dtype_map().items()) def _call_restore_v2(): gen_io_ops.restore_v2(checkpoint_path, all_names, [""] * len(all_names), all_dtypes) self._run(_call_restore_v2, 3)
def __init__(self, mapping={}) -> None: checkpoint = data_utils.get_file( DIR_NAME, WEIGHTS_URL, untar=True, extract=True, cache_subdir="deep_learning") checkpoint = os.path.join(checkpoint, "variables") checkpoint_file = os.path.join(checkpoint, "checkpoint") if not os.path.isfile(checkpoint_file): local_ckpt_file = os.path.join( os.path.dirname(__file__), 'checkpoint') copyfile(local_ckpt_file, checkpoint_file) ckpt_path = tf.train.latest_checkpoint(checkpoint) self.reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path) self.mapping = mapping
def v1_to_v2(model_path, save_path): """Load the GPT2 model from TF1 checkpoint file and save it using TF2""" hparams = {} # Load hyperparameters with open(path_join(model_path, "hparams.json"), "r") as file: hparams = json.load(file) # Initialize the GPT2 model gpt2 = GPT2(hparams["n_layer"], hparams["n_head"], hparams["n_vocab"], hparams["n_ctx"], hparams["n_embd"]) # Build the model using fake input fake_input = tf.constant([0], shape=[1, 1], dtype=tf.int32) _ = gpt2(fake_input) # Get the checkpoint containing the variables ckpt = tf.train.latest_checkpoint(model_path) # Get the checkpoint reader reader = py_checkpoint_reader.NewCheckpointReader(ckpt) # Load the variables load_weights("model", ["wte", "wpe"], gpt2.word_embedder, reader) load_weights("model/ln_f", ["g", "b"], gpt2.final_norm, reader) for layer_index in range(hparams["n_layer"]): load_weights("model/h%d/attn/c_attn" % layer_index, ["w", "b"], gpt2.blocks[layer_index].attn.expander, reader) load_weights("model/h%d/attn/c_proj" % layer_index, ["w", "b"], gpt2.blocks[layer_index].attn.compressor, reader) load_weights("model/h%d/ln_1" % layer_index, ["g", "b"], gpt2.blocks[layer_index].attn_norm, reader) load_weights("model/h%d/mlp/c_fc" % layer_index, ["w", "b"], gpt2.blocks[layer_index].position_wise.dense1, reader) load_weights("model/h%d/mlp/c_proj" % layer_index, ["w", "b"], gpt2.blocks[layer_index].position_wise.dense2, reader) load_weights("model/h%d/ln_2" % layer_index, ["g", "b"], gpt2.blocks[layer_index].position_wise_norm, reader) # Save model v2 save_model(gpt2, save_path)
def load_checkpoint(ckpt_dir_or_file): """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, reader for the latest checkpoint is returned. Args: ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint file. Returns: `CheckpointReader` object. Raises: ValueError: If `ckpt_dir_or_file` resolves to a directory with no checkpoints. """ filename = _get_checkpoint_filename(ckpt_dir_or_file) if filename is None: raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " "given directory %s" % ckpt_dir_or_file) return py_checkpoint_reader.NewCheckpointReader(filename)
def __init__(self, save_path): """Configure the checkpoint view. Args: save_path: The path to the checkpoint. Raises: ValueError: If the save_path does not lead to a TF2 checkpoint. """ reader = py_checkpoint_reader.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError as not_found_error: raise ValueError( f"The specified checkpoint \"{save_path}\" does not appear to be " "object-based (saved with TF2) since it is missing the key " f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " "TF1 name-based saver and does not contain an object dependency graph." ) from not_found_error object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) self._object_graph_proto = object_graph_proto
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_denylist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_denylist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): raise ValueError("Input checkpoint '" + input_checkpoint + "' doesn't exist!") if not output_node_names: raise ValueError( "You need to supply the name of a node to --output_node_names.") # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_partition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_partition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: raise ValueError( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): raise ValueError( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice." ) return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_denylist = (variable_names_denylist.replace( " ", "").split(",") if variable_names_denylist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_denylist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_denylist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def restore(self, save_path, options=None): """Restore a training checkpoint with host mesh placement.""" options = options or checkpoint_options.CheckpointOptions() if save_path is None: return util.InitializationOnlyStatus(self._graph_view, ops.uid()) reader = py_checkpoint_reader.NewCheckpointReader(save_path) graph_building = not context.executing_eagerly() if graph_building: dtype_map = None else: dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access save_path=save_path, dtype_map=dtype_map) if not graph_building: for existing_trackable in self._graph_view.list_objects(): # pylint: disable=protected-access existing_trackable._maybe_initialize_trackable() existing_trackable._name_based_restores.add( restore_coordinator) existing_trackable._name_based_attribute_restore( restore_coordinator) # pylint: enable=protected-access return util.NameBasedSaverStatus(restore_coordinator, graph_view=self._graph_view) if graph_building: if self._file_prefix_placeholder is None: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. self._file_prefix_placeholder = api.pack( [constant_op.constant("model")] * self._mesh.num_local_devices(), layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: # DTensor change: provide a hint for mesh broadcasting to put the input # onto the host mesh. file_prefix_tensor = api.pack([constant_op.constant(save_path)] * self._mesh.num_local_devices(), layout.Layout.replicated( self._mesh.host_mesh(), rank=0)) file_prefix_feed_dict = None object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) # DTensor Change: Hook the proper DSaver in restore. checkpoint = _DCheckpointRestoreCoordinator( mesh=self._mesh, object_graph_proto=object_graph_proto, save_path=save_path, save_path_tensor=file_prefix_tensor, reader=reader, restore_op_cache=self._restore_op_cache, graph_view=self._graph_view, options=options, saveables_cache=self._saveables_cache) base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) # Attached dependencies are not attached to the root, so should be restored # separately. if self._graph_view.attached_dependencies: for ref in self._graph_view.attached_dependencies: if ref.name == "root": # Root dependency is automatically added to attached dependencies -- # this can be ignored since it maps back to the root object. continue proto_id = None # Find proto ID of attached dependency (if it is in the proto). for proto_ref in object_graph_proto.nodes[0].children: if proto_ref.local_name == ref.name: proto_id = proto_ref.node_id break if proto_id in checkpoint.object_by_proto_id: # Object has already been restored. This can happen when there's an # indirect connection from the attached object to the root. continue base.CheckpointPosition(checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) load_status = util.CheckpointLoadStatus( checkpoint, graph_view=self._graph_view, feed_dict=file_prefix_feed_dict) return load_status
import sys import tensorflow as tf import numpy as np #from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file from tensorflow.python.training import py_checkpoint_reader f0 = 'iter9-parity-fullattn-padding-mask-adjust-1217-noTH2s-r0.1-20200311-101704' f1 = 'iter9-parity-fullattn-padding-mask-adjust-dropout-1217-noTH2s-r0.1-20200311-201403' f2 = 'iter9-parity-lshattn-padding-mask-hash2-1217-noTH2s-r0.1-20200311-101815' print(sys.argv) latest_ckp = tf.train.latest_checkpoint(f'running_center/{f2}/model') np.set_printoptions(threshold=1024 * 10) #print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='') reader = py_checkpoint_reader.NewCheckpointReader(latest_ckp) var_to_shape_map = reader.get_variable_to_shape_map() var_to_dtype_map = reader.get_variable_to_dtype_map() for key, value in sorted(var_to_shape_map.items()): print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value)) print(reader.get_tensor(key))