def test_buffer(capacity, chunk_len, sample_shape): """Builds a Buffer with the provided `capacity` and insert `capacity * 3` samples into the buffer in chunks of shape `(chunk_len,) + sample_shape`. We always insert chunks with consecutive integers. * `len(buffer)` should increase until we reach capacity. * `buffer._idx` should loop between 0 and `capacity - 1`. * After every insertion, samples should be in expected range, verifying FIFO insertion. * Mutating the inserted chunk shouldn't mutate the buffer. """ buf = Buffer( capacity, sample_shapes={ "a": sample_shape, "b": sample_shape }, dtypes={ "a": float, "b": float }, ) to_insert = 3 * capacity for i in range(0, to_insert, chunk_len): assert len(buf) == min(i, capacity) assert buf._idx == i % capacity chunk_a = _fill_chunk(i, chunk_len, sample_shape) chunk_b = _fill_chunk(i + to_insert, chunk_len, sample_shape) buf.store({"a": chunk_a, "b": chunk_b}) samples = buf.sample(100) assert set(samples.keys()) == {"a", "b"}, samples.keys() _check_bound(i + chunk_len, capacity, samples["a"]) _check_bound(i + chunk_len + to_insert, capacity, samples["b"]) assert np.all(samples["b"] - samples["a"] == to_insert) # Confirm that buffer is not mutable from inserted sample. chunk_a[:] = np.nan chunk_b[:] = np.nan assert not np.any(np.isnan(buf._arrays["a"])) assert not np.any(np.isnan(buf._arrays["b"]))
def test_buffer_from_data(): data = np.ndarray([50, 30], dtype=bool) buf = Buffer.from_data({"k": data}) assert buf._arrays["k"] is not data assert data.dtype == buf._arrays["k"].dtype assert np.array_equal(buf._arrays["k"], data)
def test_buffer_init_errors(): with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"): Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool))
def test_buffer_sample_errors(): b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool}) with pytest.raises(ValueError): b.sample(5)
def buf(): return Buffer(capacity, {"k": sample_shape}, {"k": dtype})