Exemplo n.º 1
0
        def solve_Ax_g(num_params, params, grads, grads_norm):
            def compute_Ax(x):

                # There are three ways to compute the Fisher-vector product:

                # 1. https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L54
                # Use theano.gradient.disconnected_grad and call theano.tensor.grad() twice.
                # WARNING: In our case (with the attention mechanism) it is extremly slow.

                # 2. http://deeplearning.net/software/theano/tutorial/gradients.html#hessian-times-a-vector
                # Use only theano.tensor.Rop, but you will need to calculate the fixed_output outside
                # of the compiled function, because disconnected_grad will not work with Rop.

                # 3. https://github.com/pascanur/natgrad/blob/master/model_convMNIST_standard.py
                # Rop devided by output because a metric F is based on gradient of log(output).
                # Here we also split the vector of parameters. Not checked, but it may be
                # faster then supply few vectors to minresQLP.

                xs = []
                offset = 0
                for p in params:
                    shape = p.get_value().shape
                    size = np.prod(shape)
                    xs.append(x[offset:offset + size].reshape(shape))
                    offset += size

                jvp = T.Rop(new_output, params, xs) / (
                    new_output * self.batch_size * self.history + TINY)
                fvp = T.Lop(new_output, params, jvp)
                fvp = T.concatenate([g.flatten() for g in fvp])

                return [fvp], {}

            rvals = minresQLP(compute_Ax,
                              grads / grads_norm,
                              num_params,
                              damp=DAMPING,
                              rtol=1e-10,
                              maxit=40,
                              TranCond=1)

            flag = T.cast(rvals[1], 'int32')
            residual = rvals[3]
            Acond = rvals[5]

            x = rvals[0] * grads_norm
            Ax = compute_Ax(x)[0][0] + DAMPING * x
            xAx = x.dot(Ax.T)

            lm = T.sqrt(2 * MAX_KL / xAx)
            rs = lm * x

            return rs, lm, flag, residual, Acond
Exemplo n.º 2
0
def test_minres():

    sol, flag, iters, relres, Anorm, Acond = minresQLP(
        lambda x: ([T.dot(L, x)], {}),
        g,
        param_shapes=(nparams, ),
        rtol=1e-20,
        maxit=100000)

    f = theano.function([], [sol])
    t1 = time.time()
    rvals = f()
    Linv_g = rvals[0]
    print 'test_minres runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=2)
Exemplo n.º 3
0
def test_minres():

    sol, flag, iters, relres, Anorm, Acond = minresQLP(
            lambda x: ([T.dot(L, x)], {}),
            g,
            param_shapes = (nparams,),
            rtol=1e-20,
            maxit = 100000)

    f = theano.function([], [sol])
    t1 = time.time()
    rvals = f()
    Linv_g = rvals[0]
    print 'test_minres runtime (s):', time.time() - t1
    numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=2)
Exemplo n.º 4
0
    def get_natural_direction(self,
                              ml_cost,
                              nsamples,
                              xinit=None,
                              precondition=None):
        """
        Returns: list
            See lincg documentation for the meaning of each return value.
            rvals[0]: niter
            rvals[1]: rerr
        """
        assert precondition in [None, 'jacobi']
        self.cg_params.setdefault('batch_size', self.batch_size)

        nsamples = nsamples[:self.cg_params['batch_size']]
        neg_energies = self.energy(nsamples)

        if self.computational_bs > 0:
            raise NotImplementedError()
        else:

            def Lx_func(*args):
                Lneg_x = fisher.compute_Lx(neg_energies, self.params, args)
                if self.flags['minresQLP']:
                    return Lneg_x, {}
                else:
                    return Lneg_x

        M = None
        if precondition == 'jacobi':
            cnsamples = self.center_samples(nsamples)
            raw_M = fisher.compute_L_diag(cnsamples)
            M = [(Mi + self.cg_params['damp']) for Mi in raw_M]

        if self.flags['minres']:
            rvals = minres.minres(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                rtol=self.cg_params['rtol'],
                maxiter=self.cg_params['maxiter'],
                damp=self.cg_params['damp'],
                xinit=xinit,
                Ms=M)
            [newgrads, flag, niter, rerr] = rvals[:4]
        elif self.flags['minresQLP']:
            param_shapes = []
            for p in self.params:
                param_shapes += [p.get_value().shape]
            rvals = minresQLP.minresQLP(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                param_shapes,
                rtol=self.cg_params['rtol'],
                maxit=self.cg_params['maxiter'],
                damp=self.cg_params['damp'],
                Ms=M,
                profile=0)
            [newgrads, flag, niter, rerr] = rvals[:4]
        else:
            rvals = lincg.linear_cg(
                Lx_func, [ml_cost.grads[param] for param in self.params],
                rtol=self.cg_params['rtol'],
                damp=self.cg_params['damp'],
                maxiter=self.cg_params['maxiter'],
                xinit=xinit,
                M=M)
            [newgrads, niter, rerr] = rvals

        # Now replace grad with natural gradient.
        cos_dist = 0.
        norm2_old = 0.
        norm2_new = 0.
        for i, param in enumerate(self.params):
            norm2_old += T.sum(ml_cost.grads[param]**2)
            norm2_new += T.sum(newgrads[i]**2)
            cos_dist += T.dot(ml_cost.grads[param].flatten(),
                              newgrads[i].flatten())
            ml_cost.grads[param] = newgrads[i]
        cos_dist /= (norm2_old * norm2_new)

        return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)
Exemplo n.º 5
0
    def get_natural_direction(self, ml_cost, nsamples, xinit=None,
                              precondition=None):
        """
        Returns: list
            See lincg documentation for the meaning of each return value.
            rvals[0]: niter
            rvals[1]: rerr
        """
        assert precondition in [None, 'jacobi']
        self.cg_params.setdefault('batch_size', self.batch_size)

        nsamples = nsamples[:self.cg_params['batch_size']]
        neg_energies = self.energy(nsamples)

        if self.computational_bs > 0:
            raise NotImplementedError()
        else:
            def Lx_func(*args):
                Lneg_x = fisher.compute_Lx(
                        neg_energies,
                        self.params,
                        args)
                if self.flags['minresQLP']:
                    return Lneg_x, {}
                else:
                    return Lneg_x

        M = None
        if precondition == 'jacobi':
            cnsamples = self.center_samples(nsamples)
            raw_M = fisher.compute_L_diag(cnsamples)
            M = [(Mi + self.cg_params['damp']) for Mi in raw_M]

        if self.flags['minres']:
            rvals = minres.minres(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    rtol = self.cg_params['rtol'],
                    maxiter = self.cg_params['maxiter'],
                    damp = self.cg_params['damp'],
                    xinit = xinit,
                    Ms = M)
            [newgrads, flag, niter, rerr] = rvals[:4]
        elif self.flags['minresQLP']:
            param_shapes = []
            for p in self.params:
                param_shapes += [p.get_value().shape]
            rvals = minresQLP.minresQLP(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    param_shapes,
                    rtol = self.cg_params['rtol'],
                    maxit = self.cg_params['maxiter'],
                    damp = self.cg_params['damp'],
                    Ms = M,
                    profile = 0)
            [newgrads, flag, niter, rerr] = rvals[:4]
        else:
            rvals = lincg.linear_cg(
                    Lx_func,
                    [ml_cost.grads[param] for param in self.params],
                    rtol = self.cg_params['rtol'],
                    damp = self.cg_params['damp'],
                    maxiter = self.cg_params['maxiter'],
                    xinit = xinit,
                    M = M)
            [newgrads, niter, rerr] = rvals

        # Now replace grad with natural gradient.
        cos_dist  = 0.
        norm2_old = 0.
        norm2_new = 0.
        for i, param in enumerate(self.params):
            norm2_old += T.sum(ml_cost.grads[param]**2)
            norm2_new += T.sum(newgrads[i]**2)
            cos_dist += T.dot(ml_cost.grads[param].flatten(),
                              newgrads[i].flatten())
            ml_cost.grads[param] = newgrads[i]
        cos_dist /= (norm2_old * norm2_new)
        
        return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)