def run_GN_iter(self, num_cg_iter): """Runs a single GN iteration.""" self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.vdetach() if self.debug and not self.analyze_convergence: loss = self.problem.ip_output(self.g, self.g) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.g.requires_grad(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.residuals = torch.cat((self.residuals, res))
def run_newton_iter(self, num_cg_iter): self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) if self.debug and not self.analyze_convergence: self.losses = torch.cat( (self.losses, self.f0.detach().cpu().view(-1))) # Gradient of loss self.g = TensorList( torch.autograd.grad(self.f0, self.x, create_graph=True)) # Get the right hand side self.b = -self.g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.residuals = torch.cat((self.residuals, res))
def run(self, num_cg_iter): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.vdetach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.vdetach_() self.clear_temp()
def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient loss = self.problem(x) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat( (self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))
def __init__(self, problem: L2Problem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, plotting=False, fig_num=(10, 11)): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or plotting) self.problem = problem self.x = variable.variable() self.plotting = plotting self.fig_num = fig_num self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0)
def run(self, num_iter, dummy=None): if num_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(num_iter + 1) grad_mags = torch.zeros(num_iter + 1) for i in range(num_iter): self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Compute loss loss = self.problem.ip_output(self.f0, self.f0) # Compute grad grad = TensorList(torch.autograd.grad(loss, self.x)) # Update direction if self.dir is None: self.dir = grad else: self.dir = grad + self.momentum * self.dir self.x.vdetach_() self.x -= self.step_legnth * self.dir if self.debug: lossvec[i] = loss.item() grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item() if self.debug: self.x.requires_grad_(True) self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) grad = TensorList(torch.autograd.grad(loss, self.x)) lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item() grad_mags[-1] = sum( grad.view(-1) @ grad.view(-1)).cpu().sqrt().item() self.losses = torch.cat((self.losses, lossvec)) self.gradient_mags = torch.cat((self.gradient_mags, grad_mags)) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.gradient_mags, self.fig_num[1], title='Gradient magnitude') self.x.vdetach_() self.clear_temp()
def A(self, x): return TensorList( torch.autograd.grad(self.g, self.x, x, retain_graph=True)) + self.hessian_reg * x
class NewtonCG(ConjugateGradientBase): """Newton with Conjugate Gradient. Handels general minimization problems.""" def __init__(self, problem: MinimizationProblem, variable: TensorList, init_hessian_reg=0.0, hessian_reg_factor=1.0, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, analyze=False, plotting=False, fig_num=(10, 11, 12)): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or analyze or plotting) self.problem = problem self.x = variable self.analyze_convergence = analyze self.plotting = plotting self.fig_num = fig_num self.hessian_reg = init_hessian_reg self.hessian_reg_factor = hessian_reg_factor self.cg_eps = cg_eps self.f0 = None self.g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) self.gradient_mags = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None def run(self, num_cg_iter, num_newton_iter=None): if isinstance(num_cg_iter, int): if num_cg_iter == 0: return if num_newton_iter is None: num_newton_iter = 1 num_cg_iter = [num_cg_iter] * num_newton_iter num_newton_iter = len(num_cg_iter) if num_newton_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) for cg_iter in num_cg_iter: self.run_newton_iter(cg_iter) self.hessian_reg *= self.hessian_reg_factor if self.debug: if not self.analyze_convergence: loss = self.problem(self.x) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.vdetach_() self.clear_temp() return self.losses, self.residuals def run_newton_iter(self, num_cg_iter): self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) if self.debug and not self.analyze_convergence: self.losses = torch.cat( (self.losses, self.f0.detach().cpu().view(-1))) # Gradient of loss self.g = TensorList( torch.autograd.grad(self.f0, self.x, create_graph=True)) # Get the right hand side self.b = -self.g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.residuals = torch.cat((self.residuals, res)) def A(self, x): return TensorList( torch.autograd.grad(self.g, self.x, x, retain_graph=True)) + self.hessian_reg * x def ip(self, a, b): # Implements the inner product return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x) def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient loss = self.problem(x) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat( (self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
class GaussNewtonCG(ConjugateGradientBase): """Gauss-Newton with Conjugate Gradient optimizer.""" def __init__(self, problem: L2Problem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, analyze=False, plotting=False, fig_num=(10, 11, 12)): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or analyze or plotting) self.problem = problem self.x = variable self.analyze_convergence = analyze self.plotting = plotting self.fig_num = fig_num self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) self.gradient_mags = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None self.dfdxt_g = None def run_GN(self, *args, **kwargs): return self.run(*args, **kwargs) def run(self, num_cg_iter, num_gn_iter=None): """Run the optimizer. args: num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations and number of GN iterations is given by the length of the list. num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer. """ if isinstance(num_cg_iter, int): if num_gn_iter is None: raise ValueError( 'Must specify number of GN iter if CG iter is constant') num_cg_iter = [num_cg_iter] * num_gn_iter num_gn_iter = len(num_cg_iter) if num_gn_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) # Outer loop for running the GN iterations. for cg_iter in num_cg_iter: self.run_GN_iter(cg_iter) if self.debug: if not self.analyze_convergence: self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.vdetach_() self.clear_temp() return self.losses, self.residuals def run_GN_iter(self, num_cg_iter): """Runs a single GN iteration.""" self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.vdetach() if self.debug and not self.analyze_convergence: loss = self.problem.ip_output(self.g, self.g) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.g.requires_grad(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.residuals = torch.cat((self.residuals, res)) def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True)) def ip(self, a, b): return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x) def evaluate_CG_iteration(self, delta_x): if self.analyze_convergence: x = (self.x + delta_x).detach() x.requires_grad_(True) # compute loss and gradient f = self.problem(x) loss = self.problem.ip_output(f, f) grad = TensorList(torch.autograd.grad(loss, x)) # store in the vectors self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) self.gradient_mags = torch.cat( (self.gradient_mags, sum(grad.view(-1) @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
class ConjugateGradient(ConjugateGradientBase): """Conjugate Gradient optimizer, performing single linearization of the residuals in the start.""" def __init__(self, problem: L2Problem, variable: TensorList, cg_eps=0.0, fletcher_reeves=True, standard_alpha=True, direction_forget_factor=0, debug=False, plotting=False, fig_num=(10, 11)): super().__init__(fletcher_reeves, standard_alpha, direction_forget_factor, debug or plotting) self.problem = problem self.x = variable.variable() self.plotting = plotting self.fig_num = fig_num self.cg_eps = cg_eps self.f0 = None self.g = None self.dfdxt_g = None self.residuals = torch.zeros(0) self.losses = torch.zeros(0) def clear_temp(self): self.f0 = None self.g = None self.dfdxt_g = None def run(self, num_cg_iter): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.vdetach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.vdetach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.vdetach_() self.x.plus_(delta_x) if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.vdetach_() self.clear_temp() def A(self, x): dfdx_x = torch.autograd.grad(self.dfdxt_g, self.g, x, retain_graph=True) return TensorList( torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True)) def ip(self, a, b): return self.problem.ip_input(a, b) def M1(self, x): return self.problem.M1(x) def M2(self, x): return self.problem.M2(x)