Ejemplo n.º 1
0
def test_dist():
    np.random.seed(0)
    p1, p2, p3 = (np.random.randn(3, 1), np.random.randn(4, 1),
                  np.random.randn(5, 1))
    q1, q2, q3 = (np.random.randn(6, 1), np.random.randn(7, 1),
                  np.random.randn(8, 1))

    # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
    # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))

    comm = MPI.COMM_WORLD
    assert comm.Get_size() == 2
    if comm.Get_rank() == 0:
        x1, x2, x3 = p1, p2, p3
    elif comm.Get_rank() == 1:
        x1, x2, x3 = q1, q2, q3
    else:
        assert False

    rms = RunningMeanStd(epsilon=0.0, shape=(1, ))
    U.initialize()

    rms.update(x1)
    rms.update(x2)
    rms.update(x3)

    bigvec = np.concatenate([p1, p2, p3, q1, q2, q3])

    def checkallclose(x, y):
        print(x, y)
        return np.allclose(x, y)

    assert checkallclose(bigvec.mean(axis=0), U.eval(rms.mean))
    assert checkallclose(bigvec.std(axis=0), U.eval(rms.std))
Ejemplo n.º 2
0
def validate_probtype(probtype, pdparam):
    N = 100000
    # Check to see if mean negative log likelihood == differential entropy
    Mval = np.repeat(pdparam[None, :], N, axis=0)
    M = probtype.param_placeholder([N])
    X = probtype.sample_placeholder([N])
    pd = probtype.pdclass()(M)
    calcloglik = U.function([X, M], pd.logp(X))
    calcent = U.function([M], pd.entropy())
    Xval = U.eval(pd.sample(), feed_dict={M: Mval})
    logliks = calcloglik(Xval, Mval)
    entval_ll = -logliks.mean()  #pylint: disable=E1101
    entval_ll_stderr = logliks.std() / np.sqrt(N)  #pylint: disable=E1101
    entval = calcent(Mval).mean()  #pylint: disable=E1101
    assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr  # within 3 sigmas

    # Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
    M2 = probtype.param_placeholder([N])
    pd2 = probtype.pdclass()(M2)
    q = pdparam + np.random.randn(pdparam.size) * 0.1
    Mval2 = np.repeat(q[None, :], N, axis=0)
    calckl = U.function([M, M2], pd.kl(pd2))
    klval = calckl(Mval, Mval2).mean()  #pylint: disable=E1101
    logliks = calcloglik(Xval, Mval2)
    klval_ll = -entval - logliks.mean()  #pylint: disable=E1101
    klval_ll_stderr = logliks.std() / np.sqrt(N)  #pylint: disable=E1101
    assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr  # within 3 sigmas
Ejemplo n.º 3
0
def test_runningmeanstd():
    for (x1, x2, x3) in [
        (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
        (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
    ]:

        rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
        U.initialize()

        x = np.concatenate([x1, x2, x3], axis=0)
        ms1 = [x.mean(axis=0), x.std(axis=0)]
        rms.update(x1)
        rms.update(x2)
        rms.update(x3)
        ms2 = U.eval([rms.mean, rms.std])

        assert np.allclose(ms1, ms2)