示例#1
0
def test_buffer(capacity, chunk_len, sample_shape):
    """Builds a Buffer with the provided `capacity` and insert `capacity * 3`
    samples into the buffer in chunks of shape `(chunk_len,) + sample_shape`.

    We always insert chunks with consecutive integers.

    * `len(buffer)` should increase until we reach capacity.
    * `buffer._idx` should loop between 0 and `capacity - 1`.
    * After every insertion, samples should be in expected range, verifying
      FIFO insertion.
    * Mutating the inserted chunk shouldn't mutate the buffer.
    """
    buf = Buffer(
        capacity,
        sample_shapes={
            "a": sample_shape,
            "b": sample_shape
        },
        dtypes={
            "a": float,
            "b": float
        },
    )

    to_insert = 3 * capacity
    for i in range(0, to_insert, chunk_len):
        assert len(buf) == min(i, capacity)
        assert buf._idx == i % capacity
        chunk_a = _fill_chunk(i, chunk_len, sample_shape)
        chunk_b = _fill_chunk(i + to_insert, chunk_len, sample_shape)
        buf.store({"a": chunk_a, "b": chunk_b})
        samples = buf.sample(100)
        assert set(samples.keys()) == {"a", "b"}, samples.keys()
        _check_bound(i + chunk_len, capacity, samples["a"])
        _check_bound(i + chunk_len + to_insert, capacity, samples["b"])
        assert np.all(samples["b"] - samples["a"] == to_insert)

        # Confirm that buffer is not mutable from inserted sample.
        chunk_a[:] = np.nan
        chunk_b[:] = np.nan
        assert not np.any(np.isnan(buf._arrays["a"]))
        assert not np.any(np.isnan(buf._arrays["b"]))
示例#2
0
def test_buffer_from_data():
    data = np.ndarray([50, 30], dtype=bool)
    buf = Buffer.from_data({"k": data})
    assert buf._arrays["k"] is not data
    assert data.dtype == buf._arrays["k"].dtype
    assert np.array_equal(buf._arrays["k"], data)
示例#3
0
def test_buffer_init_errors():
    with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"):
        Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool))
示例#4
0
def test_buffer_sample_errors():
    b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool})
    with pytest.raises(ValueError):
        b.sample(5)
示例#5
0
 def buf():
     return Buffer(capacity, {"k": sample_shape}, {"k": dtype})