Beispiel #1
0
def test_hdf5():
    size = 100
    buffers = {
        "array": ReplayBuffer(size, stack_num=2),
        "list": ListReplayBuffer(),
        "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    info_t = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': np.array([1, 2]),
            'done': i % 3 == 2,
            'info': {"number": {"n": i, "t": info_t}, 'extra': None},
        }
        buffers["array"].add(**kwargs)
        buffers["list"].add(**kwargs)
        buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k].maxsize == buffers[k].maxsize
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    for k in ["array", "prioritized"]:
        assert _buffers[k]._index == buffers[k]._index
        assert isinstance(buffers[k].get(0, "info"), Batch)
        assert isinstance(_buffers[k].get(0, "info"), Batch)
    for k in ["array"]:
        assert np.all(
            buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
        assert np.all(
            buffers[k][:].info.extra == _buffers[k][:].info.extra)

    # raise exception when value cannot be pickled
    data = {"not_supported": lambda x: x * x}
    grp = h5py.Group
    with pytest.raises(NotImplementedError):
        to_hdf5(data, grp)
    # ndarray with data type not supported by HDF5 that cannot be pickled
    data = {"not_supported": np.array(lambda x: x * x)}
    grp = h5py.Group
    with pytest.raises(RuntimeError):
        to_hdf5(data, grp)
Beispiel #2
0
 def save_hdf5(self, path: str) -> None:
     """Save replay buffer to HDF5 file."""
     with h5py.File(path, "w") as f:
         to_hdf5(self.__dict__, f)
Beispiel #3
0
 def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
     """Save replay buffer to HDF5 file."""
     with h5py.File(path, "w") as f:
         to_hdf5(self.__dict__, f, compression=compression)