def run(self, num_iter, dummy=None): if num_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(num_iter + 1) grad_mags = torch.zeros(num_iter + 1) for i in range(num_iter): self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Compute loss loss = self.problem.ip_output(self.f0, self.f0) # Compute grad grad = TensorList(torch.autograd.grad(loss, self.x)) # Update direction if self.dir is None: self.dir = grad else: self.dir = grad + self.momentum * self.dir self.x.vdetach_() self.x -= self.step_legnth * self.dir if self.debug: lossvec[i] = loss.item() grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item() if self.debug: self.x.requires_grad_(True) self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) grad = TensorList(torch.autograd.grad(loss, self.x)) lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item() grad_mags[-1] = sum( grad.view(-1) @ grad.view(-1)).cpu().sqrt().item() self.losses = torch.cat((self.losses, lossvec)) self.gradient_mags = torch.cat((self.gradient_mags, grad_mags)) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.gradient_mags, self.fig_num[1], title='Gradient magnitude') self.x.vdetach_() self.clear_temp()
def run(self, num_cg_iter): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.vdetach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.vdetach_() self.clear_temp()
def run(self, num_cg_iter, num_gn_iter=None): """Run the optimizer. args: num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations and number of GN iterations is given by the length of the list. num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer. """ if isinstance(num_cg_iter, int): if num_gn_iter is None: raise ValueError( 'Must specify number of GN iter if CG iter is constant') num_cg_iter = [num_cg_iter] * num_gn_iter num_gn_iter = len(num_cg_iter) if num_gn_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) # Outer loop for running the GN iterations. for cg_iter in num_cg_iter: self.run_GN_iter(cg_iter) if self.debug: if not self.analyze_convergence: self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.vdetach_() self.clear_temp() return self.losses, self.residuals
def run(self, num_cg_iter, num_newton_iter=None): if isinstance(num_cg_iter, int): if num_cg_iter == 0: return if num_newton_iter is None: num_newton_iter = 1 num_cg_iter = [num_cg_iter] * num_newton_iter num_newton_iter = len(num_cg_iter) if num_newton_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) for cg_iter in num_cg_iter: self.run_newton_iter(cg_iter) self.hessian_reg *= self.hessian_reg_factor if self.debug: if not self.analyze_convergence: loss = self.problem(self.x) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.vdetach_() self.clear_temp() return self.losses, self.residuals