Beispiel #1
0
def test_buffer(capacity, chunk_len, sample_shape):
  """Builds a Buffer with the provided `capacity` and insert `capacity * 3`
  samples into the buffer in chunks of shape `(chunk_len,) + sample_shape`.
  We always insert chunks with consecutive integers.

  * `len(buffer)` should increase until we reach capacity.
  * `buffer._idx` should loop between 0 and `capacity - 1`.
  * After every insertion, samples should be in expected range, verifying
    FIFO insertion.
  * Mutating the inserted chunk shouldn't mutate the buffer.
  """
  buf = Buffer(capacity,
               sample_shapes={'a': sample_shape, 'b': sample_shape},
               dtypes={'a': float, 'b': float})

  to_insert = 3 * capacity
  for i in range(0, to_insert, chunk_len):
    assert len(buf) == min(i, capacity)
    assert buf._idx == i % capacity
    chunk_a = _fill_chunk(i, chunk_len, sample_shape)
    chunk_b = _fill_chunk(i + to_insert, chunk_len, sample_shape)
    buf.store({'a': chunk_a, 'b': chunk_b})
    samples = buf.sample(100)
    assert set(samples.keys()) == {'a', 'b'}, samples.keys()
    _check_bound(i + chunk_len, capacity, samples['a'])
    _check_bound(i + chunk_len + to_insert, capacity, samples['b'])
    assert np.all(samples['b'] - samples['a'] == to_insert)

    # Confirm that buffer is not mutable from inserted sample.
    chunk_a[:] = np.nan
    chunk_b[:] = np.nan
    assert not np.any(np.isnan(buf._arrays['a']))
    assert not np.any(np.isnan(buf._arrays['b']))
Beispiel #2
0
def test_buffer_from_data():
    data = np.ndarray([50, 30], dtype=bool)
    buf = Buffer.from_data({'k': data})
    assert buf._arrays['k'] is not data
    assert data.dtype == buf._arrays['k'].dtype
    assert np.array_equal(buf._arrays['k'], data)
Beispiel #3
0
def test_buffer_init_errors():
    with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"):
        Buffer(10, dict(a=(2, 1), b=(3, )), dtypes=dict(a="float32", c=bool))
Beispiel #4
0
def test_buffer_sample_errors():
    b = Buffer(10, {'k': (2, 1)}, dtypes={'k': bool})
    with pytest.raises(ValueError):
        b.sample(5)
Beispiel #5
0
 def buf():
     return Buffer(capacity, {'k': sample_shape}, {'k': dtype})