コード例 #1
0
ファイル: fit.py プロジェクト: pyro-ppl/brmp
def gelman_rubin(samples):
    if ((samples.shape[0] < 2 and samples.shape[1] < 4)
            or (samples.shape[0] >= 2 and samples.shape[1] < 2)):
        return None  # Too few chains or samples.
    elif samples.shape[0] >= 2:
        return diags.gelman_rubin(samples)
    else:
        return diags.split_gelman_rubin(samples)
コード例 #2
0
def test_split_gelman_rubin_agree_with_gelman_rubin():
    x = np.random.normal(size=(2, 10))
    r_hat1 = gelman_rubin(x.reshape(2, 2, 5).reshape(4, 5))
    r_hat2 = split_gelman_rubin(x)
    assert_allclose(r_hat1, r_hat2)