def solve(self, w): X = torch.zeros((1, 2 * self.Nf * self.Nh), dtype=torch.float64, requires_grad=True) # w = torch.zeros(1, requires_grad=True) A0 = self.model.get_A(w) D_lu = torch.lu(A0) FE = self.force_ex.get_f(w) X[0, :] = FE.unsqueeze(-1).lu_solve(*D_lu).squeeze() for iter in range(self.max_iter): self.aft_method.process(X) FN = self.aft_method.get_vector().detach() dFdX = self.aft_method.get_jacobian().detach() RX = torch.matmul( A0, X.unsqueeze(-1)) + FN.unsqueeze(-1) - FE.unsqueeze(-1) err = float(torch.norm(RX, 1).detach().numpy()) if self.display: print("error: " + str(err)) if err < self.max_err: print("iter: " + str(iter + 1) + " error: " + str(err)) break D_lu = torch.lu(A0 + dFdX) dX = RX.lu_solve(*D_lu).squeeze(-1) X = X - dX return X.detach()
def blas_lapack_ops(self): m = torch.randn(3, 3) a = torch.randn(10, 3, 4) b = torch.randn(10, 4, 3) v = torch.randn(3) return ( torch.addbmm(m, a, b), torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)), torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)), torch.addr(torch.zeros(3, 3), v, v), torch.baddbmm(m, a, b), torch.bmm(a, b), torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)), # torch.cholesky(a), # deprecated torch.cholesky_inverse(torch.randn(3, 3)), torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)), torch.dot(v, v), torch.eig(m), torch.geqrf(a), torch.ger(v, v), torch.inner(m, m), torch.inverse(m), torch.det(m), torch.logdet(m), torch.slogdet(m), torch.lstsq(m, m), torch.lu(m), torch.lu_solve(m, *torch.lu(m)), torch.lu_unpack(*torch.lu(m)), torch.matmul(m, m), torch.matrix_power(m, 2), # torch.matrix_rank(m), torch.matrix_exp(m), torch.mm(m, m), torch.mv(m, v), # torch.orgqr(a, m), # torch.ormqr(a, m, v), torch.outer(v, v), torch.pinverse(m), # torch.qr(a), torch.solve(m, m), torch.svd(a), # torch.svd_lowrank(a), # torch.pca_lowrank(a), # torch.symeig(a), # deprecated # torch.lobpcg(a, b), # not supported torch.trapz(m, m), torch.trapezoid(m, m), torch.cumulative_trapezoid(m, m), # torch.triangular_solve(m, m), torch.vdot(v, v), )
def forward(ctx, A, b): A_LU, pivots = torch.lu(A) x = torch.lu_solve(b, A_LU, pivots) ctx.save_for_backward(A_LU, pivots, x) return x
def __init__(self, num_channels, LU_decomposed): super().__init__() w_shape = [num_channels, num_channels] w_init = torch.qr(torch.randn(*w_shape))[0] if not LU_decomposed: self.weight = nn.Parameter(torch.Tensor(w_init)) else: p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer("p", p) self.register_buffer("sign_s", sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.l_mask = l_mask self.eye = eye self.w_shape = w_shape self.LU_decomposed = LU_decomposed
def __init__(self, num_features, LU_decomposed=True): super(Invertible1x1ConvLU, self).__init__() w_shape = [num_features, num_features] w_init = torch.qr(torch.randn(*w_shape))[0] self.num_features = num_features if not LU_decomposed: self.weight = nn.Parameter(torch.Tensor(w_init)) else: p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer("p", p) self.register_buffer("sign_s", sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.l_mask = l_mask self.eye = eye self.if_LU = LU_decomposed self.w_shape = w_shape
def __init__(self, num_channels, LU_decomposed=True): super().__init__() self.num_channels = num_channels self.LU_decomposed = LU_decomposed weight_shape = [num_channels, num_channels] weight, _ = torch.qr(torch.randn(*weight_shape)) if not self.LU_decomposed: self.weight = nn.Parameter(weight) else: weight_lu, pivots = torch.lu(weight) w_p, w_l, w_u = torch.lu_unpack(weight_lu, pivots) w_s = torch.diag(w_u) sign_s = torch.sign(w_s) log_s = torch.log(torch.abs(w_s)) w_u = torch.triu(w_u, 1) u_mask = torch.triu(torch.ones_like(w_u), 1) l_mask = u_mask.T.contiguous() eye = torch.eye(l_mask.shape[0]) self.register_buffer('p', w_p) self.register_buffer('sign_s', sign_s) self.register_buffer('eye', eye) self.register_buffer('u_mask', u_mask) self.register_buffer('l_mask', l_mask) self.l = nn.Parameter(w_l) self.u = nn.Parameter(w_u) self.log_s = nn.Parameter(log_s)
def preprocess(self): X = self.X k = self.k sigma = self.sigma n, m = X.shape[0], self.m # Compute reduced kernel matrix Knm = k(X[:n, :], X[:m, :]).evaluate() # Compute eigen-decomposition (see eq. (7)) Lambda_m, U_m = torch.symeig(Knm[:m, :], eigenvectors=True) # Compute approximate eigenvalues (see eq. (8)) Lambda = (n / m) * Lambda_m # Compute approximate eigenvectors (see eq. (9)) U = math.sqrt(m / n) * Knm @ U_m @ torch.diag(1.0 / Lambda_m) # Convert Lambda to diagonal matrix Lambda = torch.diag(Lambda) # Store approximate eigenvalues and eigenvectors self.Lambda, self.U = Lambda, U # Factorize M M = Lambda @ U.t() @ U + sigma * torch.eye(m) # Compute and store LU factors of M self.LU = torch.lu(M.detach())
def __init__(self, in_out_channels): super(InvertibleConv1x1, self).__init__() W = torch.zeros((in_out_channels, in_out_channels), dtype=torch.float32) nn.init.orthogonal_(W) LU, pivots = torch.lu(W) P, L, U = torch.lu_unpack(LU, pivots) self.P = nn.Parameter(P, requires_grad=False) self.L = nn.Parameter(L, requires_grad=True) self.U = nn.Parameter(U, requires_grad=True) self.I = nn.Parameter(torch.eye(in_out_channels), requires_grad=False) self.pivots = nn.Parameter(pivots, requires_grad=False) L_mask = np.tril(np.ones((in_out_channels, in_out_channels), dtype='float32'), k=-1) U_mask = L_mask.T.copy() self.L_mask = nn.Parameter(torch.from_numpy(L_mask), requires_grad=False) self.U_mask = nn.Parameter(torch.from_numpy(U_mask), requires_grad=False) s = torch.diag(U) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) self.log_s = nn.Parameter(log_s, requires_grad=True) self.sign_s = nn.Parameter(sign_s, requires_grad=False)
def _exp_pade_generic(A, m=7): """ Minimal, inefficient implementation of the [m/m]-Padé approximation of the matrix exponential. """ LU = torch.lu(_pade_poly(-A,m)) result = torch.lu_solve(_pade_poly(A,m),*LU) return result
def sampling_W(self, dim): # sample a rotation matrix W = torch.empty(dim, dim) torch.nn.init.orthogonal_(W) # compute LU LU, pivot = torch.lu(W) P, L, U = torch.lu_unpack(LU, pivot) return W, P, L, U
def __init__(self, dim): super().__init__() self.dim = dim Q = nn.init.orthogonal_(torch.randn(dim, dim).to(device)) P, L, U = torch.lu_unpack(*torch.lu(Q)) self.register_buffer('P', P) self.L = nn.Parameter(L) # lower triangular portion self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
def _compute_weights(self): if self._target_dim > 1: # we first factorize the matrix self.nodes = torch.zeros(self.N, self._target_dim, dtype=self.di.dtype, device=self.device) lu_data = torch.lu(self.A) for i in range(self._target_dim): self.nodes[:, i] = torch.lu_solve(self.di[:, i].unsqueeze(0).T, *lu_data).squeeze() else: self.nodes = torch.solve(self.A, self.di)[0]
def prox(self, t, nu, warm_start, pool, cache): # raise NotImplementedError("This method is not yet done!!!") XtX = cache['XtX'] XtY = cache['XtY'] n = cache['n'] A_LU = torch.lu(XtX + 1. / (2 * t) * torch.eye(n).unsqueeze(0).double()) b = XtY + 1. / (2 * t) * torch.from_numpy(nu) x = torch.lu_solve(b, *A_LU) return x.numpy()
def __init__(self, in_channel, fixed, use_lu): super().__init__() self.use_lu = use_lu self.fixed = fixed if not use_lu: weight = th.randn(in_channel, in_channel) q, _ = th.qr(weight) weight = q if fixed: self.register_buffer('weight', weight.data) self.register_buffer('weight_inverse', weight.data.inverse()) self.register_buffer( 'fixed_log_det', th.slogdet(self.weight.double())[1].float()) else: self.weight = nn.Parameter(weight) if use_lu: assert not fixed #weight = np.random.randn(in_channel, in_channel) weight = th.randn(in_channel, in_channel) #q, _ = la.qr(weight) q, _ = th.qr(weight) # w_p, w_l, w_u = la.lu(q.astype(np.float32)) w_p, w_l, w_u = th.lu_unpack(*th.lu(q)) #w_s = np.diag(w_u) w_s = th.diag(w_u) #w_u = np.triu(w_u, 1) w_u = th.triu(w_u, 1) #u_mask = np.triu(np.ones_like(w_u), 1) u_mask = th.triu(th.ones_like(w_u), 1) #l_mask = u_mask.T l_mask = u_mask.t() #w_p = th.from_numpy(w_p) #w_l = th.from_numpy(w_l)w #w_s = th.from_numpy(w_s) #w_u = th.from_numpy(w_u) self.register_buffer('w_p', w_p) self.register_buffer('u_mask', u_mask) self.register_buffer('l_mask', l_mask) self.register_buffer('s_sign', th.sign(w_s)) self.register_buffer('l_eye', th.eye(l_mask.shape[0])) self.w_l = nn.Parameter(w_l) self.w_s = nn.Parameter(th.log(th.abs(w_s))) self.w_u = nn.Parameter(w_u)
def __init__(self, c): super(Invertible1x1ConvLUS, self).__init__() # Sample a random orthonormal matrix to initialize weights W, _ = torch.linalg.qr(torch.randn(c, c)) # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: W[:, 0] = -1 * W[:, 0] p, lower, upper = torch.lu_unpack(*torch.lu(W)) self.register_buffer('p', p) # diagonals of lower will always be 1s anyway lower = torch.tril(lower, -1) lower_diag = torch.diag(torch.eye(c, c)) self.register_buffer('lower_diag', lower_diag) self.lower = nn.Parameter(lower) self.upper_diag = nn.Parameter(torch.diag(upper)) self.upper = nn.Parameter(torch.triu(upper, 1))
def safe_solve_with_mask(B: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Helper function, which avoids crashing because of singular matrix input and outputs the mask of valid solution""" if not torch_version_geq(1, 10): sol, lu = _torch_solve_cast(B, A) warnings.warn('PyTorch version < 1.10, solve validness mask maybe not correct', RuntimeWarning) return sol, lu, torch.ones(len(A), dtype=torch.bool, device=A.device) # Based on https://github.com/pytorch/pytorch/issues/31546#issuecomment-694135622 if not isinstance(B, torch.Tensor): raise AssertionError(f"B must be torch.Tensor. Got: {type(B)}.") dtype: torch.dtype = B.dtype if dtype not in (torch.float32, torch.float64): dtype = torch.float32 A_LU, pivots, info = torch.lu(A.to(dtype), get_infos=True) valid_mask: torch.Tensor = info == 0 X = torch.lu_solve(B.to(dtype), A_LU, pivots) return X.to(B.dtype), A_LU.to(A.dtype), valid_mask
def __init__(self, dim): super(Invertible1x1Conv, self).__init__() self.dim = dim # Grab the weight and bias from a randomly initialized Conv2d. m = nn.Conv2d(dim, dim, kernel_size=1) W = m.weight.clone().detach().reshape(dim, dim) LU, pivots = torch.lu(W) P, _, _ = torch.lu_unpack(LU, pivots) s = torch.diag(LU) # noinspection PyTypeChecker LU = torch.where(torch.eye(dim) == 0, LU, torch.zeros_like(LU)) self.register_buffer("P", P) self.register_buffer("s_sign", torch.sign(s)) self.register_parameter("s_log", nn.Parameter(torch.log(torch.abs(s) + 1e-3))) self.register_parameter("LU", nn.Parameter(LU))
def test_dcem(): n_batch = 2 n_sample = 100 N = 2 torch.manual_seed(0) Q = torch.eye(N).unsqueeze(0).repeat(n_batch, 1, 1) p = 0.1 * torch.randn(n_batch, N) Q_sample = Q.unsqueeze(1).repeat(1, n_sample, 1, 1).view(n_batch * n_sample, N, N) p_sample = p.unsqueeze(1).repeat(1, n_sample, 1).view(n_batch * n_sample, N) def f(X): assert X.size() == (n_batch, n_sample, N) X = X.view(n_batch * n_sample, N) obj = 0.5 * (bmv(Q_sample, X) * X).sum(dim=1) + (p_sample * X).sum(dim=1) obj = obj.view(n_batch, n_sample) return obj def iter_cb(i, X, fX, I, X_I, mu, sigma): print(fX.mean(dim=1)) xhat = dcem( f, nx=N, n_batch=n_batch, n_sample=n_sample, n_elite=50, n_iter=40, temp=1., normalize=True, # temp = np.infty, init_sigma=1., iter_cb=iter_cb, # lb = -5., ub = 5., ) Q_LU = torch.lu(Q) xstar = -torch.lu_solve(p, *Q_LU) assert (xhat - xstar).abs().max() < 1e-4
def __init__(self, dim_inputs): super().__init__() self.dim_inputs = dim_inputs w_shape = [dim_inputs, dim_inputs] w_init = torch.qr(torch.randn(w_shape))[0] p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer('p', p) self.register_buffer('sign_s', sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.register_buffer('l_mask', l_mask) self.register_buffer('eye', eye)
def xpbatch_lu_factor(A): assert len(A.shape) == 3, "Actual" + str(A.shape) assert A.shape[1] == A.shape[2], "Actual" + str(A.shape) xp = get_array_module(A) A = copy.deepcopy(A) ''' if cupy_available and xp == cupy: Ps = xp.empty((A.shape[0], A.shape[1]), dtype=np.int32) for i in range(len(A)): A[i], Ps[i] = cupyx.scipy.linalg.u_factor(A[i], overwrite_a=True) return A, Ps else: Ps = [] for i in range(len(A)): A[i], piv = scipy.linalg.lu_factor(A[i], overwrite_a=True) Ps.append(piv) return A, Ps ''' # use PyTorch here, because scipy does not offer batch lu_factorization A_LU, pivots = torch.lu(torch.tensor(A)) return A_LU.cpu().numpy(), pivots.cpu().numpy()
def __init__(self, num_channels, lu_decomposition=False): """ Invertible 1x1 convolution layer :param num_channels: number of channels :type num_channels: int :param lu_decomposition: whether to use LU decomposition :type lu_decomposition: bool """ super().__init__() self.num_channels = num_channels self.lu_decomposition = lu_decomposition w_shape = [num_channels, num_channels] tolerance = 1e-4 # Sample a random orthogonal matrix w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') if self.lu_decomposition: w_LU_pts, pivots = torch.lu(torch.Tensor(w_init)) p, w_l, w_u = torch.lu_unpack(w_LU_pts, pivots) s = torch.diag(torch.diag(w_u)) w_u -= s print((torch.Tensor(w_init) - torch.matmul(p, torch.matmul(w_l, (w_u + s)))).abs().sum()) assert (torch.Tensor(w_init) - torch.matmul( p, torch.matmul(w_l, (w_u + s)))).abs().sum() < tolerance self.register_parameter('weight_L', nn.Parameter(torch.Tensor(w_l))) self.register_parameter('weight_U', nn.Parameter(torch.Tensor(w_u))) self.register_parameter('s', nn.Parameter(torch.Tensor(s))) self.register_buffer('weight_P', torch.FloatTensor(p)) else: #w_shape = [num_channels, num_channels] # Sample a random orthogonal matrix #w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') self.register_parameter('weight', nn.Parameter(torch.Tensor(w_init)))
def __init__(self, in_channel, weight=None): # in_channel indicates the channel size of the input super().__init__() self.in_channel = in_channel # set an random orthogonal matrix as the initial weight if weight is None: weight = torch.randn(in_channel, in_channel) Weight, _ = torch.qr(weight) # PLU decomposition self.weight = Weight # only for testing Weight_LU, pivots = torch.lu(Weight) w_p, w_l, w_u = torch.lu_unpack(Weight_LU, pivots) w_s = torch.diag(w_u) w_logs = torch.log(torch.abs(w_s)) s_sign = torch.sign(w_s) w_u = torch.triu(w_u, 1) u_mask = torch.triu(torch.ones_like(w_u), 1) l_mask = u_mask.T l_eye = torch.eye(l_mask.shape[0]) # fix P self.register_buffer("w_p", w_p) self.register_buffer("u_mask", u_mask) self.register_buffer("l_mask", l_mask) self.register_buffer("s_sign", s_sign) self.register_buffer("l_eye", l_eye) self.w_l = torch.nn.Parameter(w_l) self.w_u = torch.nn.Parameter(w_u) # self.w_s = torch.nn.Parameter(w_s) self.w_logs = torch.nn.Parameter(w_logs)
def rank_revealing_LUP_GPU(A, threshold=10**-10): """ It returns the rank of A using GPU acceleration. This is based on PyTorch's LUP implementation. Parameters __________ A: (2D float64 Numpy Array) Input array for which rank needs to be computed threshold: (Float) See documentation of the function "sort_rows_fast" Returns: _________ rank: (Integer) rank of A """ try: import torch except ImportError as e: print(str(e)) print("rank_revealing_LUP_GPU requires PyTorch to run.") A = A / np.abs(A).max() A_tensor = torch.from_numpy(A).cuda() A_LU, pivots = torch.lu(A_tensor, get_infos=False) _, _, u = torch.lu_unpack(A_LU, pivots, unpack_pivots=False) u = u.cpu().numpy() ref = compute_ref_LUP_fast(u, threshold) rank = compute_rank_ref(ref, threshold) return rank
def corrector(self, w0, X0, dw, dX, ds): self.aft_method.process(X0) FE = self.force_ex.get_f(w0) FN = self.aft_method.get_vector().detach() # start = time.time() dFdX = self.aft_method.get_jacobian().detach() # end = time.time() # print("differ: " + str(end - start)) A0 = self.model.get_A(w0).detach() dAdw = self.model.get_DADw(w0).detach() / self.weight_w dFEdw = self.force_ex.get_dfdw(w0) / self.weight_w RX = torch.matmul( A0, X0.unsqueeze(-1)) + FN.unsqueeze(-1) - FE.unsqueeze(-1) err = float(torch.norm(RX, 1).detach().numpy()) dRdX = A0 + dFdX dRdw = torch.matmul(dAdw, X0.unsqueeze(-1)).detach() - dFEdw.unsqueeze(-1) t, J = self.get_tau(dRdX, dRdw) J_ = torch.cat((J, t.unsqueeze(-1).permute(0, 2, 1)), dim=1) Rw = torch.sum(t * torch.cat( (dX, self.weight_w * dw.view(-1, 1)), dim=1)) - ds D_lu = torch.lu(J_) d = torch.cat((RX, Rw.unsqueeze(0).unsqueeze(0).unsqueeze(0)), dim=1).detach().lu_solve(*D_lu) dw = dw - d[:, -1] / self.weight_w dX = dX - d[:, :-1].squeeze(-1) return dw.detach(), dX.detach(), err
def _expm_frechet_pade(A, E, m=7): assert(m in [3,5,7,9,13]) if m == 3: b = [120., 60., 12., 1.] elif m == 5: b = [30240., 15120., 3360., 420., 30., 1.] elif m == 7: b = [17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.] elif m == 9: b = [17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160., 110880., 3960., 90., 1.] elif m == 13: b = [64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.] # Efficiently compute series terms of M_i (and A_i if needed). # Not very pretty, but more readable than the alternatives. I = _eye_like(A) if m!=13: if m >= 3: M_2 = A @ E + E @ A A_2 = A @ A U = b[3]*A_2 V = b[2]*A_2 L_U = b[3]*M_2 L_V = b[2]*M_2 if m >= 5: M_4 = A_2 @ M_2 + M_2 @ A_2 A_4 = A_2 @ A_2 U = U + b[5]*A_4 V = V + b[4]*A_4 L_U = L_U + b[5]*M_4 L_V = L_V + b[4]*M_4 if m >= 7: M_6 = A_4 @ M_2 + M_4 @ A_2 A_6 = A_4 @ A_2 U = U + b[7]*A_6 V = V + b[6]*A_6 L_U = L_U + b[7]*M_6 L_V = L_V + b[6]*M_6 if m == 9: M_8 = A_4 @ M_4 + M_4 @ A_4 A_8 = A_4 @ A_4 U = U + b[9]*A_8 V = V + b[8]*A_8 L_U = L_U + b[9]*M_8 L_V = L_V + b[8]*M_8 U = U + b[1]*I V = U + b[0]*I del I L_U = A @ L_U L_U = L_U + E @ U U = A @ U else: M_2 = A @ E + E @ A A_2 = A @ A M_4 = A_2 @ M_2 + M_2 @ A_2 A_4 = A_2 @ A_2 M_6 = A_4 @ M_2 + M_4 @ A_2 A_6 = A_4 @ A_2 W_1 = b[13]*A_6 + b[11]*A_4 + b[9]*A_2 W_2 = b[7]*A_6 + b[5]*A_4 + b[3]*A_2 + b[1]*I W = A_6 @ W_1 + W_2 Z_1 = b[12]*A_6 + b[10]*A_4 + b[8]*A_2 Z_2 = b[6]*A_6 + b[4]*A_4 + b[2]*A_2 + b[0]*I U = A @ W V = A_6 @ Z_1 + Z_2 L_W1 = b[13]*M_6 + b[11]*M_4 + b[9]*M_2 L_W2 = b[7]*M_6 + b[5]*M_4 + b[3]*M_2 L_Z1 = b[12]*M_6 + b[10]*M_4 + b[8]*M_2 L_Z2 = b[6]*M_6 + b[4]*M_4 + b[2]*M_2 L_W = A_6 @ L_W1 + M_6 @ W_1 + L_W2 L_U = A @ L_W + E @ W L_V = A_6 @ L_Z1 + M_6 @ Z_1 + L_Z2 lu_decom = torch.lu(-U + V) exp_A = torch.lu_solve(U + V, *lu_decom) dexp_A = torch.lu_solve(L_U + L_V + (L_U - L_V) @ exp_A, *lu_decom) return exp_A, dexp_A
def _expm_pade(A, m=7): assert(m in [3,5,7,9,13]) # The following are values generated as # b = torch.tensor([_fraction(m, k) for k in range(m+1)]), # but reduced to natural numbers such that b[-1]=1. This still works, # because the same constants are used in the numerator and denominator # of the Padé approximation. if m == 3: b = [120., 60., 12., 1.] elif m == 5: b = [30240., 15120., 3360., 420., 30., 1.] elif m == 7: b = [17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.] elif m == 9: b = [17643225600., 8821612800., 2075673600., 302702400., 30270240., 2162160., 110880., 3960., 90., 1.] elif m == 13: b = [64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 40840800., 960960., 16380., 182., 1.] # pre-computing terms I = _eye_like(A) if m!=13: # There is a more efficient algorithm for m=13 U = b[1]*I V = b[0]*I if m >= 3: A_2 = A @ A U = U + b[3]*A_2 V = V + b[2]*A_2 if m >= 5: A_4 = A_2 @ A_2 U = U + b[5]*A_4 V = V + b[4]*A_4 if m >= 7: A_6 = A_4 @ A_2 U = U + b[7]*A_6 V = V + b[6]*A_6 if m == 9: A_8 = A_4 @ A_4 U = U + b[9]*A_8 V = V + b[8]*A_8 U = A @ U else: A_2 = A @ A A_4 = A_2 @ A_2 A_6 = A_4 @ A_2 W_1 = b[13]*A_6 + b[11]*A_4 + b[9]*A_2 W_2 = b[7]*A_6 + b[5]*A_4 + b[3]*A_2 + b[1]*I W = A_6 @ W_1 + W_2 Z_1 = b[12]*A_6 + b[10]*A_4 + b[8]*A_2 Z_2 = b[6]*A_6 + b[4]*A_4 + b[2]*A_2 + b[0]*I U = A @ W V = A_6 @ Z_1 + Z_2 del A_2 if m>=5: del A_4 if m>=7: del A_6 if m==9: del A_8 R = torch.lu_solve(U+V, *torch.lu(-U+V)) del U, V return R
def pnqp(H, q, lower, upper, x_init=None, n_iter=20): GAMMA = 0.1 n_batch, n, _ = H.size() pnqp_I = 1e-11 * torch.eye(n).type_as(H).expand_as(H) def obj(x): return 0.5 * util.bquad(x, H) + util.bdot(q, x) if x_init is None: if n == 1: x_init = -(1. / H.squeeze(2)) * q else: # H_lu = H.btrifact() # XXX deprecated!!! H_lu = torch.lu(H) x_init = -q.btrisolve(H_lu[0], H_lu[1]) # Clamped in the x assignment. else: x_init = x_init.clone() # Don't over-write the original x_init. x = util.eclamp(x_init, lower, upper) # Active examples in the batch. J = torch.ones(n_batch).type_as(x).byte() for i in range(n_iter): g = util.bmv(H, x) + q Ic = ((x == lower) & (g > 0)) | ((x == upper) & (g < 0)) If = 1 - Ic if If.is_cuda: Hff_I = util.bger(If.float(), If.float()).type_as(If) not_Hff_I = 1 - Hff_I Hfc_I = util.bger(If.float(), Ic.float()).type_as(If) else: Hff_I = util.bger(If, If) not_Hff_I = 1 - Hff_I Hfc_I = util.bger(If, Ic) g_ = g.clone() g_[Ic] = 0. H_ = H.clone() H_[not_Hff_I] = 0.0 H_ += pnqp_I if n == 1: dx = -(1. / H_.squeeze(2)) * g_ else: # H_lu_ = H_.btrifact() # XXX deprecated!!! H_lu_ = torch.lu(H) dx = -g_.btrisolve(*H_lu_) J = torch.norm(dx, 2, 1) >= 1e-4 m = J.sum().item() # Number of active examples in the batch. if m == 0: return x, H_ if n == 1 else H_lu_, If, i alpha = torch.ones(n_batch).type_as(x) decay = 0.1 max_armijo = GAMMA count = 0 while max_armijo <= GAMMA and count < 10: # Crude way of making sure too much time isn't being spent # doing the line search. # assert count < 10 maybe_x = util.eclamp(x + torch.diag(alpha).mm(dx), lower, upper) armijos = (GAMMA + 1e-6) * torch.ones(n_batch).type_as(x) armijos[J] = (obj(x) - obj(maybe_x))[J] / util.bdot( g, x - maybe_x)[J] I = armijos <= GAMMA alpha[I] *= decay max_armijo = torch.max(armijos) count += 1 x = maybe_x # TODO: Maybe change this to a warning. print("[WARNING] pnqp warning: Did not converge") return x, H_ if n == 1 else H_lu_, If, i
def newton_exact(f, g, x_guess, opt_params, ls_method, ls_params): """ This function performs gradient descent using newton's method as the search direction INPUTS: f < function > : objective function f(x) -> f g < function > : gradient function g(x) -> g x_guess < tensor > : initial x opt_params < dict{ 'ep_g' < float > : conv. tolerance on gradient 'ep_a' < float > : absolute tolerance 'ep_r' < float > : relative tolerance 'Hessian' < function > : function that returns the Hessian 'iter_lim' < int > : iteration limit } > : dictionary of optimization settings ls_method < str > : indicates which method to use with line search ls_params < dict > : dictionary with parameters to use for line search """ ep_g = opt_params['ep_g'] ep_a = opt_params['ep_a'] ep_r = opt_params['ep_r'] H = opt_params['Hessian'] iter_lim = opt_params['iter_lim'] # initializations x_k = x_guess x_hist = [x_k] f_k = f(x_guess) f_hist = [f_k] k = 0 conv_count = 0 # how many iterations for rel. abs. tolerance met before stopping conv_count_max = 2 while k < iter_lim: k += 1 # compute gradient g_k = g(x_k) # check for gradient convergence if torch.norm(g_k) <= ep_g: converge = True message = "Exact Newton converged due to grad. tolerance." break # invert Hessian and find search direction H_k = H(x_k) H_LU, pivots, infos = torch.lu(H_k.reshape( [1, H_k.shape[0], -1]), get_infos=True) if infos.nonzero().size(0) != 0: # check if LU factorization failed converge = False message = "Hessian LU factorization failed." break # LU solve is designed for batch operations, hence the [0] delta_k = torch.lu_solve(-g_k.unsqueeze(0), H_LU, pivots)[0] if torch.matmul(delta_k.t(), g_k) < 0: p_k = delta_k else: p_k = -delta_k # perform line search alf, ls_converge, ls_message = line_search(f, x_k, g_k, p_k, ls_method=ls_method, ls_params=ls_params) if not ls_converge: converge = ls_converge message = ls_message break # compute x_(k+1) x_k1, f_k1 = search_step(f, x_k, alf, p_k) # check relative and absolute convergence criteria if rel_abs_convergence(f_k, f_k1, ep_a, ep_r): conv_count += 1 x_k = x_k1 f_k = f_k1 x_hist.append(x_k) f_hist.append(f_k) if conv_count >= conv_count_max: converge = True message = "Exact Newton converged due to abs. rel. tolerance." break if k == iter_lim: converge = False message = "Exact Newton iteration limit reached." return x_k, f_k, x_hist, f_hist, converge, message
def quad_search(f, x_k, g_k, p_k, ls_params): """ This function performs approximate quadratic line search INPUTS: f < function > : objective function f(x) -> f x_k < tensor > : current best guess for f(x) minimum g_k < tensor > : gradient evaluated at x_k p_k < tensor > : search direction alf < float > : initial step length ls_params < dict{ 'alf' < float > : initial guess for step-length 'mu' < float > : small positive constant used in "Armijo suff. decrease condition" 'rho' < float > : step-size dicount coefficient 'iter_lim < int > : iteration limit for solver 'alf_lower_coeff' : coefficient for determining point one in quad_search 'alf_upper_coeff' : coefficient for determining point three in quad_search } > : dictionary with parameters to use for line search RETURNS: alf_new < float > : computed search length converge < bool > : bool indicating whether line search converged message < string > : string with output from back tracking method """ mu = ls_params['mu'] iter_lim = ls_params['iter_lim'] alf_new = ls_params['alf'] alf_coeff1 = ls_params['alf_lower_coeff'] alf_coeff2 = ls_params['alf_upper_coeff'] iter = 0 while not armijo_suff_decrease(f, x_k, g_k, p_k, alf_new, mu) and iter < iter_lim: a1 = alf_new a2 = alf_new * 0.1 a3 = alf_new * 2.0 f1 = f(x_k + a1 * p_k) f2 = f(x_k + a2 * p_k) f3 = f(x_k + a3 * p_k) A = torch.tensor([[1 / 2 * a1.pow(2), a1, 1], [1 / 2 * a2.pow(2), a2, 1], [1 / 2 * a3.pow(2), a3, 1]]) b = torch.tensor([[f1], [f2], [f3]]) A_LU, pivots, infos = torch.lu(A.reshape([1, A.shape[0], -1]), get_infos=True) if infos.nonzero().size(0) != 0: converge = False message = "Quadratic approx was not possible." break coeff = torch.lu_solve(b.unsqueeze(0), A_LU, pivots)[0] alf_new = -coeff[1] / coeff[0] iter += 1 if iter == iter_lim: converge = False message = "Quadratic approx. line search iteration limit reached." else: converge = True message = "Quadratic approx. line search converged." return alf_new, converge, message