def test_dist(): np.random.seed(0) p1, p2, p3 = (np.random.randn(3, 1), np.random.randn(4, 1), np.random.randn(5, 1)) q1, q2, q3 = (np.random.randn(6, 1), np.random.randn(7, 1), np.random.randn(8, 1)) # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5)) # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8)) comm = MPI.COMM_WORLD assert comm.Get_size() == 2 if comm.Get_rank() == 0: x1, x2, x3 = p1, p2, p3 elif comm.Get_rank() == 1: x1, x2, x3 = q1, q2, q3 else: assert False rms = RunningMeanStd(epsilon=0.0, shape=(1, )) U.initialize() rms.update(x1) rms.update(x2) rms.update(x3) bigvec = np.concatenate([p1, p2, p3, q1, q2, q3]) def checkallclose(x, y): print(x, y) return np.allclose(x, y) assert checkallclose(bigvec.mean(axis=0), U.eval(rms.mean)) assert checkallclose(bigvec.std(axis=0), U.eval(rms.std))
def validate_probtype(probtype, pdparam): N = 100000 # Check to see if mean negative log likelihood == differential entropy Mval = np.repeat(pdparam[None, :], N, axis=0) M = probtype.param_placeholder([N]) X = probtype.sample_placeholder([N]) pd = probtype.pdclass()(M) calcloglik = U.function([X, M], pd.logp(X)) calcent = U.function([M], pd.entropy()) Xval = U.eval(pd.sample(), feed_dict={M: Mval}) logliks = calcloglik(Xval, Mval) entval_ll = -logliks.mean() #pylint: disable=E1101 entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 entval = calcent(Mval).mean() #pylint: disable=E1101 assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas # Check to see if kldiv[p,q] = - ent[p] - E_p[log q] M2 = probtype.param_placeholder([N]) pd2 = probtype.pdclass()(M2) q = pdparam + np.random.randn(pdparam.size) * 0.1 Mval2 = np.repeat(q[None, :], N, axis=0) calckl = U.function([M, M2], pd.kl(pd2)) klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101 logliks = calcloglik(Xval, Mval2) klval_ll = -entval - logliks.mean() #pylint: disable=E1101 klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
def test_runningmeanstd(): for (x1, x2, x3) in [ (np.random.randn(3), np.random.randn(4), np.random.randn(5)), (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), ]: rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) U.initialize() x = np.concatenate([x1, x2, x3], axis=0) ms1 = [x.mean(axis=0), x.std(axis=0)] rms.update(x1) rms.update(x2) rms.update(x3) ms2 = U.eval([rms.mean, rms.std]) assert np.allclose(ms1, ms2)