def get_filled_buffer_frame_stack(frame_stack=4, frame_dim=1):
    """ Return a preinitialized buffer with frame stack implemented """
    observation_space = gym.spaces.Box(low=0, high=255, shape=(2, 2, frame_dim * frame_stack), dtype=int)
    action_space = gym.spaces.Discrete(4)

    buffer = CircularVecEnvBufferBackend(
        buffer_capacity=20, num_envs=2, observation_space=observation_space, action_space=action_space,
        frame_stack_compensation=True, frame_history=frame_stack
    )

    v1 = np.ones(8 * frame_dim).reshape((2, 2, 2, frame_dim))
    done_set = {2, 5, 10, 13, 18, 22, 28}

    # simple buffer of previous frames to simulate frame stack
    item_array = []

    for i in range(30):
        item = v1.copy()

        item[:, 0] *= (i+1)
        item[:, 1] *= 10 * (i+1)

        done_array = np.array([i in done_set, (i+1) in done_set], dtype=bool)

        item_array.append(item)

        if len(item_array) < frame_stack:
            item_concatenated = np.concatenate([item] * frame_stack, axis=-1)
        else:
            item_concatenated = np.concatenate(item_array[-frame_stack:], axis=-1)

        buffer.store_transition(item_concatenated, 0, float(i) / 2, done_array)

    return buffer
def get_filled_buffer(frame_history=1):
    """ Return simple preinitialized buffer """
    observation_space = gym.spaces.Box(low=0, high=255, shape=(2, 2, 1), dtype=int)
    action_space = gym.spaces.Discrete(4)

    buffer = CircularVecEnvBufferBackend(
        20, num_envs=2, observation_space=observation_space, action_space=action_space, frame_history=frame_history
    )

    v1 = np.ones(8).reshape((2, 2, 2, 1))

    for i in range(30):
        item = v1.copy()
        item[0] *= (i+1)
        item[1] *= 10 * (i+1)

        buffer.store_transition(item, 0, float(i)/2, False)

    return buffer
def get_filled_buffer_with_dones(frame_history=1):
    """ Return simple preinitialized buffer with some done's in there """
    observation_space = gym.spaces.Box(low=0, high=255, shape=(2, 2, 1), dtype=int)
    action_space = gym.spaces.Discrete(4)

    buffer = CircularVecEnvBufferBackend(
        20, num_envs=2, observation_space=observation_space, action_space=action_space, frame_history=frame_history
    )

    v1 = np.ones(8).reshape((2, 2, 2, 1))

    done_set = {2, 5, 10, 13, 18, 22, 28}

    for i in range(30):
        item = v1.copy()
        item[0] *= (i+1)
        item[1] *= 10 * (i+1)

        done_array = np.array([i in done_set, (i+1) in done_set], dtype=bool)
        buffer.store_transition(item, 0, float(i)/2, done_array)

    return buffer
def test_buffer_filling_size():
    """ Check if buffer size is properly updated when we add items """
    observation_space = gym.spaces.Box(low=0, high=255, shape=(2, 2, 1), dtype=int)
    action_space = gym.spaces.Discrete(4)
    buffer = CircularVecEnvBufferBackend(20, num_envs=2, observation_space=observation_space, action_space=action_space)

    v1 = np.ones(8).reshape((2, 2, 2, 1))

    assert buffer.current_size == 0

    buffer.store_transition(v1, 0, 0, False)
    buffer.store_transition(v1, 0, 0, False)

    assert buffer.current_size == 2

    for i in range(30):
        buffer.store_transition(v1 * (i+1), 0, float(i)/2, False)

    assert buffer.current_size == buffer.buffer_capacity
def test_simple_get_frame():
    """ Check if get_frame returns frames from a buffer partially full """
    observation_space = gym.spaces.Box(low=0, high=255, shape=(2, 2, 1), dtype=int)
    action_space = gym.spaces.Discrete(4)
    buffer = CircularVecEnvBufferBackend(
        20, num_envs=2, observation_space=observation_space, action_space=action_space, frame_history=4
    )

    v1 = np.ones(8).reshape((2, 2, 2, 1))
    v1[1] *= 2

    v2 = v1 * 2
    v3 = v1 * 3

    buffer.store_transition(v1, 0, 0, False)
    buffer.store_transition(v2, 0, 0, False)
    buffer.store_transition(v3, 0, 0, False)

    assert np.all(buffer.get_frame(0, 0).max(0).max(0) == np.array([0, 0, 0, 1]))
    assert np.all(buffer.get_frame(1, 0).max(0).max(0) == np.array([0, 0, 1, 2]))
    assert np.all(buffer.get_frame(2, 0).max(0).max(0) == np.array([0, 1, 2, 3]))

    assert np.all(buffer.get_frame(0, 1).max(0).max(0) == np.array([0, 0, 0, 2]))
    assert np.all(buffer.get_frame(1, 1).max(0).max(0) == np.array([0, 0, 2, 4]))
    assert np.all(buffer.get_frame(2, 1).max(0).max(0) == np.array([0, 2, 4, 6]))

    with pytest.raises(VelException):
        buffer.get_frame(3, 0)

    with pytest.raises(VelException):
        buffer.get_frame(4, 0)

    with pytest.raises(VelException):
        buffer.get_frame(3, 1)

    with pytest.raises(VelException):
        buffer.get_frame(4, 1)