def test_replay_buffer_dump(): import torch # Assign filled_buffer = 8 prop_keys = ["state", "action", "reward", "next_state"] buffer = ReplayBuffer(batch_size=5, buffer_size=10) for sars in generate_sample_SARS(filled_buffer): buffer.add(state=torch.tensor(sars[0]), reward=sars[1], action=[sars[2]], next_state=torch.tensor(sars[3]), dones=sars[4]) # Act dump = list(buffer.dump_buffer()) # Assert assert all([len(dump) == filled_buffer]) assert all([key in dump[0] for key in prop_keys])
def test_replay_buffer_dump_serializable(): import json import torch # Assign filled_buffer = 8 buffer = ReplayBuffer(batch_size=5, buffer_size=10) for sars in generate_sample_SARS(filled_buffer, dict_type=True): sars['state'] = torch.tensor(sars['state']) sars['next_state'] = torch.tensor(sars['next_state']) buffer.add(**sars) # Act dump = list(buffer.dump_buffer(serialize=True)) # Assert ser_dump = json.dumps(dump) assert isinstance(ser_dump, str) assert json.loads(ser_dump) == dump
def callback(obs_t, obs_next, action, rew, done, *args, **kwargs): buffer.add(**dict(state=obs_t, action=[action], reward=[rew], done=[done]), next_state=obs_next) return [ rew, ] return callback buffer = ReplayBuffer(10, 2000) callback = buffer_callback(buffer) plotter = PlayPlot(callback, 30 * 5, ["reward"]) env_name = "Breakout-v0" env = gym.make(env_name) env.reset() play(env, fps=20, callback=plotter.callback) t = [] exp_dump = buffer.dump_buffer(serialize=True) t.append(time.time()) with gzip.open('buffer.gzip', 'wt') as f: for exp in exp_dump: f.write(json.dumps(exp)) f.write("\n") t.append(time.time()) print(f"Writing to gzip took: {t[1]-t[0]} s")