def test_open_file_str_pathlib(tmp_path, pathtype): # check that suffix isn't added because we used open_path first with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" assert not record # test custom suffix with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" assert not record # test without suffix with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" assert not record # test that a warning is raised when the path doesn't exist with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: assert load_from_pkl( open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 with pytest.warns(None) as record: assert load_from_pkl( open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 fp = pathlib.Path(f"{tmp_path}/t2").open("w") fp.write("rubbish") fp.close() # test that a warning is only raised when verbose = 0 with pytest.warns(None) as record: open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() assert len(record) == 1
def load_replay_buffer( self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_traj: bool = True, ) -> None: """ Load a replay buffer from a pickle file. :param path: Path to the pickled replay buffer. :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling: If set to ``True``, we assume that the last trajectory in the replay buffer was finished (and truncate it). If set to ``False``, we assume that we continue the same trajectory (same episode). """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" # Backward compatibility with SB3 < 2.1.0 replay buffer # Keep old behavior: do not handle timeout termination separately if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover self.replay_buffer.handle_timeout_termination = False self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones) if isinstance(self.replay_buffer, HerReplayBuffer): assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" self.replay_buffer.set_env(self.get_env()) if truncate_last_traj: self.replay_buffer.truncate_last_trajectory()
def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: """ Load a replay buffer from a pickle file. :param path: Path to the pickled replay buffer. """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"