예제 #1
0
def test_minres_with_jacobi():
    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T,g) / M
    dv = T.dot(g.T,h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)
   
    Ldiag_terms = natural.generic_compute_L_diag([vv,gg,hh])
    Ms = [Ldiag_term + 0.1 for Ldiag_term in Ldiag_terms]

    newgrads = minres.minres(
            lambda xw, xv, xa, xb, xc: natural.compute_Lx(vv,gg,hh,xw,xv,xa,xb,xc),
            [dw, dv, da, db, dc],
            rtol=1e-5,
            damp = 0.,
            maxiter = 10000,
            Ms = Ms,
            profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
예제 #2
0
def test_minres_with_jacobi():
    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T, g) / M
    dv = T.dot(g.T, h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)

    Ldiag_terms = natural.generic_compute_L_diag([vv, gg, hh])
    Ms = [Ldiag_term + 0.1 for Ldiag_term in Ldiag_terms]

    newgrads = minres.minres(lambda xw, xv, xa, xb, xc: natural.compute_Lx(
        vv, gg, hh, xw, xv, xa, xb, xc), [dw, dv, da, db, dc],
                             rtol=1e-5,
                             damp=0.,
                             maxiter=10000,
                             Ms=Ms,
                             profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
예제 #3
0
def test_generic_compute_Ldiag():

    ## now compare against theano version
    vv = T.matrix()
    gg = T.matrix()
    hh = T.matrix()
    qq = T.matrix()
    # test compute_Lx
    LL = natural.generic_compute_L_diag([vv, gg, hh, qq])
    f = theano.function([vv, gg, hh, qq], LL)
    rvals = f(v, g, h, q)
    # compare against baseline
    Ldiag = numpy.diag(L)
    Ldiag_w = Ldiag[:N0 * N1].reshape(N0, N1)
    Ldiag_v = Ldiag[N0 * N1:N0 * N1 + N1 * N2].reshape(N1, N2)
    Ldiag_z = Ldiag[N0 * N1 + N1 * N2:N0 * N1 + N1 * N2 + N2 * N3].reshape(
        N2, N3)
    Ldiag_a = Ldiag[-N3 - N2 - N1 - N0:-N3 - N2 - N1]
    Ldiag_b = Ldiag[-N3 - N2 - N1:-N3 - N2]
    Ldiag_c = Ldiag[-N3 - N2:-N3]
    Ldiag_d = Ldiag[-N3:]
    numpy.testing.assert_almost_equal(Ldiag_w, rvals[0], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_v, rvals[1], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_z, rvals[2], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_a, rvals[3], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_b, rvals[4], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_c, rvals[5], decimal=3)
예제 #4
0
def test_generic_compute_Ldiag():

    ## now compare against theano version
    vv = T.matrix()
    gg = T.matrix()
    hh = T.matrix()
    qq = T.matrix()
    # test compute_Lx
    LL = natural.generic_compute_L_diag([vv, gg, hh, qq])
    f = theano.function([vv, gg, hh, qq], LL)
    rvals = f(v, g, h, q)
    # compare against baseline
    Ldiag = numpy.diag(L)
    Ldiag_w = Ldiag[:N0*N1].reshape(N0,N1)
    Ldiag_v = Ldiag[N0*N1 : N0*N1 + N1*N2].reshape(N1,N2)
    Ldiag_z = Ldiag[N0*N1 + N1*N2 : N0*N1 + N1*N2 + N2*N3].reshape(N2,N3)
    Ldiag_a = Ldiag[-N3-N2-N1-N0:-N3-N2-N1]
    Ldiag_b = Ldiag[-N3-N2-N1:-N3-N2]
    Ldiag_c = Ldiag[-N3-N2:-N3]
    Ldiag_d = Ldiag[-N3:]
    numpy.testing.assert_almost_equal(Ldiag_w, rvals[0], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_v, rvals[1], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_z, rvals[2], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_a, rvals[3], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_b, rvals[4], decimal=3)
    numpy.testing.assert_almost_equal(Ldiag_c, rvals[5], decimal=3)