Exemplo n.º 1
0
def test_gelman_rubin():
    # only need to test precision for small data
    x = np.empty((2, 10))
    x[0, :] = np.arange(10.0)
    x[1, :] = np.arange(10.0) + 1

    r_hat = gelman_rubin(x)
    assert_allclose(r_hat, 0.98, atol=0.01)
Exemplo n.º 2
0
Arquivo: fit.py Projeto: pyro-ppl/brmp
def gelman_rubin(samples):
    if ((samples.shape[0] < 2 and samples.shape[1] < 4)
            or (samples.shape[0] >= 2 and samples.shape[1] < 2)):
        return None  # Too few chains or samples.
    elif samples.shape[0] >= 2:
        return diags.gelman_rubin(samples)
    else:
        return diags.split_gelman_rubin(samples)
Exemplo n.º 3
0
def test_split_gelman_rubin_agree_with_gelman_rubin():
    x = np.random.normal(size=(2, 10))
    r_hat1 = gelman_rubin(x.reshape(2, 2, 5).reshape(4, 5))
    r_hat2 = split_gelman_rubin(x)
    assert_allclose(r_hat1, r_hat2)