Exemplo n.º 1
0
def test_minres_with_jacobi():
    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T,g) / M
    dv = T.dot(g.T,h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)
   
    Ldiag_terms = natural.generic_compute_L_diag([vv,gg,hh])
    Ms = [Ldiag_term + 0.1 for Ldiag_term in Ldiag_terms]

    newgrads = minres.minres(
            lambda xw, xv, xa, xb, xc: natural.compute_Lx(vv,gg,hh,xw,xv,xa,xb,xc),
            [dw, dv, da, db, dc],
            rtol=1e-5,
            damp = 0.,
            maxiter = 10000,
            Ms = Ms,
            profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
Exemplo n.º 2
0
def test_minres_with_jacobi():
    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T, g) / M
    dv = T.dot(g.T, h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)

    Ldiag_terms = natural.generic_compute_L_diag([vv, gg, hh])
    Ms = [Ldiag_term + 0.1 for Ldiag_term in Ldiag_terms]

    newgrads = minres.minres(lambda xw, xv, xa, xb, xc: natural.compute_Lx(
        vv, gg, hh, xw, xv, xa, xb, xc), [dw, dv, da, db, dc],
                             rtol=1e-5,
                             damp=0.,
                             maxiter=10000,
                             Ms=Ms,
                             profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
Exemplo n.º 3
0
def test_minres_with_xinit():
    rng = numpy.random.RandomState(123412)

    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T,g) / M
    dv = T.dot(g.T,h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)
  
    xinit = [ rng.rand(N0,N1),
              rng.rand(N1,N2),
              rng.rand(N0),
              rng.rand(N1),
              rng.rand(N2)]
    xinit = [xi.astype(floatX) for xi in xinit]

    newgrads = minres.minres(
            lambda xw, xv, xa, xb, xc: natural.compute_Lx(vv,gg,hh,xw,xv,xa,xb,xc),
            [dw, dv, da, db, dc],
            rtol=1e-5,
            damp = 0.,
            maxiter = 10000,
            xinit = xinit,
            profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
Exemplo n.º 4
0
def test_minres():
    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T,g) / M
    dv = T.dot(g.T,h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)
   
    newgrads = minres.minres(
            lambda xw, xv, xa, xb, xc: natural.compute_Lx(vv,gg,hh,xw,xv,xa,xb,xc),
            [dw, dv, da, db, dc],
            rtol=1e-5,
            damp = 0.,
            maxit = 30,
            profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
Exemplo n.º 5
0
def test_minres():
    rval = minres.minres(lambda x: [T.dot(symb['L'], x)], [symb['g']],
                         rtol=1e-14,
                         damp=0.,
                         maxiter=10000,
                         profile=0)

    f = theano.function([symb['L'], symb['g']], [rval[0][0], rval[1], rval[2]])
    t1 = time.time()
    [Linv_g, flag, iter] = f(vals['L'], vals['g'])
    print 'test_minres runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=3)
Exemplo n.º 6
0
def test_minres():
    rval = minres.minres(
            lambda x: [T.dot(symb['L'], x)],
            [symb['g']],
            rtol=1e-14,
            damp = 0.,
            maxiter = 10000,
            profile=0)

    f = theano.function([symb['L'], symb['g']], [rval[0][0], rval[1], rval[2]])
    t1 = time.time()
    [Linv_g, flag, iter] = f(vals['L'], vals['g'])
    print 'test_minres runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=3)
Exemplo n.º 7
0
def test_minres_xinit():
    symb['xinit'] = T.vector('xinit')
    vals['xinit'] = rng.rand(nparams).astype(floatX)

    symb_Linv_g = minres.minres(lambda x: [T.dot(symb['L'], x)], [symb['g']],
                                rtol=1e-14,
                                damp=0.,
                                maxiter=10000,
                                xinit=[symb['xinit']],
                                profile=0)[0]

    f = theano.function([symb['L'], symb['g'], symb['xinit']], symb_Linv_g)
    t1 = time.time()
    Linv_g = f(vals['L'], vals['g'], vals['xinit'])[0]
    print 'test_minres_xinit runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=3)
Exemplo n.º 8
0
def test_minres_xinit():
    symb['xinit'] = T.vector('xinit')
    vals['xinit'] = rng.rand(nparams).astype(floatX)

    symb_Linv_g = minres.minres(
            lambda x: [T.dot(symb['L'], x)],
            [symb['g']],
            rtol=1e-14,
            damp = 0.,
            maxiter = 10000,
            xinit = [symb['xinit']],
            profile=0)[0]

    f = theano.function([symb['L'], symb['g'], symb['xinit']], symb_Linv_g)
    t1 = time.time()
    Linv_g = f(vals['L'], vals['g'], vals['xinit'])[0]
    print 'test_minres_xinit runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=3)
Exemplo n.º 9
0
def test_minres_with_xinit():
    rng = numpy.random.RandomState(123412)

    vv = theano.shared(v, name='v')
    gg = theano.shared(g, name='g')
    hh = theano.shared(h, name='h')
    dw = T.dot(v.T, g) / M
    dv = T.dot(g.T, h) / M
    da = T.mean(v, axis=0)
    db = T.mean(g, axis=0)
    dc = T.mean(h, axis=0)

    xinit = [
        rng.rand(N0, N1),
        rng.rand(N1, N2),
        rng.rand(N0),
        rng.rand(N1),
        rng.rand(N2)
    ]
    xinit = [xi.astype(floatX) for xi in xinit]

    newgrads = minres.minres(lambda xw, xv, xa, xb, xc: natural.compute_Lx(
        vv, gg, hh, xw, xv, xa, xb, xc), [dw, dv, da, db, dc],
                             rtol=1e-5,
                             damp=0.,
                             maxiter=10000,
                             xinit=xinit,
                             profile=0)[0]

    f = theano.function([], newgrads)
    [new_dw, new_dv, new_da, new_db, new_dc] = f()
    numpy.testing.assert_almost_equal(Linv_x_w, new_dw, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_v, new_dv, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_a, new_da, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_b, new_db, decimal=1)
    numpy.testing.assert_almost_equal(Linv_x_c, new_dc, decimal=1)
Exemplo n.º 10
0
    def get_natural_direction(self,
                              ml_cost,
                              nsamples,
                              xinit=None,
                              precondition=None):
        """
        Returns: list
            See lincg documentation for the meaning of each return value.
            rvals[0]: niter
            rvals[1]: rerr
        """
        assert precondition in [None, 'jacobi']
        self.cg_params.setdefault('batch_size', self.batch_size)

        nsamples = nsamples[:self.cg_params['batch_size']]
        neg_energies = self.energy(nsamples)

        if self.computational_bs > 0:
            raise NotImplementedError()
        else:

            def Lx_func(*args):
                Lneg_x = fisher.compute_Lx(neg_energies, self.params, args)
                if self.flags['minresQLP']:
                    return Lneg_x, {}
                else:
                    return Lneg_x

        M = None
        if precondition == 'jacobi':
            cnsamples = self.center_samples(nsamples)
            raw_M = fisher.compute_L_diag(cnsamples)
            M = [(Mi + self.cg_params['damp']) for Mi in raw_M]

        if self.flags['minres']:
            rvals = minres.minres(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                rtol=self.cg_params['rtol'],
                maxiter=self.cg_params['maxiter'],
                damp=self.cg_params['damp'],
                xinit=xinit,
                Ms=M)
            [newgrads, flag, niter, rerr] = rvals[:4]
        elif self.flags['minresQLP']:
            param_shapes = []
            for p in self.params:
                param_shapes += [p.get_value().shape]
            rvals = minresQLP.minresQLP(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                param_shapes,
                rtol=self.cg_params['rtol'],
                maxit=self.cg_params['maxiter'],
                damp=self.cg_params['damp'],
                Ms=M,
                profile=0)
            [newgrads, flag, niter, rerr] = rvals[:4]
        else:
            rvals = lincg.linear_cg(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                rtol=self.cg_params['rtol'],
                damp=self.cg_params['damp'],
                maxiter=self.cg_params['maxiter'],
                xinit=xinit,
                M=M)
            [newgrads, niter, rerr] = rvals

        # Now replace grad with natural gradient.
        cos_dist = 0.
        norm2_old = 0.
        norm2_new = 0.
        for i, param in enumerate(self.params):
            norm2_old += T.sum(ml_cost.grads[param]**2)
            norm2_new += T.sum(newgrads[i]**2)
            cos_dist += T.dot(ml_cost.grads[param].flatten(),
                              newgrads[i].flatten())
            ml_cost.grads[param] = newgrads[i]
        cos_dist /= (norm2_old * norm2_new)

        return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)
Exemplo n.º 11
0
    def get_natural_direction(self, ml_cost, nsamples, xinit=None,
                              precondition=None):
        """
        Returns: list
            See lincg documentation for the meaning of each return value.
            rvals[0]: niter
            rvals[1]: rerr
        """
        assert precondition in [None, 'jacobi']
        self.cg_params.setdefault('batch_size', self.batch_size)

        nsamples = nsamples[:self.cg_params['batch_size']]
        neg_energies = self.energy(nsamples)

        if self.computational_bs > 0:
            raise NotImplementedError()
        else:
            def Lx_func(*args):
                Lneg_x = fisher.compute_Lx(
                        neg_energies,
                        self.params,
                        args)
                if self.flags['minresQLP']:
                    return Lneg_x, {}
                else:
                    return Lneg_x

        M = None
        if precondition == 'jacobi':
            cnsamples = self.center_samples(nsamples)
            raw_M = fisher.compute_L_diag(cnsamples)
            M = [(Mi + self.cg_params['damp']) for Mi in raw_M]

        if self.flags['minres']:
            rvals = minres.minres(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    rtol = self.cg_params['rtol'],
                    maxiter = self.cg_params['maxiter'],
                    damp = self.cg_params['damp'],
                    xinit = xinit,
                    Ms = M)
            [newgrads, flag, niter, rerr] = rvals[:4]
        elif self.flags['minresQLP']:
            param_shapes = []
            for p in self.params:
                param_shapes += [p.get_value().shape]
            rvals = minresQLP.minresQLP(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    param_shapes,
                    rtol = self.cg_params['rtol'],
                    maxit = self.cg_params['maxiter'],
                    damp = self.cg_params['damp'],
                    Ms = M,
                    profile = 0)
            [newgrads, flag, niter, rerr] = rvals[:4]
        else:
            rvals = lincg.linear_cg(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    rtol = self.cg_params['rtol'],
                    damp = self.cg_params['damp'],
                    maxiter = self.cg_params['maxiter'],
                    xinit = xinit,
                    M = M)
            [newgrads, niter, rerr] = rvals

        # Now replace grad with natural gradient.
        cos_dist  = 0.
        norm2_old = 0.
        norm2_new = 0.
        for i, param in enumerate(self.params):
            norm2_old += T.sum(ml_cost.grads[param]**2)
            norm2_new += T.sum(newgrads[i]**2)
            cos_dist += T.dot(ml_cost.grads[param].flatten(),
                              newgrads[i].flatten())
            ml_cost.grads[param] = newgrads[i]
        cos_dist /= (norm2_old * norm2_new)
        
        return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)