Exemple #1
0
def read_ckpt(ckpt):
    reader = py_checkpoint_reader.NewCheckpointReader(ckpt)
    weights = {
        n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
        for (n, _) in reader.get_variable_to_shape_map().items()
    }
    return weights
Exemple #2
0
def print_tensors_in_checkpoint_file(file_name,
                                     tensor_name,
                                     all_tensors,
                                     all_tensor_names=False,
                                     count_exclude_pattern=""):
    """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
    all_tensors: Boolean indicating whether to print all tensors.
    all_tensor_names: Boolean indicating whether to print all tensor names.
    count_exclude_pattern: Regex string, pattern to exclude tensors when count.
  """
    try:
        reader = py_checkpoint_reader.NewCheckpointReader(file_name)
        if all_tensors or all_tensor_names:
            var_to_shape_map = reader.get_variable_to_shape_map()
            var_to_dtype_map = reader.get_variable_to_dtype_map()
            for key, value in sorted(var_to_shape_map.items()):
                print("tensor: %s (%s) %s" %
                      (key, var_to_dtype_map[key].name, value))
                if all_tensors:
                    print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8", errors="ignore"))
        else:
            if not reader.has_tensor(tensor_name):
                print("Tensor %s not found in checkpoint" % tensor_name)
                return

            var_to_shape_map = reader.get_variable_to_shape_map()
            var_to_dtype_map = reader.get_variable_to_dtype_map()
            print("tensor: %s (%s) %s" %
                  (tensor_name, var_to_dtype_map[tensor_name].name,
                   var_to_shape_map[tensor_name]))
            print(reader.get_tensor(tensor_name))

        # Count total number of parameters
        print("# Total number of params: %d" % _count_total_params(
            reader, count_exclude_pattern=count_exclude_pattern))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.")
        if ("Data loss" in str(e)
                and any(e in file_name for e in [".index", ".meta", ".data"])):
            proposed_file = ".".join(file_name.split(".")[0:-1])
            v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*.  Try removing the '.' and extension.  Try:
inspect checkpoint --file_name = {}"""
            print(v2_file_error_template.format(proposed_file))
Exemple #3
0
    def restore_saveables(
        self,
        tensor_saveables: Dict[str, saveable_object.SaveableObject],
        python_saveables: List[base.PythonStateSaveable],
        registered_savers: Optional[Dict[str, Dict[str,
                                                   base.Trackable]]] = None
    ) -> Optional[List[ops.Operation]]:
        """Run or build restore operations for SaveableObjects.

    Args:
      tensor_saveables: `SaveableObject`s which correspond to Tensors.
      python_saveables: `PythonStateSaveable`s which correspond to Python
        values.
      registered_savers: a dict mapping saver names-> object name -> Trackable.
        This argument is not implemented for DTensorCheckpoint.

    Returns:
      When graph building, a list of restore operations, either cached or newly
      created, to restore `tensor_saveables`.
    """
        del registered_savers

        restore_ops = []
        # Eagerly run restorations for Python state.
        reader = None
        for saveable in python_saveables:
            if reader is None:
                # Lazily create the NewCheckpointReader, since this requires file access
                # and we may not have any Python saveables.
                reader = py_checkpoint_reader.NewCheckpointReader(
                    self.save_path_string)
            spec_names = [spec.name for spec in saveable.specs]
            saveable.python_restore(
                [reader.get_tensor(name) for name in spec_names])

        # If we have new SaveableObjects, extract and cache restore ops.
        if tensor_saveables:
            validated_saveables = saveable_object_util.validate_and_slice_inputs(
                tensor_saveables)
            validated_names = set(saveable.name
                                  for saveable in validated_saveables)
            if set(tensor_saveables.keys()) != validated_names:
                raise AssertionError(
                    ("Saveable keys changed when validating. Got back %s, was "
                     "expecting %s") %
                    (tensor_saveables.keys(), validated_names))
            # DTensor change: Use _DSaver that does restore on DTensor with
            # customized DTensorRestoreV2 op.
            new_restore_ops = _DSaver(self._mesh, validated_saveables).restore(
                self.save_path_tensor, self.options)
            if not context.executing_eagerly():
                for name, restore_op in sorted(new_restore_ops.items()):
                    restore_ops.append(restore_op)
                    assert name not in self.restore_ops_by_name
                    self.restore_ops_by_name[name] = restore_op
        return restore_ops
Exemple #4
0
    def benchmark_raw_restore(self):
        checkpoint_path = _save_checkpoint()
        all_names, all_dtypes = zip(*py_checkpoint_reader.NewCheckpointReader(
            checkpoint_path).get_variable_to_dtype_map().items())

        def _call_restore_v2():
            gen_io_ops.restore_v2(checkpoint_path, all_names,
                                  [""] * len(all_names), all_dtypes)

        self._run(_call_restore_v2, 3)
 def __init__(self, mapping={}) -> None:
     checkpoint = data_utils.get_file(
         DIR_NAME, WEIGHTS_URL, untar=True, extract=True, cache_subdir="deep_learning")
     checkpoint = os.path.join(checkpoint, "variables")
     checkpoint_file = os.path.join(checkpoint, "checkpoint")
     if not os.path.isfile(checkpoint_file):
         local_ckpt_file = os.path.join(
             os.path.dirname(__file__), 'checkpoint')
         copyfile(local_ckpt_file, checkpoint_file)
     ckpt_path = tf.train.latest_checkpoint(checkpoint)
     self.reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
     self.mapping = mapping
def v1_to_v2(model_path, save_path):
    """Load the GPT2 model from TF1 checkpoint file and save it using TF2"""

    hparams = {}
    # Load hyperparameters
    with open(path_join(model_path, "hparams.json"), "r") as file:
        hparams = json.load(file)

    # Initialize the GPT2 model
    gpt2 = GPT2(hparams["n_layer"], hparams["n_head"], hparams["n_vocab"],
                hparams["n_ctx"], hparams["n_embd"])

    # Build the model using fake input
    fake_input = tf.constant([0], shape=[1, 1], dtype=tf.int32)
    _ = gpt2(fake_input)

    # Get the checkpoint containing the variables
    ckpt = tf.train.latest_checkpoint(model_path)

    # Get the checkpoint reader
    reader = py_checkpoint_reader.NewCheckpointReader(ckpt)

    # Load the variables
    load_weights("model", ["wte", "wpe"], gpt2.word_embedder, reader)
    load_weights("model/ln_f", ["g", "b"], gpt2.final_norm, reader)

    for layer_index in range(hparams["n_layer"]):
        load_weights("model/h%d/attn/c_attn" % layer_index, ["w", "b"],
                     gpt2.blocks[layer_index].attn.expander, reader)
        load_weights("model/h%d/attn/c_proj" % layer_index, ["w", "b"],
                     gpt2.blocks[layer_index].attn.compressor, reader)
        load_weights("model/h%d/ln_1" % layer_index, ["g", "b"],
                     gpt2.blocks[layer_index].attn_norm, reader)

        load_weights("model/h%d/mlp/c_fc" % layer_index, ["w", "b"],
                     gpt2.blocks[layer_index].position_wise.dense1, reader)
        load_weights("model/h%d/mlp/c_proj" % layer_index, ["w", "b"],
                     gpt2.blocks[layer_index].position_wise.dense2, reader)
        load_weights("model/h%d/ln_2" % layer_index, ["g", "b"],
                     gpt2.blocks[layer_index].position_wise_norm, reader)

    # Save model v2
    save_model(gpt2, save_path)
Exemple #7
0
def load_checkpoint(ckpt_dir_or_file):
    """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.

  If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
  reader for the latest checkpoint is returned.

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
      file.

  Returns:
    `CheckpointReader` object.

  Raises:
    ValueError: If `ckpt_dir_or_file` resolves to a directory with no
      checkpoints.
  """
    filename = _get_checkpoint_filename(ckpt_dir_or_file)
    if filename is None:
        raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
                         "given directory %s" % ckpt_dir_or_file)
    return py_checkpoint_reader.NewCheckpointReader(filename)
  def __init__(self, save_path):
    """Configure the checkpoint view.

    Args:
      save_path: The path to the checkpoint.

    Raises:
      ValueError: If the save_path does not lead to a TF2 checkpoint.
    """

    reader = py_checkpoint_reader.NewCheckpointReader(save_path)
    try:
      object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
    except errors_impl.NotFoundError as not_found_error:
      raise ValueError(
          f"The specified checkpoint \"{save_path}\" does not appear to be "
          "object-based (saved with TF2) since it is missing the key "
          f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
          "TF1 name-based saver and does not contain an object dependency graph."
      ) from not_found_error
    object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    self._object_graph_proto = object_graph_proto
Exemple #9
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_denylist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_denylist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)
  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        raise ValueError("Input checkpoint '" + input_checkpoint +
                         "' doesn't exist!")

    if not output_node_names:
        raise ValueError(
            "You need to supply the name of a node to --output_node_names.")

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_partition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_partition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    raise ValueError(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    raise ValueError(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_denylist = (variable_names_denylist.replace(
            " ", "").split(",") if variable_names_denylist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_denylist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Exemple #10
0
    def restore(self, save_path, options=None):
        """Restore a training checkpoint with host mesh placement."""
        options = options or checkpoint_options.CheckpointOptions()
        if save_path is None:
            return util.InitializationOnlyStatus(self._graph_view, ops.uid())
        reader = py_checkpoint_reader.NewCheckpointReader(save_path)
        graph_building = not context.executing_eagerly()
        if graph_building:
            dtype_map = None
        else:
            dtype_map = reader.get_variable_to_dtype_map()
        try:
            object_graph_string = reader.get_tensor(
                base.OBJECT_GRAPH_PROTO_KEY)
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try the
            # name-based compatibility mode.
            restore_coordinator = util._NameBasedRestoreCoordinator(  # pylint: disable=protected-access
                save_path=save_path,
                dtype_map=dtype_map)
            if not graph_building:
                for existing_trackable in self._graph_view.list_objects():
                    # pylint: disable=protected-access
                    existing_trackable._maybe_initialize_trackable()
                    existing_trackable._name_based_restores.add(
                        restore_coordinator)
                    existing_trackable._name_based_attribute_restore(
                        restore_coordinator)
                    # pylint: enable=protected-access
            return util.NameBasedSaverStatus(restore_coordinator,
                                             graph_view=self._graph_view)

        if graph_building:
            if self._file_prefix_placeholder is None:
                # DTensor change: provide a hint for mesh broadcasting to put the input
                # onto the host mesh.
                self._file_prefix_placeholder = api.pack(
                    [constant_op.constant("model")] *
                    self._mesh.num_local_devices(),
                    layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            # DTensor change: provide a hint for mesh broadcasting to put the input
            # onto the host mesh.
            file_prefix_tensor = api.pack([constant_op.constant(save_path)] *
                                          self._mesh.num_local_devices(),
                                          layout.Layout.replicated(
                                              self._mesh.host_mesh(), rank=0))
            file_prefix_feed_dict = None
        object_graph_proto = (
            trackable_object_graph_pb2.TrackableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        # DTensor Change: Hook the proper DSaver in restore.
        checkpoint = _DCheckpointRestoreCoordinator(
            mesh=self._mesh,
            object_graph_proto=object_graph_proto,
            save_path=save_path,
            save_path_tensor=file_prefix_tensor,
            reader=reader,
            restore_op_cache=self._restore_op_cache,
            graph_view=self._graph_view,
            options=options,
            saveables_cache=self._saveables_cache)
        base.CheckpointPosition(checkpoint=checkpoint,
                                proto_id=0).restore(self._graph_view.root)

        # Attached dependencies are not attached to the root, so should be restored
        # separately.
        if self._graph_view.attached_dependencies:
            for ref in self._graph_view.attached_dependencies:
                if ref.name == "root":
                    # Root dependency is automatically added to attached dependencies --
                    # this can be ignored since it maps back to the root object.
                    continue
                proto_id = None
                # Find proto ID of attached dependency (if it is in the proto).
                for proto_ref in object_graph_proto.nodes[0].children:
                    if proto_ref.local_name == ref.name:
                        proto_id = proto_ref.node_id
                        break

                if proto_id in checkpoint.object_by_proto_id:
                    # Object has already been restored. This can happen when there's an
                    # indirect connection from the attached object to the root.
                    continue

                base.CheckpointPosition(checkpoint=checkpoint,
                                        proto_id=proto_id).restore(ref.ref)

        load_status = util.CheckpointLoadStatus(
            checkpoint,
            graph_view=self._graph_view,
            feed_dict=file_prefix_feed_dict)
        return load_status
import sys
import tensorflow as tf
import numpy as np
#from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
from tensorflow.python.training import py_checkpoint_reader

f0 = 'iter9-parity-fullattn-padding-mask-adjust-1217-noTH2s-r0.1-20200311-101704'
f1 = 'iter9-parity-fullattn-padding-mask-adjust-dropout-1217-noTH2s-r0.1-20200311-201403'
f2 = 'iter9-parity-lshattn-padding-mask-hash2-1217-noTH2s-r0.1-20200311-101815'

print(sys.argv)

latest_ckp = tf.train.latest_checkpoint(f'running_center/{f2}/model')
np.set_printoptions(threshold=1024 * 10)
#print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
reader = py_checkpoint_reader.NewCheckpointReader(latest_ckp)
var_to_shape_map = reader.get_variable_to_shape_map()
var_to_dtype_map = reader.get_variable_to_dtype_map()
for key, value in sorted(var_to_shape_map.items()):
    print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value))
    print(reader.get_tensor(key))