示例#1
0
def test_check_random_state():
    # seed is None
    rng_type = type(np.random.RandomState(10))
    rng = _check_random_state(None)
    assert type(rng) == rng_type

    # seed is int
    rng = _check_random_state(10)
    assert type(rng) == rng_type

    # seed is RandomState
    rng_test = np.random.RandomState(10)
    rng = _check_random_state(rng_test)
    assert type(rng) == rng_type

    # seed is none of the above : error
    pytest.raises(ValueError, _check_random_state, 'random')
示例#2
0
def test_sample_choose_weighted():
    # make sure probabilities are factored in
    rng = _check_random_state(0)
    assert _sample_choose_weighted([0, 1, 2], [1, 0, 0], rng) == 0
    assert _sample_choose_weighted([0, 1, 2], [0, 1, 0], rng) == 1
    assert _sample_choose_weighted([0, 1, 2], [0, 0, 1], rng) == 2

    samples = []
    for _ in range(100000):
        samples.append(_sample_choose_weighted([0, 1], [0.3, 0.7], rng))
    
    samples = np.asarray(samples)
    zero_ratio = (samples == 0).sum() / len(samples)
    one_ratio = (samples == 1).sum() / len(samples)
    assert np.allclose(zero_ratio, 0.3, atol=1e-2)
    assert np.allclose(one_ratio, 0.7, atol=1e-2)
示例#3
0
def test_sample_trunc_norm():
    '''
    Should return values from a truncated normal distribution.

    '''
    rng = _check_random_state(0)
    # sample values from a distribution
    mu, sigma, trunc_min, trunc_max = 2, 1, 0, 5
    x = [_sample_trunc_norm(mu, sigma, trunc_min, trunc_max, random_state=rng) for _ in range(100000)]
    x = np.asarray(x)

    # simple check: values must be within truncated bounds
    assert (x >= trunc_min).all() and (x <= trunc_max).all()

    # trickier check: values must approximate distribution's PDF
    hist, bins = np.histogram(x, bins=np.arange(0, 10.1, 0.2), density=True)
    xticks = bins[:-1] + 0.1
    a, b = (trunc_min - mu) / float(sigma), (trunc_max - mu) / float(sigma)
    trunc_closed = truncnorm.pdf(xticks, a, b, mu, sigma)
    assert np.allclose(hist, trunc_closed, atol=0.015)
示例#4
0
def test_sample_choose():
    # using choose with duplicates will issue a warning
    rng = _check_random_state(0)
    pytest.warns(ScaperWarning, _sample_choose, [0, 1, 2, 2, 2], rng)