Beispiel #1
0
def _serialize_checkpointables(
    checkpointable_objects, node_ids, object_names, slot_variables):
  """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
  object_graph_proto = (
      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
  named_saveables = {}

  for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
    assert node_ids[checkpointable] == checkpoint_id
    object_proto = object_graph_proto.nodes.add()
    object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
    object_name = object_names[checkpointable]
    for name, saveable_factory in (
        checkpointable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
      attribute = object_proto.attributes.add()
      attribute.name = name
      attribute.checkpoint_key = "%s/%s/%s" % (
          object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
      if callable(saveable_factory):
        saveable = saveable_factory(name=attribute.checkpoint_key)
      else:
        saveable = saveable_factory
      # Figure out the name-based Saver's name for this variable.
      saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
          [saveable], convert_variable_to_tensor=False)
      attribute.full_name, = saver_dict.keys()
      named_saveables[attribute.checkpoint_key] = saveable

    for child in checkpointable._checkpoint_dependencies:  # pylint: disable=protected-access
      child_proto = object_proto.children.add()
      child_proto.node_id = node_ids[child.ref]
      child_proto.local_name = child.name

  return named_saveables, object_graph_proto
Beispiel #2
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        trainer = spinn.train_or_infer_spinn(embed, word2index, train_data,
                                             dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))

        # 5. Verify that checkpoints exist and contains all the expected variables.
        self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
        object_graph_string = checkpoint_utils.load_variable(
            config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
        object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph(
        )
        object_graph.ParseFromString(object_graph_string)
        ckpt_variable_names = set()
        for node in object_graph.nodes:
            for attribute in node.attributes:
                ckpt_variable_names.add(attribute.full_name)
        self.assertIn("global_step", ckpt_variable_names)
        for v in trainer.variables:
            variable_name = v.name[:v.name.
                                   index(":")] if ":" in v.name else v.name
            self.assertIn(variable_name, ckpt_variable_names)
Beispiel #3
0
 def _fill_object_graph_proto(self, checkpointable_objects,
                              node_ids,
                              slot_variables,
                              object_graph_proto=None):
   """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
   if object_graph_proto is None:
     object_graph_proto = (
         checkpointable_object_graph_pb2.CheckpointableObjectGraph())
   for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
     assert node_ids[checkpointable] == checkpoint_id
     object_proto = object_graph_proto.nodes.add()
     object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
     for child in self.list_dependencies(checkpointable):
       child_proto = object_proto.children.add()
       child_proto.node_id = node_ids[child.ref]
       child_proto.local_name = child.name
   return object_graph_proto
Beispiel #4
0
def object_metadata(save_path):
  """Retrieves information about the objects in a checkpoint.

  Example usage:

  ```python
  object_graph = tf.contrib.checkpoint.object_metadata(
      tf.train.latest_checkpoint(checkpoint_directory))
  ckpt_variable_names = set()
  for node in object_graph.nodes:
    for attribute in node.attributes:
      ckpt_variable_names.add(attribute.full_name)
  ```

  Args:
    save_path: The path to the checkpoint, as returned by `save` or
      `tf.train.latest_checkpoint`.
  Returns:
    A parsed `tf.contrib.checkpoint.CheckpointableObjectGraph` protocol buffer.
  Raises:
    ValueError: If an object graph was not found in the checkpoint.
  """
  reader = pywrap_tensorflow.NewCheckpointReader(save_path)
  try:
    object_graph_string = reader.get_tensor(
        checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
  except errors_impl.NotFoundError:
    raise ValueError(
        ('The specified checkpoint "%s" does not appear to be object-based (it '
         'is missing the key "%s"). Likely it was created with a name-based '
         'saver and does not contain an object dependency graph.') % (
             save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
  object_graph_proto = (
      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
  object_graph_proto.ParseFromString(object_graph_string)
  return object_graph_proto
Beispiel #5
0
  def restore(self, save_path):
    """Restore a training checkpoint.

    Restores `root_checkpointable` and any objects that it tracks
    (transitive). Either assigns values immediately if variables to restore have
    been created already, or defers restoration until the variables are
    created. Dependencies added to the `root_checkpointable` passed to the
    constructor after this call will be matched if they have a corresponding
    object in the checkpoint.

    When building a graph, restorations are added to the graph but not run.

    To disallow deferred loading, assert immediately that all checkpointed
    variables have been matched to variable objects:

    ```python
    saver = Saver(root)
    saver.restore(path).assert_consumed()
    ```

    An exception will be raised unless every object was matched and its
    variables already exist.

    When graph building, `assert_consumed()` indicates that all of the restore
    ops which will be created for this checkpoint have been created. They can be
    run via the `run_restore_ops()` function of the status object:

    ```python
    saver.restore(path).assert_consumed().run_restore_ops()
    ```

    If the checkpoint has not been consumed completely, then the list of restore
    ops will grow as more objects are added to the dependency graph.

    Name-based `tf.train.Saver` checkpoints can be loaded using this
    method. There is no deferred loading, and names are used to match
    variables. No restore ops are created/run until `run_restore_ops()` or
    `initialize_or_restore()` are called on the returned status object, even
    when executing eagerly. Re-encode name-based checkpoints using this
    object-based `Saver.save` as soon as possible.

    Args:
      save_path: The path to the checkpoint, as returned by `save` or
        `tf.train.latest_checkpoint`. If None (as when there is no latest
        checkpoint for `tf.train.latest_checkpoint` to return), returns an
        object which may run initializers for objects in the dependency
        graph. If the checkpoint was written by the name-based `tf.train.Saver`,
        names are used to match variables.

    Returns:
      A load status object, which can be used to make assertions about the
      status of checkpoint restoration and run initialization/restore ops
      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
      `save_path` is `None`).

      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
      object is returned which runs restore ops from a name-based saver.
    """
    if save_path is None:
      return InitializationOnlyStatus(self._root_checkpointable, ops.uid())
    in_graph_mode = not context.executing_eagerly()
    if in_graph_mode:
      if self._file_prefix_placeholder is None:
        with ops.device("/cpu:0"):
          self._file_prefix_placeholder = constant_op.constant("model")
      file_prefix_tensor = self._file_prefix_placeholder
      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
    else:
      with ops.device("/cpu:0"):
        file_prefix_tensor = constant_op.constant(save_path)
      file_prefix_feed_dict = None
    reader = pywrap_tensorflow.NewCheckpointReader(save_path)
    try:
      object_graph_string = reader.get_tensor(
          checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
    except errors_impl.NotFoundError:
      # The object graph proto does not exist in this checkpoint. Try again with
      # name-based saving.
      return NameBasedSaverStatus(self, save_path)

    object_graph_proto = (
        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
      checkpoint = self._last_restore_checkpoint
    else:
      if in_graph_mode:
        dtype_map = None
      else:
        dtype_map = reader.get_variable_to_dtype_map()
      checkpoint = _CheckpointRestoreCoordinator(
          object_graph_proto=object_graph_proto,
          save_path=file_prefix_tensor,
          dtype_map=dtype_map)
      if in_graph_mode:
        if self._last_restore_object_graph is not None:
          raise NotImplementedError(
              "Using a single Saver to restore different object graphs is not "
              "currently supported when graph building. Use a different Saver "
              "for each object graph (restore ops will be duplicated), or "
              "file a feature request if this limitation bothers you.")
        self._last_restore_checkpoint = checkpoint
        self._last_restore_object_graph = object_graph_proto
    checkpointable_lib._CheckpointPosition(  # pylint: disable=protected-access
        checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
    load_status = CheckpointLoadStatus(
        checkpoint,
        root_checkpointable=self._root_checkpointable,
        feed_dict=file_prefix_feed_dict)
    return load_status
Beispiel #6
0
def dot_graph_from_checkpoint(save_path):
  r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).

  Useful for inspecting checkpoints and debugging loading issues.

  Example usage from Python (requires pydot):
  ```python
  import tensorflow as tf
  import pydot

  dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
  parsed, = pydot.graph_from_dot_data(dot_string)
  parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
  ```

  Example command line usage:
  ```sh
  python -c "import tensorflow as tf;\
    print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
    | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
  ```

  Args:
    save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
      or `tf.train.latest_checkpoint`.
  Returns:
    A graph in DOT format as a string.
  """
  reader = pywrap_tensorflow.NewCheckpointReader(save_path)
  try:
    object_graph_string = reader.get_tensor(
        checkpointable.OBJECT_GRAPH_PROTO_KEY)
  except errors_impl.NotFoundError:
    raise ValueError(
        ('The specified checkpoint "%s" does not appear to be object-based (it '
         'is missing the key "%s"). Likely it was created with a name-based '
         'saver and does not contain an object dependency graph.') % (
             save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
  shape_map = reader.get_variable_to_shape_map()
  dtype_map = reader.get_variable_to_dtype_map()
  object_graph = (
      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
  object_graph.ParseFromString(object_graph_string)
  graph = 'digraph {\n'
  def _escape(name):
    return name.replace('"', '\\"')
  slot_ids = set()
  for node in object_graph.nodes:
    for slot_reference in node.slot_variables:
      slot_ids.add(slot_reference.slot_variable_node_id)
  for node_id, node in enumerate(object_graph.nodes):
    if (len(node.attributes) == 1
        and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
      if node_id in slot_ids:
        color = 'orange'
        tooltip_prefix = 'Slot variable'
      else:
        color = 'blue'
        tooltip_prefix = 'Variable'
      attribute = node.attributes[0]
      graph += ('N_%d [shape=point label="" color=%s width=.25'
                ' tooltip="%s %s shape=%s %s"]\n') % (
                    node_id,
                    color,
                    tooltip_prefix,
                    _escape(attribute.full_name),
                    shape_map[attribute.checkpoint_key],
                    dtype_map[attribute.checkpoint_key].name)
    elif node.slot_variables:
      graph += ('N_%d [shape=point label="" width=.25 color=red,'
                'tooltip="Optimizer"]\n') % node_id
    else:
      graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
    for reference in node.children:
      graph += 'N_%d -> N_%d [label="%s"]\n' % (
          node_id, reference.node_id, _escape(reference.local_name))
    for slot_reference in node.slot_variables:
      graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
          node_id,
          slot_reference.slot_variable_node_id,
          _escape(slot_reference.slot_name))
      graph += 'N_%d -> N_%d [style=dotted]\n' % (
          slot_reference.original_variable_node_id,
          slot_reference.slot_variable_node_id)
  graph += '}\n'
  return graph