def test_gelman_rubin(): # only need to test precision for small data x = np.empty((2, 10)) x[0, :] = np.arange(10.0) x[1, :] = np.arange(10.0) + 1 r_hat = gelman_rubin(x) assert_allclose(r_hat, 0.98, atol=0.01)
def gelman_rubin(samples): if ((samples.shape[0] < 2 and samples.shape[1] < 4) or (samples.shape[0] >= 2 and samples.shape[1] < 2)): return None # Too few chains or samples. elif samples.shape[0] >= 2: return diags.gelman_rubin(samples) else: return diags.split_gelman_rubin(samples)
def test_split_gelman_rubin_agree_with_gelman_rubin(): x = np.random.normal(size=(2, 10)) r_hat1 = gelman_rubin(x.reshape(2, 2, 5).reshape(4, 5)) r_hat2 = split_gelman_rubin(x) assert_allclose(r_hat1, r_hat2)