def solve_Ax_g(num_params, params, grads, grads_norm): def compute_Ax(x): # There are three ways to compute the Fisher-vector product: # 1. https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L54 # Use theano.gradient.disconnected_grad and call theano.tensor.grad() twice. # WARNING: In our case (with the attention mechanism) it is extremly slow. # 2. http://deeplearning.net/software/theano/tutorial/gradients.html#hessian-times-a-vector # Use only theano.tensor.Rop, but you will need to calculate the fixed_output outside # of the compiled function, because disconnected_grad will not work with Rop. # 3. https://github.com/pascanur/natgrad/blob/master/model_convMNIST_standard.py # Rop devided by output because a metric F is based on gradient of log(output). # Here we also split the vector of parameters. Not checked, but it may be # faster then supply few vectors to minresQLP. xs = [] offset = 0 for p in params: shape = p.get_value().shape size = np.prod(shape) xs.append(x[offset:offset + size].reshape(shape)) offset += size jvp = T.Rop(new_output, params, xs) / ( new_output * self.batch_size * self.history + TINY) fvp = T.Lop(new_output, params, jvp) fvp = T.concatenate([g.flatten() for g in fvp]) return [fvp], {} rvals = minresQLP(compute_Ax, grads / grads_norm, num_params, damp=DAMPING, rtol=1e-10, maxit=40, TranCond=1) flag = T.cast(rvals[1], 'int32') residual = rvals[3] Acond = rvals[5] x = rvals[0] * grads_norm Ax = compute_Ax(x)[0][0] + DAMPING * x xAx = x.dot(Ax.T) lm = T.sqrt(2 * MAX_KL / xAx) rs = lm * x return rs, lm, flag, residual, Acond
def test_minres(): sol, flag, iters, relres, Anorm, Acond = minresQLP( lambda x: ([T.dot(L, x)], {}), g, param_shapes=(nparams, ), rtol=1e-20, maxit=100000) f = theano.function([], [sol]) t1 = time.time() rvals = f() Linv_g = rvals[0] print 'test_minres runtime (s):', time.time() - t1 numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=2)
def test_minres(): sol, flag, iters, relres, Anorm, Acond = minresQLP( lambda x: ([T.dot(L, x)], {}), g, param_shapes = (nparams,), rtol=1e-20, maxit = 100000) f = theano.function([], [sol]) t1 = time.time() rvals = f() Linv_g = rvals[0] print 'test_minres runtime (s):', time.time() - t1 numpy.testing.assert_almost_equal(Linv_g, vals['Linv_g'], decimal=2)
def get_natural_direction(self, ml_cost, nsamples, xinit=None, precondition=None): """ Returns: list See lincg documentation for the meaning of each return value. rvals[0]: niter rvals[1]: rerr """ assert precondition in [None, 'jacobi'] self.cg_params.setdefault('batch_size', self.batch_size) nsamples = nsamples[:self.cg_params['batch_size']] neg_energies = self.energy(nsamples) if self.computational_bs > 0: raise NotImplementedError() else: def Lx_func(*args): Lneg_x = fisher.compute_Lx(neg_energies, self.params, args) if self.flags['minresQLP']: return Lneg_x, {} else: return Lneg_x M = None if precondition == 'jacobi': cnsamples = self.center_samples(nsamples) raw_M = fisher.compute_L_diag(cnsamples) M = [(Mi + self.cg_params['damp']) for Mi in raw_M] if self.flags['minres']: rvals = minres.minres( Lx_func, [ml_cost.grads[param] for param in self.params], rtol=self.cg_params['rtol'], maxiter=self.cg_params['maxiter'], damp=self.cg_params['damp'], xinit=xinit, Ms=M) [newgrads, flag, niter, rerr] = rvals[:4] elif self.flags['minresQLP']: param_shapes = [] for p in self.params: param_shapes += [p.get_value().shape] rvals = minresQLP.minresQLP( Lx_func, [ml_cost.grads[param] for param in self.params], param_shapes, rtol=self.cg_params['rtol'], maxit=self.cg_params['maxiter'], damp=self.cg_params['damp'], Ms=M, profile=0) [newgrads, flag, niter, rerr] = rvals[:4] else: rvals = lincg.linear_cg( Lx_func, [ml_cost.grads[param] for param in self.params], rtol=self.cg_params['rtol'], damp=self.cg_params['damp'], maxiter=self.cg_params['maxiter'], xinit=xinit, M=M) [newgrads, niter, rerr] = rvals # Now replace grad with natural gradient. cos_dist = 0. norm2_old = 0. norm2_new = 0. for i, param in enumerate(self.params): norm2_old += T.sum(ml_cost.grads[param]**2) norm2_new += T.sum(newgrads[i]**2) cos_dist += T.dot(ml_cost.grads[param].flatten(), newgrads[i].flatten()) ml_cost.grads[param] = newgrads[i] cos_dist /= (norm2_old * norm2_new) return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)
def get_natural_direction(self, ml_cost, nsamples, xinit=None, precondition=None): """ Returns: list See lincg documentation for the meaning of each return value. rvals[0]: niter rvals[1]: rerr """ assert precondition in [None, 'jacobi'] self.cg_params.setdefault('batch_size', self.batch_size) nsamples = nsamples[:self.cg_params['batch_size']] neg_energies = self.energy(nsamples) if self.computational_bs > 0: raise NotImplementedError() else: def Lx_func(*args): Lneg_x = fisher.compute_Lx( neg_energies, self.params, args) if self.flags['minresQLP']: return Lneg_x, {} else: return Lneg_x M = None if precondition == 'jacobi': cnsamples = self.center_samples(nsamples) raw_M = fisher.compute_L_diag(cnsamples) M = [(Mi + self.cg_params['damp']) for Mi in raw_M] if self.flags['minres']: rvals = minres.minres( Lx_func, [ml_cost.grads[param] for param in self.params], rtol = self.cg_params['rtol'], maxiter = self.cg_params['maxiter'], damp = self.cg_params['damp'], xinit = xinit, Ms = M) [newgrads, flag, niter, rerr] = rvals[:4] elif self.flags['minresQLP']: param_shapes = [] for p in self.params: param_shapes += [p.get_value().shape] rvals = minresQLP.minresQLP( Lx_func, [ml_cost.grads[param] for param in self.params], param_shapes, rtol = self.cg_params['rtol'], maxit = self.cg_params['maxiter'], damp = self.cg_params['damp'], Ms = M, profile = 0) [newgrads, flag, niter, rerr] = rvals[:4] else: rvals = lincg.linear_cg( Lx_func, [ml_cost.grads[param] for param in self.params], rtol = self.cg_params['rtol'], damp = self.cg_params['damp'], maxiter = self.cg_params['maxiter'], xinit = xinit, M = M) [newgrads, niter, rerr] = rvals # Now replace grad with natural gradient. cos_dist = 0. norm2_old = 0. norm2_new = 0. for i, param in enumerate(self.params): norm2_old += T.sum(ml_cost.grads[param]**2) norm2_new += T.sum(newgrads[i]**2) cos_dist += T.dot(ml_cost.grads[param].flatten(), newgrads[i].flatten()) ml_cost.grads[param] = newgrads[i] cos_dist /= (norm2_old * norm2_new) return [niter, rerr, cos_dist], self.get_dparam_updates(*newgrads)