checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session( graph=save_graph) as session: root = self._initialized_model() save_path = root.save( session=session, file_prefix=checkpoint_prefix) with context.eager_mode(): root = self._initialized_model() self._set_sentinels(root) root.restore(save_path).assert_consumed() self._check_sentinels(root) def testSaveEagerLoadGraph(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with context.eager_mode(): root = self._initialized_model() save_path = root.save(file_prefix=checkpoint_prefix) with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session(graph=save_graph): root = self._initialized_model() self._set_sentinels(root) root.restore(save_path).assert_consumed().run_restore_ops() self._check_sentinels(root) if __name__ == "__main__": test.main()
def _restore_from_tensors(self, restored_tensors): self.a.assign(restored_tensors["-a"]) self.b.assign(restored_tensors["-b"]) new = NewTrackable() # Test with the checkpoint conversion flag disabled (normal compatibility). saveable_compat.force_checkpoint_conversion(False) checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt") checkpoint.Checkpoint(new).write(checkpoint_path) dep = DeprecatedTrackable() checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() self.assertEqual(3, self.evaluate(dep.a)) self.assertEqual(4, self.evaluate(dep.b)) # Now test with the checkpoint conversion flag enabled (forward compat). # The deprecated object will try to load from the new checkpoint. saveable_compat.force_checkpoint_conversion() checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2") checkpoint.Checkpoint(new).write(checkpoint_path) dep = DeprecatedTrackable() checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() self.assertEqual(3, self.evaluate(dep.a)) self.assertEqual(4, self.evaluate(dep.b)) if __name__ == "__main__": test.main()