def test_ema_decay_reset_nofilter():
    # Create a FastArray instance to call ema_decay() with;
    # the data itself doesn't matter, since we're just looking to
    # check how the function validates preconditions.
    count = 10
    data = rt.ones(count, dtype=np.float32)
    times = rt.ones(count, dtype=np.float32)
    reset = np.full(count, True, dtype=np.bool_)

    # Check that calling ema_decay with a reset mask but no filter raises a warning
    # to notify the user that the reset mask will be ignored.
    with pytest.raises(UserWarning):
        data.ema_decay(times, 1.0, reset=reset)
def test_ema_decay_requires_matching_filter_len():
    count = 10
    data = rt.ones(count, dtype=np.float32)
    times = rt.ones(count, dtype=np.float32)

    # Check that calling ema_decay with a filter whose length is smaller than the input array length
    # fails with an error.
    smaller_filter = np.full(count - 1, True, dtype=np.bool_)
    with pytest.raises(ValueError):
        data.ema_decay(times, 1.0, filter=smaller_filter)

    # Check that calling ema_decay with a filter whose length is larger than the input array length
    # fails with an error.
    larger_filter = np.full(count * 2, False, dtype=np.bool_)
    with pytest.raises(ValueError):
        data.ema_decay(times, 1.0, filter=larger_filter)
def test_ema_decay_requires_matching_time_len():
    count = 10
    data = rt.ones(count, dtype=np.float32)

    # Check that calling ema_decay with a time array whose length is smaller than the input array length
    # fails with an error.
    smaller_time = np.arange(count - 1).astype(np.float32)
    with pytest.raises(ValueError):
        data.ema_decay(smaller_time, 1.0)

    # Check that calling ema_decay with a time array whose length is larger than the input array length
    # fails with an error.
    larger_time = np.arange(count * 2).astype(np.float32)
    with pytest.raises(ValueError):
        data.ema_decay(larger_time, 1.0)
def test_ema_decay(decay_rate, filter, reset, dtype_override, expected):
    data = rt.ones(10)
    times = rt.FastArray([0, 1, 1, 3, 4, 5, 5.5, 10.5, 10.55, 11])

    # Call ema_decay.
    # Don't override the default dtype unless we actually have an override.
    # We don't bother doing this for the other arguments because they're either
    # non-optional or already default to None.
    if dtype_override is None:
        result = data.ema_decay(times, decay_rate, filter=filter, reset=reset)
    else:
        result = data.ema_decay(times,
                                decay_rate,
                                filter=filter,
                                reset=reset,
                                dtype=dtype_override)

    # Check the result against the expected values.
    assert_array_almost_equal(result, expected)