Example #1
0
def _serialize_checkpointables(checkpointable_objects, node_ids, object_names,
                               slot_variables):
    """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
    object_graph_proto = (
        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
    named_saveables = {}

    for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
        assert node_ids[checkpointable] == checkpoint_id
        object_proto = object_graph_proto.nodes.add()
        object_proto.slot_variables.extend(
            slot_variables.get(checkpointable, ()))
        object_name = object_names[checkpointable]
        for name, saveable in (
                checkpointable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
            attribute = object_proto.attributes.add()
            attribute.name = name
            attribute.checkpoint_key = "%s/%s/%s" % (
                object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
            # Figure out the name-based Saver's name for this variable.
            saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
                [saveable], convert_variable_to_tensor=False)
            attribute.full_name, = saver_dict.keys()
            named_saveables[attribute.checkpoint_key] = saveable

        for child in checkpointable._checkpoint_dependencies:  # pylint: disable=protected-access
            child_proto = object_proto.children.add()
            child_proto.node_id = node_ids[child.ref]
            child_proto.local_name = child.name

    return named_saveables, object_graph_proto
Example #2
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph_string = checkpoint_utils.load_variable(
        config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH")
    object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph()
    object_graph.ParseFromString(object_graph_string)
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
def _serialize_object_graph(root_checkpointable):
    """Determine checkpoint keys for variables and build a serialized graph.

  Non-slot variables are keyed based on a shortest path from the root saveable
  to the object which owns the variable (i.e. the one which called
  `Checkpointable.add_variable` to create it).

  Slot variables are keyed based on a shortest path to the variable being
  slotted for, a shortest path to their optimizer, and the slot name.

  Args:
    root_checkpointable: A `Checkpointable` object whose variables (including
      the variables of dependencies, recursively) should be saved.

  Returns:
    A tuple of (named_variables, object_graph_proto):
      named_variables: A dictionary mapping names to variable objects.
      object_graph_proto: A CheckpointableObjectGraph protocol buffer containing
        the serialized object graph and variable references.

  Raises:
    ValueError: If there are invalid characters in an optimizer's slot names.
  """
    checkpointable_objects, path_to_root = (
        _breadth_first_checkpointable_traversal(root_checkpointable))
    object_graph_proto = (
        checkpointable_object_graph_pb2.CheckpointableObjectGraph())

    # Gather non-slot variables.
    named_variables, non_slot_variables = _serialize_non_slot_variables(
        checkpointable_objects, path_to_root, object_graph_proto)

    # Gather slot variables which are associated with variables gathered above.
    named_slot_variables = _serialize_slot_variables(checkpointable_objects,
                                                     path_to_root,
                                                     non_slot_variables,
                                                     object_graph_proto)

    named_variables.update(named_slot_variables)
    return named_variables, object_graph_proto
Example #4
0
    def restore(self, save_path, session=None):
        """Restore a training checkpoint.

    Restores `root_checkpointable` and any objects that it tracks
    (transitive). Either assigns values immediately if variables to restore have
    been created already, or defers restoration until the variables are
    created. Dependencies added to the `root_checkpointable` passed to the
    constructor after this call will be matched if they have a corresponding
    object in the checkpoint.

    When building a graph, restorations are added to the graph but not run. A
    session is required to retrieve checkpoint metadata.

    To disallow deferred loading, assert immediately that all checkpointed
    variables have been matched to variable objects:

    ```python
    saver = Saver(root)
    saver.restore(path).assert_consumed()
    ```

    An exception will be raised unless every object was matched and its
    variables already exist.

    When graph building, `assert_consumed()` indicates that all of the restore
    ops which will be created for this checkpoint have been created. They can be
    run via the `run_restore_ops()` function of the status object:

    ```python
    saver.restore(path).assert_consumed().run_restore_ops()
    ```

    If the checkpoint has not been consumed completely, then the list of restore
    ops will grow as more objects are added to the dependency graph.

    Name-based `tf.train.Saver` checkpoints can be loaded using this
    method. There is no deferred loading, and names are used to match
    variables. No restore ops are created/run until `run_restore_ops()` or
    `initialize_or_restore()` are called on the returned status object, even
    when executing eagerly. Re-encode name-based checkpoints using this
    object-based `Saver.save` as soon as possible.

    Args:
      save_path: The path to the checkpoint, as returned by `save` or
        `tf.train.latest_checkpoint`. If None (as when there is no latest
        checkpoint for `tf.train.latest_checkpoint` to return), returns an
        object which may run initializers for objects in the dependency
        graph. If the checkpoint was written by the name-based `tf.train.Saver`,
        names are used to match variables.
      session: The session to retrieve metadata with. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.

    Returns:
      A load status object, which can be used to make assertions about the
      status of checkpoint restoration and run initialization/restore ops
      (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
      `save_path` is `None`).

      If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
      object is returned which runs restore ops from a name-based saver.
    """
        if save_path is None:
            return InitializationOnlyStatus(self._root_checkpointable)
        in_graph_mode = context.in_graph_mode()
        if in_graph_mode:
            if session is None:
                session = ops.get_default_session()
            file_prefix_tensor = self._file_prefix_placeholder
            file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
        else:
            session = None
            file_prefix_tensor = constant_op.constant(save_path)
            file_prefix_feed_dict = None
        try:
            if not in_graph_mode or self._object_graph_restore_tensor is None:
                object_graph_string, = io_ops.restore_v2(
                    prefix=file_prefix_tensor,
                    tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
                    shape_and_slices=[""],
                    dtypes=[dtypes.string],
                    name="object_graph_proto_read")
                if in_graph_mode:
                    self._object_graph_restore_tensor = object_graph_string
            if in_graph_mode:
                object_graph_string = session.run(
                    self._object_graph_restore_tensor,
                    feed_dict=file_prefix_feed_dict)
            else:
                object_graph_string = object_graph_string.numpy()
        except errors_impl.NotFoundError:
            # The object graph proto does not exist in this checkpoint. Try again with
            # name-based saving.
            return NameBasedSaverStatus(self, save_path)

        object_graph_proto = (
            checkpointable_object_graph_pb2.CheckpointableObjectGraph())
        object_graph_proto.ParseFromString(object_graph_string)
        if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
            checkpoint = self._last_restore_checkpoint
        else:
            if in_graph_mode:
                dtype_map = None
            else:
                reader = pywrap_tensorflow.NewCheckpointReader(save_path)
                dtype_map = reader.get_variable_to_dtype_map()
            checkpoint = core_checkpointable_utils._Checkpoint(  # pylint: disable=protected-access
                object_graph_proto=object_graph_proto,
                save_path=file_prefix_tensor,
                dtype_map=dtype_map)
            if in_graph_mode:
                if self._last_restore_object_graph is not None:
                    raise NotImplementedError(
                        "Using a single Saver to restore different object graphs is not "
                        "currently supported when graph building. Use a different Saver "
                        "for each object graph (restore ops will be duplicated), or "
                        "file a feature request if this limitation bothers you."
                    )
                self._last_restore_checkpoint = checkpoint
                self._last_restore_object_graph = object_graph_proto
        core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
            checkpoint=checkpoint,
            proto_id=0).restore(self._root_checkpointable)
        load_status = CheckpointLoadStatus(checkpoint,
                                           feed_dict=file_prefix_feed_dict)
        return load_status
Example #5
0
def restore(save_path, root_checkpointable, session=None):
    """Restore a training checkpoint.

  Restores the values of variables created with `Checkpointable._add_variable`
  in `root_checkpointable` and any objects that it tracks (transitive). Either
  assigns values immediately if variables to restore have been created already,
  or defers restoration until the variables are created. Dependencies added to
  `root_checkpointable` after this call will be matched if they have a
  corresponding object in the checkpoint.

  When building a graph, restorations are added to the graph but not run. A
  session is required to retrieve checkpoint metadata.

  To disallow deferred loading, assert immediately that all checkpointed
  variables have been matched to variable objects:

  ```python
  restore(path, root).assert_consumed()
  ```

  An exception will be raised unless every object was matched and its variables
  already exist.

  When graph building, `assert_consumed()` indicates that all of the restore ops
  which will be created for this checkpoint have been created. They are
  available in the `restore_ops` property of the status object:

  ```python
  session.run(restore(path, root).assert_consumed().restore_ops)
  ```

  If the checkpoint has not been consumed completely, then the list of
  `restore_ops` will grow as more objects are added to the dependency graph.

  Args:
    save_path: The path to the checkpoint, as returned by `save` or
      `tf.train.latest_checkpoint`. If None (as when there is no latest
      checkpoint for `tf.train.latest_checkpoint` to return), does nothing.
    root_checkpointable: The root of the object graph to restore. Variables to
      restore need not have been created yet, but all dependencies on other
      `Checkpointable` objects should already be declared. Objects in the
      dependency graph are matched to objects in the checkpointed graph, and
      matching objects have their variables restored (or the checkpointed values
      saved for eventual restoration when the variable is created).
    session: The session to retrieve metadata with. Ignored when executing
      eagerly. If not provided when graph building, the default session is used.
  Returns:
    A `CheckpointLoadStatus` object, which can be used to make assertions about
    the status of checkpoint restoration and fetch restore ops.
  """
    if save_path is None:
        return
    if context.in_graph_mode():
        if session is None:
            session = ops.get_default_session()
    else:
        session = None
    object_graph_string, = io_ops.restore_v2(
        prefix=save_path,
        tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
        shape_and_slices=[""],
        dtypes=[dtypes.string],
        name="object_graph_proto_read")
    if session is not None:
        object_graph_string = session.run(object_graph_string)
    else:
        object_graph_string = object_graph_string.numpy()
    object_graph_proto = (
        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
    object_graph_proto.ParseFromString(object_graph_string)
    checkpoint = core_checkpointable._Checkpoint(  # pylint: disable=protected-access
        object_graph_proto=object_graph_proto,
        save_path=save_path)
    core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
        checkpoint=checkpoint, proto_id=0).restore(root_checkpointable)
    load_status = CheckpointLoadStatus(checkpoint)
    return load_status