def __init__(self, name, state_callback, restore_callback=None): """Configure saving. Args: name: The checkpoint key to write to. state_callback: A function taking no arguments which returns a string. This function is run every time a checkpoint is written. restore_callback: A function taking a Python string, used to restore state. Optional; defaults to doing nothing. """ self._restore_callback = restore_callback if context.executing_eagerly(): self._save_string = (lambda: constant_op.constant( state_callback(), dtype=dtypes.string)) else: self._save_string = constant_op.constant("", dtype=dtypes.string) self.feed_dict_additions = (lambda: { self._save_string: state_callback() }) spec = saveable_object.SaveSpec(self._save_string, "", name, dtype=dtypes.string) super(PythonStringStateSaveable, self).__init__(self._save_string, [spec], name)
def __init__(self, name, state_callback, restore_callback=None): """Configure saving. Args: name: The checkpoint key to write to. state_callback: A function taking no arguments which returns a string. This function is run every time a checkpoint is written. restore_callback: A function taking a Python string, used to restore state. Optional; defaults to doing nothing. """ self._state_callback = state_callback self._restore_callback = restore_callback with ops.device("/cpu:0"): self._save_string = constant_op.constant("", dtype=dtypes.string) spec = saveable_object.SaveSpec(self._save_string, "", name, dtype=dtypes.string) super(PythonStringStateSaveable, self).__init__(self._save_string, [spec], name)
def __init__(self, tensor, name, dtype=None): spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype) super(NoRestoreSaveable, self).__init__(tensor, [spec], name)