def __call__(self, A, b): lup = self.cache.get(id(A)) if lup is None: lup = A.lu() self.cache[id(A)] = lup if b.ndim == 1: return torch.lu_solve(b.unsqueeze(-1), *lup).squeeze(-1) else: return torch.lu_solve(b, *lup)
def forward(ctx, A, b): A_LU, pivots = torch.lu(A) x = torch.lu_solve(b, A_LU, pivots) ctx.save_for_backward(A_LU, pivots, x) return x
def _exp_pade_generic(A, m=7): """ Minimal, inefficient implementation of the [m/m]-Padé approximation of the matrix exponential. """ LU = torch.lu(_pade_poly(-A,m)) result = torch.lu_solve(_pade_poly(A,m),*LU) return result
def _compute_weights(self): if self._target_dim > 1: # we first factorize the matrix self.nodes = torch.zeros(self.N, self._target_dim, dtype=self.di.dtype, device=self.device) lu_data = torch.lu(self.A) for i in range(self._target_dim): self.nodes[:, i] = torch.lu_solve(self.di[:, i].unsqueeze(0).T, *lu_data).squeeze() else: self.nodes = torch.solve(self.A, self.di)[0]
def backward(ctx, grad_x): A_LU, pivots, x = ctx.saved_tensors # Math: # A * grad_b = grad_x # grad_A = -grad_b * x^T grad_b = torch.lu_solve(grad_x, A_LU, pivots) grad_A = -torch.matmul(grad_b, x.view(1, -1)) return grad_A, grad_b
def prox(self, t, nu, warm_start, pool, cache): # raise NotImplementedError("This method is not yet done!!!") XtX = cache['XtX'] XtY = cache['XtY'] n = cache['n'] A_LU = torch.lu(XtX + 1. / (2 * t) * torch.eye(n).unsqueeze(0).double()) b = XtY + 1. / (2 * t) * torch.from_numpy(nu) x = torch.lu_solve(b, *A_LU) return x.numpy()
def solve(self, t): sigma = self.sigma m = self.m Lambda, U = self.Lambda, self.U LU = self.LU # Solve linear system (see eq. (11)) y = Lambda @ U.t() @ t z = torch.lu_solve(y, *LU) alpha = 1.0 / sigma * (t - U @ z) return alpha
def solve_kkt_be(self, rx, rs, rz, ry): b1 = torch.cat((rx, rs), dim=1) if ry != None: b2 = torch.cat((rz, ry), dim=1) else: b2 = rz A11 = self.J[:, :self.nx + self.nineq, :self.nx + self.nineq] A12 = self.J[:, :self.nx + self.nineq, self.nx + self.nineq:] A21 = torch.transpose(A12, dim0=2, dim1=1) # self.J_lu,self.J_piv= self.lu_factorize(self.J) # U_A11= torch.cholesky(A11) U_A11, U_A11_piv = self.lu_factorize(A11) # u=torch.cholesky_solve(b1,U_A11) u = torch.lu_solve(b1, U_A11, U_A11_piv) # v=torch.cholesky_solve(A12,U_A11) v = torch.lu_solve(A12, U_A11, U_A11_piv) S_neg = torch.bmm(A21, v) U_S_neg, U_S_neg_piv = self.lu_factorize(S_neg) # w= torch.cholesky_solve(b2,U_S_neg) w = torch.lu_solve(b2, U_S_neg, U_S_neg_piv) # t= torch.cholesky_solve(A21,U_S_neg ) t = torch.lu_solve(A21, U_S_neg, U_S_neg_piv) x2 = -(w - torch.bmm(t, u)) x1 = u - torch.bmm(v, x2) dx = x1[:, :self.nx, :] ds = x1[:, self.nx:, :] if ry != None: dz = x2[:, :-self.neq, :] else: dz = x2 if ry != None: dy = x2[:, -self.neq:, :] else: dy = None return (dx, ds, dz, dy)
def backward(self, y, log_df_dz): with torch.no_grad(): LU = self.L * self.L_mask + self.U * self.U_mask + torch.diag( self.sign_s * torch.exp(self.log_s)) y_reshape = y.view(y.size(0), y.size(1), -1) y_reshape = torch.lu_solve(y_reshape, LU.unsqueeze(0), self.pivots.unsqueeze(0)) y = y_reshape.view(y.size()) y = y.contiguous() num_pixels = np.prod(y.size()) // (y.size(0) * y.size(1)) log_df_dz -= torch.sum(self.log_s, dim=0) * num_pixels return y, log_df_dz
def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps): nineq, nz, neq, nBatch = get_sizes(G, A) H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde) H_[:, :nz, :nz] = Q_tilde H_[:, -nineq:, -nineq:] = D if neq > 0: # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1), # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0) A_ = torch.cat([ torch.cat( [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2), torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2) ], 1) g_ = torch.cat([rx, rs], 1) h_ = torch.cat([rz, ry], 1) else: A_ = torch.cat( [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2) g_ = torch.cat([rx, rs], 1) h_ = rz H_LU = lu_hack(H_) invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU) invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) S_ = torch.bmm(A_, invH_A_) S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1) S_LU = lu_hack(S_) t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2) t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze() v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) v_ = torch.lu_solve(t_.unsqueeze(2), *H_LU).squeeze(2) dx = v_[:, :nz] ds = v_[:, nz:] dz = w_[:, :nineq] dy = w_[:, nineq:] if neq > 0 else None return dx, ds, dz, dy
def xpbatch_lu_solve(lu_and_piv, b): """ solve ax = b :param lu_and_piv: :param b: :return: """ LU, piv = lu_and_piv # xp = get_array_module(LU) b = b.copy() ''' for i in range(len(LU)): if cupy_available and xp == cupy: b[i] = cupyx.scipy.linalg.lu_solve((LU[i], piv[i]), b[i], overwrite_b=True) else: b[i] = scipy.linalg.lu_solve((LU[i], piv[i]), b[i], overwrite_b=True) ''' b = torch.Tensor(b) b = b.float() LU = torch.from_numpy(LU) piv = torch.from_numpy(piv) LU = LU.float() b = torch.lu_solve(b, LU, piv) return b.cpu().numpy()
def newton_exact(f, g, x_guess, opt_params, ls_method, ls_params): """ This function performs gradient descent using newton's method as the search direction INPUTS: f < function > : objective function f(x) -> f g < function > : gradient function g(x) -> g x_guess < tensor > : initial x opt_params < dict{ 'ep_g' < float > : conv. tolerance on gradient 'ep_a' < float > : absolute tolerance 'ep_r' < float > : relative tolerance 'Hessian' < function > : function that returns the Hessian 'iter_lim' < int > : iteration limit } > : dictionary of optimization settings ls_method < str > : indicates which method to use with line search ls_params < dict > : dictionary with parameters to use for line search """ ep_g = opt_params['ep_g'] ep_a = opt_params['ep_a'] ep_r = opt_params['ep_r'] H = opt_params['Hessian'] iter_lim = opt_params['iter_lim'] # initializations x_k = x_guess x_hist = [x_k] f_k = f(x_guess) f_hist = [f_k] k = 0 conv_count = 0 # how many iterations for rel. abs. tolerance met before stopping conv_count_max = 2 while k < iter_lim: k += 1 # compute gradient g_k = g(x_k) # check for gradient convergence if torch.norm(g_k) <= ep_g: converge = True message = "Exact Newton converged due to grad. tolerance." break # invert Hessian and find search direction H_k = H(x_k) H_LU, pivots, infos = torch.lu(H_k.reshape( [1, H_k.shape[0], -1]), get_infos=True) if infos.nonzero().size(0) != 0: # check if LU factorization failed converge = False message = "Hessian LU factorization failed." break # LU solve is designed for batch operations, hence the [0] delta_k = torch.lu_solve(-g_k.unsqueeze(0), H_LU, pivots)[0] if torch.matmul(delta_k.t(), g_k) < 0: p_k = delta_k else: p_k = -delta_k # perform line search alf, ls_converge, ls_message = line_search(f, x_k, g_k, p_k, ls_method=ls_method, ls_params=ls_params) if not ls_converge: converge = ls_converge message = ls_message break # compute x_(k+1) x_k1, f_k1 = search_step(f, x_k, alf, p_k) # check relative and absolute convergence criteria if rel_abs_convergence(f_k, f_k1, ep_a, ep_r): conv_count += 1 x_k = x_k1 f_k = f_k1 x_hist.append(x_k) f_hist.append(f_k) if conv_count >= conv_count_max: converge = True message = "Exact Newton converged due to abs. rel. tolerance." break if k == iter_lim: converge = False message = "Exact Newton iteration limit reached." return x_k, f_k, x_hist, f_hist, converge, message
def quad_search(f, x_k, g_k, p_k, ls_params): """ This function performs approximate quadratic line search INPUTS: f < function > : objective function f(x) -> f x_k < tensor > : current best guess for f(x) minimum g_k < tensor > : gradient evaluated at x_k p_k < tensor > : search direction alf < float > : initial step length ls_params < dict{ 'alf' < float > : initial guess for step-length 'mu' < float > : small positive constant used in "Armijo suff. decrease condition" 'rho' < float > : step-size dicount coefficient 'iter_lim < int > : iteration limit for solver 'alf_lower_coeff' : coefficient for determining point one in quad_search 'alf_upper_coeff' : coefficient for determining point three in quad_search } > : dictionary with parameters to use for line search RETURNS: alf_new < float > : computed search length converge < bool > : bool indicating whether line search converged message < string > : string with output from back tracking method """ mu = ls_params['mu'] iter_lim = ls_params['iter_lim'] alf_new = ls_params['alf'] alf_coeff1 = ls_params['alf_lower_coeff'] alf_coeff2 = ls_params['alf_upper_coeff'] iter = 0 while not armijo_suff_decrease(f, x_k, g_k, p_k, alf_new, mu) and iter < iter_lim: a1 = alf_new a2 = alf_new * 0.1 a3 = alf_new * 2.0 f1 = f(x_k + a1 * p_k) f2 = f(x_k + a2 * p_k) f3 = f(x_k + a3 * p_k) A = torch.tensor([[1 / 2 * a1.pow(2), a1, 1], [1 / 2 * a2.pow(2), a2, 1], [1 / 2 * a3.pow(2), a3, 1]]) b = torch.tensor([[f1], [f2], [f3]]) A_LU, pivots, infos = torch.lu(A.reshape([1, A.shape[0], -1]), get_infos=True) if infos.nonzero().size(0) != 0: converge = False message = "Quadratic approx was not possible." break coeff = torch.lu_solve(b.unsqueeze(0), A_LU, pivots)[0] alf_new = -coeff[1] / coeff[0] iter += 1 if iter == iter_lim: converge = False message = "Quadratic approx. line search iteration limit reached." else: converge = True message = "Quadratic approx. line search converged." return alf_new, converge, message