def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Save the replay buffer as a pickle file. :param path: Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary. """ assert self.replay_buffer is not None, "The replay buffer is not defined" save_to_pkl(path, self.replay_buffer, self.verbose)
def test_open_file_str_pathlib(tmp_path, pathtype): # check that suffix isn't added because we used open_path first with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" assert not record # test custom suffix with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" assert not record # test without suffix with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" assert not record # test that a warning is raised when the path doesn't exist with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl( open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 with pytest.warns(None) as record: assert load_from_pkl( open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 fp = pathlib.Path(f"{tmp_path}/t2").open("w") fp.write("rubbish") fp.close() # test that a warning is only raised when verbose = 0 with pytest.warns(None) as record: open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() assert len(record) == 1