Example #1
0
    def __init__(self, name, state_callback, restore_callback=None):
        """Configure saving.

    Args:
      name: The checkpoint key to write to.
      state_callback: A function taking no arguments which returns a
        string. This function is run every time a checkpoint is written.
      restore_callback: A function taking a Python string, used to restore
        state. Optional; defaults to doing nothing.
    """
        self._restore_callback = restore_callback
        if context.executing_eagerly():
            self._save_string = (lambda: constant_op.constant(
                state_callback(), dtype=dtypes.string))
        else:
            self._save_string = constant_op.constant("", dtype=dtypes.string)
            self.feed_dict_additions = (lambda: {
                self._save_string: state_callback()
            })
        spec = saveable_object.SaveSpec(self._save_string,
                                        "",
                                        name,
                                        dtype=dtypes.string)
        super(PythonStringStateSaveable,
              self).__init__(self._save_string, [spec], name)
Example #2
0
    def __init__(self, name, state_callback, restore_callback=None):
        """Configure saving.

    Args:
      name: The checkpoint key to write to.
      state_callback: A function taking no arguments which returns a
        string. This function is run every time a checkpoint is written.
      restore_callback: A function taking a Python string, used to restore
        state. Optional; defaults to doing nothing.
    """
        self._state_callback = state_callback
        self._restore_callback = restore_callback
        with ops.device("/cpu:0"):
            self._save_string = constant_op.constant("", dtype=dtypes.string)
        spec = saveable_object.SaveSpec(self._save_string,
                                        "",
                                        name,
                                        dtype=dtypes.string)
        super(PythonStringStateSaveable,
              self).__init__(self._save_string, [spec], name)
Example #3
0
 def __init__(self, tensor, name, dtype=None):
     spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
     super(NoRestoreSaveable, self).__init__(tensor, [spec], name)