def load_buffer(buffer_path: str) -> ReplayBuffer: with h5py.File(buffer_path, "r") as dataset: buffer = ReplayBuffer.from_data(obs=dataset["observations"], act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], obs_next=dataset["next_observations"]) return buffer
def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer: dataset = d4rl.qlearning_dataset(gym.make(expert_data_task)) replay_buffer = ReplayBuffer.from_data( obs=dataset["observations"], act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], obs_next=dataset["next_observations"]) return replay_buffer
def test_from_data(): obs_data = np.ndarray((10, 3, 3), dtype="uint8") for i in range(10): obs_data[i] = i * np.ones((3, 3), dtype="uint8") obs_next_data = np.zeros_like(obs_data) obs_next_data[:-1] = obs_data[1:] f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) with h5py.File(path, "w") as f: obs = f.create_dataset("obs", data=obs_data) act = f.create_dataset("act", data=np.arange(10, dtype="int32")) rew = f.create_dataset("rew", data=np.arange(10, dtype="float32")) done = f.create_dataset("done", data=np.zeros(10, dtype="bool")) obs_next = f.create_dataset("obs_next", data=obs_next_data) buf = ReplayBuffer.from_data(obs, act, rew, done, obs_next) assert len(buf) == 10 batch = buf[3] assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8")) assert batch.act == 3 assert batch.rew == 3.0 assert not batch.done assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8")) os.remove(path)