예제 #1
0
    def solve(self, w):

        X = torch.zeros((1, 2 * self.Nf * self.Nh),
                        dtype=torch.float64,
                        requires_grad=True)
        # w = torch.zeros(1, requires_grad=True)

        A0 = self.model.get_A(w)
        D_lu = torch.lu(A0)
        FE = self.force_ex.get_f(w)
        X[0, :] = FE.unsqueeze(-1).lu_solve(*D_lu).squeeze()

        for iter in range(self.max_iter):

            self.aft_method.process(X)
            FN = self.aft_method.get_vector().detach()
            dFdX = self.aft_method.get_jacobian().detach()

            RX = torch.matmul(
                A0, X.unsqueeze(-1)) + FN.unsqueeze(-1) - FE.unsqueeze(-1)
            err = float(torch.norm(RX, 1).detach().numpy())

            if self.display: print("error:  " + str(err))
            if err < self.max_err:
                print("iter:  " + str(iter + 1) + "    error:   " + str(err))
                break

            D_lu = torch.lu(A0 + dFdX)
            dX = RX.lu_solve(*D_lu).squeeze(-1)
            X = X - dX

        return X.detach()
예제 #2
0
파일: math_ops.py 프로젝트: malfet/pytorch
 def blas_lapack_ops(self):
     m = torch.randn(3, 3)
     a = torch.randn(10, 3, 4)
     b = torch.randn(10, 4, 3)
     v = torch.randn(3)
     return (
         torch.addbmm(m, a, b),
         torch.addmm(torch.randn(2, 3), torch.randn(2, 3),
                     torch.randn(3, 3)),
         torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
         torch.addr(torch.zeros(3, 3), v, v),
         torch.baddbmm(m, a, b),
         torch.bmm(a, b),
         torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3),
                            torch.randn(3, 3)),
         # torch.cholesky(a), # deprecated
         torch.cholesky_inverse(torch.randn(3, 3)),
         torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)),
         torch.dot(v, v),
         torch.eig(m),
         torch.geqrf(a),
         torch.ger(v, v),
         torch.inner(m, m),
         torch.inverse(m),
         torch.det(m),
         torch.logdet(m),
         torch.slogdet(m),
         torch.lstsq(m, m),
         torch.lu(m),
         torch.lu_solve(m, *torch.lu(m)),
         torch.lu_unpack(*torch.lu(m)),
         torch.matmul(m, m),
         torch.matrix_power(m, 2),
         # torch.matrix_rank(m),
         torch.matrix_exp(m),
         torch.mm(m, m),
         torch.mv(m, v),
         # torch.orgqr(a, m),
         # torch.ormqr(a, m, v),
         torch.outer(v, v),
         torch.pinverse(m),
         # torch.qr(a),
         torch.solve(m, m),
         torch.svd(a),
         # torch.svd_lowrank(a),
         # torch.pca_lowrank(a),
         # torch.symeig(a), # deprecated
         # torch.lobpcg(a, b), # not supported
         torch.trapz(m, m),
         torch.trapezoid(m, m),
         torch.cumulative_trapezoid(m, m),
         # torch.triangular_solve(m, m),
         torch.vdot(v, v),
     )
예제 #3
0
    def forward(ctx, A, b):
        A_LU, pivots = torch.lu(A)
        x = torch.lu_solve(b, A_LU, pivots)

        ctx.save_for_backward(A_LU, pivots, x)

        return x
예제 #4
0
    def __init__(self, num_channels, LU_decomposed):
        super().__init__()
        w_shape = [num_channels, num_channels]
        w_init = torch.qr(torch.randn(*w_shape))[0]

        if not LU_decomposed:
            self.weight = nn.Parameter(torch.Tensor(w_init))
        else:
            p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
            s = torch.diag(upper)
            sign_s = torch.sign(s)
            log_s = torch.log(torch.abs(s))
            upper = torch.triu(upper, 1)
            l_mask = torch.tril(torch.ones(w_shape), -1)
            eye = torch.eye(*w_shape)

            self.register_buffer("p", p)
            self.register_buffer("sign_s", sign_s)
            self.lower = nn.Parameter(lower)
            self.log_s = nn.Parameter(log_s)
            self.upper = nn.Parameter(upper)
            self.l_mask = l_mask
            self.eye = eye

        self.w_shape = w_shape
        self.LU_decomposed = LU_decomposed
예제 #5
0
    def __init__(self, num_features, LU_decomposed=True):
        super(Invertible1x1ConvLU, self).__init__()
        w_shape = [num_features, num_features]
        w_init = torch.qr(torch.randn(*w_shape))[0]

        self.num_features = num_features

        if not LU_decomposed:
            self.weight = nn.Parameter(torch.Tensor(w_init))
        else:
            p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
            s = torch.diag(upper)
            sign_s = torch.sign(s)
            log_s = torch.log(torch.abs(s))
            upper = torch.triu(upper, 1)
            l_mask = torch.tril(torch.ones(w_shape), -1)
            eye = torch.eye(*w_shape)

            self.register_buffer("p", p)
            self.register_buffer("sign_s", sign_s)
            self.lower = nn.Parameter(lower)
            self.log_s = nn.Parameter(log_s)
            self.upper = nn.Parameter(upper)
            self.l_mask = l_mask
            self.eye = eye

        self.if_LU = LU_decomposed
        self.w_shape = w_shape
예제 #6
0
    def __init__(self, num_channels, LU_decomposed=True):
        super().__init__()
        self.num_channels = num_channels
        self.LU_decomposed = LU_decomposed
        weight_shape = [num_channels, num_channels]
        weight, _ = torch.qr(torch.randn(*weight_shape))
        if not self.LU_decomposed:
            self.weight = nn.Parameter(weight)
        else:
            weight_lu, pivots = torch.lu(weight)
            w_p, w_l, w_u = torch.lu_unpack(weight_lu, pivots)
            w_s = torch.diag(w_u)
            sign_s = torch.sign(w_s)
            log_s = torch.log(torch.abs(w_s))
            w_u = torch.triu(w_u, 1)

            u_mask = torch.triu(torch.ones_like(w_u), 1)
            l_mask = u_mask.T.contiguous()
            eye = torch.eye(l_mask.shape[0])

            self.register_buffer('p', w_p)
            self.register_buffer('sign_s', sign_s)
            self.register_buffer('eye', eye)
            self.register_buffer('u_mask', u_mask)
            self.register_buffer('l_mask', l_mask)
            self.l = nn.Parameter(w_l)
            self.u = nn.Parameter(w_u)
            self.log_s = nn.Parameter(log_s)
예제 #7
0
    def preprocess(self):
        X = self.X
        k = self.k
        sigma = self.sigma
        n, m = X.shape[0], self.m

        # Compute reduced kernel matrix
        Knm = k(X[:n, :], X[:m, :]).evaluate()

        # Compute eigen-decomposition (see eq. (7))
        Lambda_m, U_m = torch.symeig(Knm[:m, :], eigenvectors=True)

        # Compute approximate eigenvalues (see eq. (8))
        Lambda = (n / m) * Lambda_m

        # Compute approximate eigenvectors (see eq. (9))
        U = math.sqrt(m / n) * Knm @ U_m @ torch.diag(1.0 / Lambda_m)

        # Convert Lambda to diagonal matrix
        Lambda = torch.diag(Lambda)

        # Store approximate eigenvalues and eigenvectors
        self.Lambda, self.U = Lambda, U

        # Factorize M
        M = Lambda @ U.t() @ U + sigma * torch.eye(m)

        # Compute and store LU factors of M
        self.LU = torch.lu(M.detach())
예제 #8
0
    def __init__(self, in_out_channels):
        super(InvertibleConv1x1, self).__init__()

        W = torch.zeros((in_out_channels, in_out_channels),
                        dtype=torch.float32)
        nn.init.orthogonal_(W)
        LU, pivots = torch.lu(W)

        P, L, U = torch.lu_unpack(LU, pivots)
        self.P = nn.Parameter(P, requires_grad=False)
        self.L = nn.Parameter(L, requires_grad=True)
        self.U = nn.Parameter(U, requires_grad=True)
        self.I = nn.Parameter(torch.eye(in_out_channels), requires_grad=False)
        self.pivots = nn.Parameter(pivots, requires_grad=False)

        L_mask = np.tril(np.ones((in_out_channels, in_out_channels),
                                 dtype='float32'),
                         k=-1)
        U_mask = L_mask.T.copy()
        self.L_mask = nn.Parameter(torch.from_numpy(L_mask),
                                   requires_grad=False)
        self.U_mask = nn.Parameter(torch.from_numpy(U_mask),
                                   requires_grad=False)

        s = torch.diag(U)
        sign_s = torch.sign(s)
        log_s = torch.log(torch.abs(s))
        self.log_s = nn.Parameter(log_s, requires_grad=True)
        self.sign_s = nn.Parameter(sign_s, requires_grad=False)
예제 #9
0
def _exp_pade_generic(A, m=7):
    """
    Minimal, inefficient implementation of the [m/m]-Padé approximation of the
    matrix exponential.
    """
    LU = torch.lu(_pade_poly(-A,m))
    result = torch.lu_solve(_pade_poly(A,m),*LU)
    return result
예제 #10
0
파일: utilities.py 프로젝트: wujiren/INNLab
 def sampling_W(self, dim):
     # sample a rotation matrix
     W = torch.empty(dim, dim)
     torch.nn.init.orthogonal_(W)
     # compute LU
     LU, pivot = torch.lu(W)
     P, L, U = torch.lu_unpack(LU, pivot)
     return W, P, L, U
예제 #11
0
 def __init__(self, dim):
     super().__init__()
     self.dim = dim
     Q = nn.init.orthogonal_(torch.randn(dim, dim).to(device))
     P, L, U = torch.lu_unpack(*torch.lu(Q))
     self.register_buffer('P', P)
     self.L = nn.Parameter(L) # lower triangular portion
     self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
예제 #12
0
 def _compute_weights(self):
     if self._target_dim > 1:
         # we first factorize the matrix
         self.nodes = torch.zeros(self.N, self._target_dim,
                                  dtype=self.di.dtype, device=self.device)
         lu_data = torch.lu(self.A)
         for i in range(self._target_dim):
             self.nodes[:, i] = torch.lu_solve(self.di[:, i].unsqueeze(0).T,
                                               *lu_data).squeeze()
     else:
         self.nodes = torch.solve(self.A, self.di)[0]
예제 #13
0
    def prox(self, t, nu, warm_start, pool, cache):
        # raise NotImplementedError("This method is not yet done!!!")

        XtX = cache['XtX']
        XtY = cache['XtY']
        n = cache['n']

        A_LU = torch.lu(XtX + 1. /
                        (2 * t) * torch.eye(n).unsqueeze(0).double())
        b = XtY + 1. / (2 * t) * torch.from_numpy(nu)
        x = torch.lu_solve(b, *A_LU)

        return x.numpy()
예제 #14
0
    def __init__(self, in_channel, fixed, use_lu):
        super().__init__()
        self.use_lu = use_lu
        self.fixed = fixed
        if not use_lu:
            weight = th.randn(in_channel, in_channel)
            q, _ = th.qr(weight)
            weight = q
            if fixed:
                self.register_buffer('weight', weight.data)
                self.register_buffer('weight_inverse', weight.data.inverse())
                self.register_buffer(
                    'fixed_log_det',
                    th.slogdet(self.weight.double())[1].float())
            else:
                self.weight = nn.Parameter(weight)
        if use_lu:
            assert not fixed
            #weight = np.random.randn(in_channel, in_channel)
            weight = th.randn(in_channel, in_channel)
            #q, _ = la.qr(weight)
            q, _ = th.qr(weight)

            # w_p, w_l, w_u = la.lu(q.astype(np.float32))
            w_p, w_l, w_u = th.lu_unpack(*th.lu(q))

            #w_s = np.diag(w_u)
            w_s = th.diag(w_u)
            #w_u = np.triu(w_u, 1)
            w_u = th.triu(w_u, 1)
            #u_mask = np.triu(np.ones_like(w_u), 1)
            u_mask = th.triu(th.ones_like(w_u), 1)
            #l_mask = u_mask.T
            l_mask = u_mask.t()

            #w_p = th.from_numpy(w_p)
            #w_l = th.from_numpy(w_l)w
            #w_s = th.from_numpy(w_s)
            #w_u = th.from_numpy(w_u)

            self.register_buffer('w_p', w_p)
            self.register_buffer('u_mask', u_mask)
            self.register_buffer('l_mask', l_mask)
            self.register_buffer('s_sign', th.sign(w_s))
            self.register_buffer('l_eye', th.eye(l_mask.shape[0]))
            self.w_l = nn.Parameter(w_l)
            self.w_s = nn.Parameter(th.log(th.abs(w_s)))
            self.w_u = nn.Parameter(w_u)
예제 #15
0
    def __init__(self, c):
        super(Invertible1x1ConvLUS, self).__init__()
        # Sample a random orthonormal matrix to initialize weights
        W, _ = torch.linalg.qr(torch.randn(c, c))
        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:, 0] = -1 * W[:, 0]
        p, lower, upper = torch.lu_unpack(*torch.lu(W))

        self.register_buffer('p', p)
        # diagonals of lower will always be 1s anyway
        lower = torch.tril(lower, -1)
        lower_diag = torch.diag(torch.eye(c, c))
        self.register_buffer('lower_diag', lower_diag)
        self.lower = nn.Parameter(lower)
        self.upper_diag = nn.Parameter(torch.diag(upper))
        self.upper = nn.Parameter(torch.triu(upper, 1))
예제 #16
0
def safe_solve_with_mask(B: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""Helper function, which avoids crashing because of singular matrix input and outputs the
    mask of valid solution"""
    if not torch_version_geq(1, 10):
        sol, lu = _torch_solve_cast(B, A)
        warnings.warn('PyTorch version < 1.10, solve validness mask maybe not correct', RuntimeWarning)
        return sol, lu, torch.ones(len(A), dtype=torch.bool, device=A.device)
    # Based on https://github.com/pytorch/pytorch/issues/31546#issuecomment-694135622
    if not isinstance(B, torch.Tensor):
        raise AssertionError(f"B must be torch.Tensor. Got: {type(B)}.")
    dtype: torch.dtype = B.dtype
    if dtype not in (torch.float32, torch.float64):
        dtype = torch.float32
    A_LU, pivots, info = torch.lu(A.to(dtype), get_infos=True)
    valid_mask: torch.Tensor = info == 0
    X = torch.lu_solve(B.to(dtype), A_LU, pivots)
    return X.to(B.dtype), A_LU.to(A.dtype), valid_mask
예제 #17
0
    def __init__(self, dim):
        super(Invertible1x1Conv, self).__init__()
        self.dim = dim

        # Grab the weight and bias from a randomly initialized Conv2d.
        m = nn.Conv2d(dim, dim, kernel_size=1)
        W = m.weight.clone().detach().reshape(dim, dim)
        LU, pivots = torch.lu(W)
        P, _, _ = torch.lu_unpack(LU, pivots)

        s = torch.diag(LU)
        # noinspection PyTypeChecker
        LU = torch.where(torch.eye(dim) == 0, LU, torch.zeros_like(LU))

        self.register_buffer("P", P)
        self.register_buffer("s_sign", torch.sign(s))
        self.register_parameter("s_log", nn.Parameter(torch.log(torch.abs(s) + 1e-3)))
        self.register_parameter("LU", nn.Parameter(LU))
예제 #18
0
def test_dcem():
    n_batch = 2
    n_sample = 100
    N = 2

    torch.manual_seed(0)
    Q = torch.eye(N).unsqueeze(0).repeat(n_batch, 1, 1)
    p = 0.1 * torch.randn(n_batch, N)

    Q_sample = Q.unsqueeze(1).repeat(1, n_sample, 1,
                                     1).view(n_batch * n_sample, N, N)
    p_sample = p.unsqueeze(1).repeat(1, n_sample,
                                     1).view(n_batch * n_sample, N)

    def f(X):
        assert X.size() == (n_batch, n_sample, N)
        X = X.view(n_batch * n_sample, N)
        obj = 0.5 * (bmv(Q_sample, X) * X).sum(dim=1) + (p_sample *
                                                         X).sum(dim=1)
        obj = obj.view(n_batch, n_sample)
        return obj

    def iter_cb(i, X, fX, I, X_I, mu, sigma):
        print(fX.mean(dim=1))

    xhat = dcem(
        f,
        nx=N,
        n_batch=n_batch,
        n_sample=n_sample,
        n_elite=50,
        n_iter=40,
        temp=1.,
        normalize=True,
        # temp = np.infty,
        init_sigma=1.,
        iter_cb=iter_cb,
        # lb = -5., ub = 5.,
    )

    Q_LU = torch.lu(Q)
    xstar = -torch.lu_solve(p, *Q_LU)
    assert (xhat - xstar).abs().max() < 1e-4
예제 #19
0
    def __init__(self, dim_inputs):
        super().__init__()
        self.dim_inputs = dim_inputs
        w_shape = [dim_inputs, dim_inputs]
        w_init = torch.qr(torch.randn(w_shape))[0]

        p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
        s = torch.diag(upper)
        sign_s = torch.sign(s)
        log_s = torch.log(torch.abs(s))
        upper = torch.triu(upper, 1)
        l_mask = torch.tril(torch.ones(w_shape), -1)
        eye = torch.eye(*w_shape)

        self.register_buffer('p', p)
        self.register_buffer('sign_s', sign_s)
        self.lower = nn.Parameter(lower)
        self.log_s = nn.Parameter(log_s)
        self.upper = nn.Parameter(upper)
        self.register_buffer('l_mask', l_mask)
        self.register_buffer('eye', eye)
예제 #20
0
def xpbatch_lu_factor(A):
    assert len(A.shape) == 3, "Actual" + str(A.shape)
    assert A.shape[1] == A.shape[2], "Actual" + str(A.shape)
    xp = get_array_module(A)
    A = copy.deepcopy(A)
    '''
    if cupy_available and xp == cupy:
        Ps = xp.empty((A.shape[0], A.shape[1]), dtype=np.int32)
        for i in range(len(A)):
            A[i], Ps[i] = cupyx.scipy.linalg.u_factor(A[i], overwrite_a=True)
        return A, Ps
    else:
        Ps = []
        for i in range(len(A)):
            A[i], piv = scipy.linalg.lu_factor(A[i], overwrite_a=True)
            Ps.append(piv)
        return A, Ps
    '''
    # use PyTorch here, because scipy does not offer batch lu_factorization
    A_LU, pivots = torch.lu(torch.tensor(A))
    return A_LU.cpu().numpy(), pivots.cpu().numpy()
예제 #21
0
    def __init__(self, num_channels, lu_decomposition=False):
        """
        Invertible 1x1 convolution layer

        :param num_channels: number of channels
        :type num_channels: int
        :param lu_decomposition: whether to use LU decomposition
        :type lu_decomposition: bool
        """
        super().__init__()
        self.num_channels = num_channels
        self.lu_decomposition = lu_decomposition
        w_shape = [num_channels, num_channels]
        tolerance = 1e-4
        # Sample a random orthogonal matrix
        w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32')
        if self.lu_decomposition:
            w_LU_pts, pivots = torch.lu(torch.Tensor(w_init))
            p, w_l, w_u = torch.lu_unpack(w_LU_pts, pivots)
            s = torch.diag(torch.diag(w_u))
            w_u -= s
            print((torch.Tensor(w_init) -
                   torch.matmul(p, torch.matmul(w_l, (w_u + s)))).abs().sum())
            assert (torch.Tensor(w_init) - torch.matmul(
                p, torch.matmul(w_l, (w_u + s)))).abs().sum() < tolerance
            self.register_parameter('weight_L',
                                    nn.Parameter(torch.Tensor(w_l)))
            self.register_parameter('weight_U',
                                    nn.Parameter(torch.Tensor(w_u)))
            self.register_parameter('s', nn.Parameter(torch.Tensor(s)))
            self.register_buffer('weight_P', torch.FloatTensor(p))

        else:
            #w_shape = [num_channels, num_channels]
            # Sample a random orthogonal matrix
            #w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32')
            self.register_parameter('weight',
                                    nn.Parameter(torch.Tensor(w_init)))
예제 #22
0
    def __init__(self, in_channel, weight=None):
        # in_channel indicates the channel size of the input
        super().__init__()
        self.in_channel = in_channel

        # set an random orthogonal matrix as the initial weight
        if weight is None:
            weight = torch.randn(in_channel, in_channel)
        Weight, _ = torch.qr(weight)

        # PLU decomposition
        self.weight = Weight  # only for testing

        Weight_LU, pivots = torch.lu(Weight)
        w_p, w_l, w_u = torch.lu_unpack(Weight_LU, pivots)
        w_s = torch.diag(w_u)
        w_logs = torch.log(torch.abs(w_s))
        s_sign = torch.sign(w_s)

        w_u = torch.triu(w_u, 1)

        u_mask = torch.triu(torch.ones_like(w_u), 1)
        l_mask = u_mask.T

        l_eye = torch.eye(l_mask.shape[0])

        # fix P
        self.register_buffer("w_p", w_p)
        self.register_buffer("u_mask", u_mask)
        self.register_buffer("l_mask", l_mask)
        self.register_buffer("s_sign", s_sign)
        self.register_buffer("l_eye", l_eye)

        self.w_l = torch.nn.Parameter(w_l)
        self.w_u = torch.nn.Parameter(w_u)
        # self.w_s = torch.nn.Parameter(w_s)
        self.w_logs = torch.nn.Parameter(w_logs)
예제 #23
0
def rank_revealing_LUP_GPU(A, threshold=10**-10):
    """
    It returns the rank of A using GPU acceleration. This is based on PyTorch's LUP implementation.
    
    Parameters
    __________

    A: (2D float64 Numpy Array)
    Input array for which rank needs to be computed

    threshold: (Float)
    See documentation of the function "sort_rows_fast"


    Returns:
    _________

    rank: (Integer)
    rank of A

    """

    try:
        import torch
    except ImportError as e:
        print(str(e))
        print("rank_revealing_LUP_GPU requires PyTorch to run.")

    A = A / np.abs(A).max()
    A_tensor = torch.from_numpy(A).cuda()
    A_LU, pivots = torch.lu(A_tensor, get_infos=False)
    _, _, u = torch.lu_unpack(A_LU, pivots, unpack_pivots=False)
    u = u.cpu().numpy()
    ref = compute_ref_LUP_fast(u, threshold)
    rank = compute_rank_ref(ref, threshold)

    return rank
예제 #24
0
    def corrector(self, w0, X0, dw, dX, ds):

        self.aft_method.process(X0)
        FE = self.force_ex.get_f(w0)
        FN = self.aft_method.get_vector().detach()

        # start = time.time()
        dFdX = self.aft_method.get_jacobian().detach()
        # end = time.time()
        # print("differ:  " + str(end - start))

        A0 = self.model.get_A(w0).detach()
        dAdw = self.model.get_DADw(w0).detach() / self.weight_w
        dFEdw = self.force_ex.get_dfdw(w0) / self.weight_w

        RX = torch.matmul(
            A0, X0.unsqueeze(-1)) + FN.unsqueeze(-1) - FE.unsqueeze(-1)
        err = float(torch.norm(RX, 1).detach().numpy())

        dRdX = A0 + dFdX
        dRdw = torch.matmul(dAdw,
                            X0.unsqueeze(-1)).detach() - dFEdw.unsqueeze(-1)

        t, J = self.get_tau(dRdX, dRdw)

        J_ = torch.cat((J, t.unsqueeze(-1).permute(0, 2, 1)), dim=1)
        Rw = torch.sum(t * torch.cat(
            (dX, self.weight_w * dw.view(-1, 1)), dim=1)) - ds

        D_lu = torch.lu(J_)
        d = torch.cat((RX, Rw.unsqueeze(0).unsqueeze(0).unsqueeze(0)),
                      dim=1).detach().lu_solve(*D_lu)
        dw = dw - d[:, -1] / self.weight_w
        dX = dX - d[:, :-1].squeeze(-1)

        return dw.detach(), dX.detach(), err
예제 #25
0
파일: expm.py 프로젝트: RedekopEP/iunets
def _expm_frechet_pade(A, E, m=7):

    assert(m in [3,5,7,9,13])
    
    if m == 3:
        b = [120., 60., 12., 1.]
    elif m == 5:
        b = [30240., 15120., 3360., 420., 30., 1.]
    elif m == 7:
        b = [17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.]
    elif m == 9:
        b = [17643225600., 8821612800., 2075673600., 302702400., 30270240., 
             2162160., 110880., 3960., 90., 1.]
    elif m == 13:
        b = [64764752532480000., 32382376266240000., 7771770303897600.,
             1187353796428800., 129060195264000., 10559470521600.,
             670442572800., 33522128640., 1323241920., 40840800., 960960.,
             16380., 182., 1.]

    # Efficiently compute series terms of M_i (and A_i if needed).
    # Not very pretty, but more readable than the alternatives.
    I = _eye_like(A)
    if m!=13:
        if m >= 3:
            M_2 = A @ E + E @ A
            A_2 = A @ A 
            U = b[3]*A_2
            V = b[2]*A_2
            L_U = b[3]*M_2
            L_V = b[2]*M_2
        if m >= 5:
            M_4 = A_2 @ M_2 + M_2 @ A_2
            A_4 = A_2 @ A_2
            U = U + b[5]*A_4
            V = V + b[4]*A_4
            L_U = L_U + b[5]*M_4
            L_V = L_V + b[4]*M_4
        if m >= 7:
            M_6 = A_4 @ M_2 + M_4 @ A_2
            A_6 = A_4 @ A_2
            U = U + b[7]*A_6
            V = V + b[6]*A_6
            L_U = L_U + b[7]*M_6
            L_V = L_V + b[6]*M_6
        if m == 9:
            M_8 = A_4 @ M_4 + M_4 @ A_4
            A_8 = A_4 @ A_4
            U = U + b[9]*A_8
            V = V + b[8]*A_8
            L_U = L_U + b[9]*M_8
            L_V = L_V + b[8]*M_8
            
        U = U + b[1]*I
        V = U + b[0]*I
        del I

        L_U = A @ L_U
        L_U = L_U + E @ U

        U = A @ U
            
    else:
        M_2 = A @ E + E @ A
        A_2 = A @ A 
        
        M_4 = A_2 @ M_2 + M_2 @ A_2
        A_4 = A_2 @ A_2
        
        M_6 = A_4 @ M_2 + M_4 @ A_2
        A_6 = A_4 @ A_2
        
        W_1 = b[13]*A_6 + b[11]*A_4 + b[9]*A_2 
        W_2 = b[7]*A_6 + b[5]*A_4 + b[3]*A_2 + b[1]*I
        
        W = A_6 @ W_1 + W_2

        Z_1 = b[12]*A_6 + b[10]*A_4 + b[8]*A_2
        Z_2 = b[6]*A_6 + b[4]*A_4 + b[2]*A_2 + b[0]*I
        
        U = A @ W
        V = A_6 @ Z_1 + Z_2
        
        L_W1 = b[13]*M_6 + b[11]*M_4 + b[9]*M_2
        L_W2 = b[7]*M_6 + b[5]*M_4 + b[3]*M_2
        
        L_Z1 = b[12]*M_6 + b[10]*M_4 + b[8]*M_2
        L_Z2 = b[6]*M_6 + b[4]*M_4 + b[2]*M_2
        
        L_W = A_6 @ L_W1 + M_6 @ W_1 + L_W2
        L_U = A @ L_W + E @ W   
        L_V = A_6 @ L_Z1 + M_6 @ Z_1 + L_Z2


    lu_decom = torch.lu(-U + V)
    exp_A = torch.lu_solve(U + V, *lu_decom)
    dexp_A = torch.lu_solve(L_U + L_V + (L_U - L_V) @ exp_A, *lu_decom)
     
    return exp_A, dexp_A
예제 #26
0
파일: expm.py 프로젝트: RedekopEP/iunets
def _expm_pade(A, m=7):
    assert(m in [3,5,7,9,13])
    
    # The following are values generated as 
    # b = torch.tensor([_fraction(m, k) for k in range(m+1)]),
    # but reduced to natural numbers such that b[-1]=1. This still works,
    # because the same constants are used in the numerator and denominator
    # of the Padé approximation.
    if m == 3:
        b = [120., 60., 12., 1.]
    elif m == 5:
        b = [30240., 15120., 3360., 420., 30., 1.]
    elif m == 7:
        b = [17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.]
    elif m == 9:
        b = [17643225600., 8821612800., 2075673600., 302702400., 30270240., 
             2162160., 110880., 3960., 90., 1.]
    elif m == 13:
        b = [64764752532480000., 32382376266240000., 7771770303897600., 1187353796428800., 
             129060195264000., 10559470521600., 670442572800., 33522128640., 1323241920., 
             40840800., 960960., 16380., 182., 1.]
    
    
    # pre-computing terms
    I = _eye_like(A)
    if m!=13: # There is a more efficient algorithm for m=13
        U = b[1]*I
        V = b[0]*I
        if m >= 3:
            A_2 = A @ A
            U = U + b[3]*A_2
            V = V + b[2]*A_2
        if m >= 5:
            A_4 = A_2 @ A_2
            U = U + b[5]*A_4
            V = V + b[4]*A_4
        if m >= 7:
            A_6 = A_4 @ A_2
            U = U + b[7]*A_6
            V = V + b[6]*A_6
        if m == 9: 
            A_8 = A_4 @ A_4
            U = U + b[9]*A_8
            V = V + b[8]*A_8
        U = A @ U
    else:
        A_2 = A @ A 
        A_4 = A_2 @ A_2
        A_6 = A_4 @ A_2
        
        W_1 = b[13]*A_6 + b[11]*A_4 + b[9]*A_2
        W_2 = b[7]*A_6 + b[5]*A_4 + b[3]*A_2 + b[1]*I
        W = A_6 @ W_1 + W_2
        
        Z_1 = b[12]*A_6 + b[10]*A_4 + b[8]*A_2
        Z_2 = b[6]*A_6 + b[4]*A_4 + b[2]*A_2 + b[0]*I
        
        U = A @ W
        V = A_6 @ Z_1 + Z_2
    
    del A_2
    if m>=5: del A_4
    if m>=7: del A_6
    if m==9: del A_8
    
    R = torch.lu_solve(U+V, *torch.lu(-U+V))

    del U, V
    return R
예제 #27
0
def pnqp(H, q, lower, upper, x_init=None, n_iter=20):
    GAMMA = 0.1
    n_batch, n, _ = H.size()
    pnqp_I = 1e-11 * torch.eye(n).type_as(H).expand_as(H)

    def obj(x):
        return 0.5 * util.bquad(x, H) + util.bdot(q, x)

    if x_init is None:
        if n == 1:
            x_init = -(1. / H.squeeze(2)) * q
        else:
            # H_lu = H.btrifact()  # XXX deprecated!!!
            H_lu = torch.lu(H)
            x_init = -q.btrisolve(H_lu[0],
                                  H_lu[1])  # Clamped in the x assignment.
    else:
        x_init = x_init.clone()  # Don't over-write the original x_init.

    x = util.eclamp(x_init, lower, upper)

    # Active examples in the batch.
    J = torch.ones(n_batch).type_as(x).byte()

    for i in range(n_iter):
        g = util.bmv(H, x) + q
        Ic = ((x == lower) & (g > 0)) | ((x == upper) & (g < 0))
        If = 1 - Ic

        if If.is_cuda:
            Hff_I = util.bger(If.float(), If.float()).type_as(If)
            not_Hff_I = 1 - Hff_I
            Hfc_I = util.bger(If.float(), Ic.float()).type_as(If)
        else:
            Hff_I = util.bger(If, If)
            not_Hff_I = 1 - Hff_I
            Hfc_I = util.bger(If, Ic)

        g_ = g.clone()
        g_[Ic] = 0.
        H_ = H.clone()
        H_[not_Hff_I] = 0.0
        H_ += pnqp_I

        if n == 1:
            dx = -(1. / H_.squeeze(2)) * g_
        else:
            # H_lu_ = H_.btrifact()  # XXX deprecated!!!
            H_lu_ = torch.lu(H)
            dx = -g_.btrisolve(*H_lu_)

        J = torch.norm(dx, 2, 1) >= 1e-4
        m = J.sum().item()  # Number of active examples in the batch.
        if m == 0:
            return x, H_ if n == 1 else H_lu_, If, i

        alpha = torch.ones(n_batch).type_as(x)
        decay = 0.1
        max_armijo = GAMMA
        count = 0
        while max_armijo <= GAMMA and count < 10:
            # Crude way of making sure too much time isn't being spent
            # doing the line search.
            # assert count < 10

            maybe_x = util.eclamp(x + torch.diag(alpha).mm(dx), lower, upper)
            armijos = (GAMMA + 1e-6) * torch.ones(n_batch).type_as(x)
            armijos[J] = (obj(x) - obj(maybe_x))[J] / util.bdot(
                g, x - maybe_x)[J]
            I = armijos <= GAMMA
            alpha[I] *= decay
            max_armijo = torch.max(armijos)
            count += 1

        x = maybe_x

    # TODO: Maybe change this to a warning.
    print("[WARNING] pnqp warning: Did not converge")
    return x, H_ if n == 1 else H_lu_, If, i
예제 #28
0
def newton_exact(f, g, x_guess, opt_params, ls_method, ls_params):
    """
    This function performs gradient descent using newton's method as the search direction

    INPUTS:
        f < function > : objective function f(x) -> f
        g < function > : gradient function g(x) -> g
        x_guess < tensor > : initial x
        opt_params < dict{
            'ep_g' < float > : conv. tolerance on gradient
            'ep_a' < float > : absolute tolerance
            'ep_r' < float > : relative tolerance
            'Hessian' < function > : function that returns the Hessian
            'iter_lim' < int > : iteration limit
        } > : dictionary of optimization settings
        ls_method < str > : indicates which method to use with line search
        ls_params < dict > : dictionary with parameters to use for line search
    """
    ep_g = opt_params['ep_g']
    ep_a = opt_params['ep_a']
    ep_r = opt_params['ep_r']
    H = opt_params['Hessian']
    iter_lim = opt_params['iter_lim']

    # initializations
    x_k = x_guess
    x_hist = [x_k]
    f_k = f(x_guess)
    f_hist = [f_k]
    k = 0
    conv_count = 0

    # how many iterations for rel. abs. tolerance met before stopping
    conv_count_max = 2

    while k < iter_lim:
        k += 1

        # compute gradient
        g_k = g(x_k)

        # check for gradient convergence
        if torch.norm(g_k) <= ep_g:
            converge = True
            message = "Exact Newton converged due to grad. tolerance."
            break

        # invert Hessian and find search direction
        H_k = H(x_k)
        H_LU, pivots, infos = torch.lu(H_k.reshape(
            [1, H_k.shape[0], -1]), get_infos=True)

        if infos.nonzero().size(0) != 0:
            # check if LU factorization failed
            converge = False
            message = "Hessian LU factorization failed."
            break

        # LU solve is designed for batch operations, hence the [0]
        delta_k = torch.lu_solve(-g_k.unsqueeze(0), H_LU, pivots)[0]

        if torch.matmul(delta_k.t(), g_k) < 0:
            p_k = delta_k
        else:
            p_k = -delta_k

        # perform line search
        alf, ls_converge, ls_message = line_search(f, x_k, g_k, p_k,
                                                   ls_method=ls_method, ls_params=ls_params)
        if not ls_converge:
            converge = ls_converge
            message = ls_message
            break

        # compute x_(k+1)
        x_k1, f_k1 = search_step(f, x_k, alf, p_k)

        # check relative and absolute convergence criteria
        if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
            conv_count += 1

        x_k = x_k1
        f_k = f_k1

        x_hist.append(x_k)
        f_hist.append(f_k)

        if conv_count >= conv_count_max:
            converge = True
            message = "Exact Newton converged due to abs. rel. tolerance."
            break

    if k == iter_lim:
        converge = False
        message = "Exact Newton iteration limit reached."

    return x_k, f_k, x_hist, f_hist, converge, message
예제 #29
0
def quad_search(f, x_k, g_k, p_k, ls_params):
    """
    This function performs approximate quadratic line search

    INPUTS:
        f < function > : objective function f(x) -> f
        x_k < tensor > : current best guess for f(x) minimum
        g_k < tensor > : gradient evaluated at x_k
        p_k < tensor > : search direction
        alf < float > : initial step length
        ls_params < dict{
            'alf' < float > : initial guess for step-length
            'mu' < float > : small positive constant used in "Armijo suff. decrease condition"
            'rho' < float > : step-size dicount coefficient
            'iter_lim < int > : iteration limit for solver
            'alf_lower_coeff' : coefficient for determining point one in quad_search
            'alf_upper_coeff' : coefficient for determining point three in quad_search
        } > : dictionary with parameters to use for line search

    RETURNS:
        alf_new < float > : computed search length
        converge < bool > : bool indicating whether line search converged
        message < string > : string with output from back tracking method
    """
    mu = ls_params['mu']
    iter_lim = ls_params['iter_lim']
    alf_new = ls_params['alf']
    alf_coeff1 = ls_params['alf_lower_coeff']
    alf_coeff2 = ls_params['alf_upper_coeff']
    iter = 0
    while not armijo_suff_decrease(f, x_k, g_k, p_k, alf_new,
                                   mu) and iter < iter_lim:
        a1 = alf_new
        a2 = alf_new * 0.1
        a3 = alf_new * 2.0
        f1 = f(x_k + a1 * p_k)
        f2 = f(x_k + a2 * p_k)
        f3 = f(x_k + a3 * p_k)

        A = torch.tensor([[1 / 2 * a1.pow(2), a1,
                           1], [1 / 2 * a2.pow(2), a2, 1],
                          [1 / 2 * a3.pow(2), a3, 1]])
        b = torch.tensor([[f1], [f2], [f3]])

        A_LU, pivots, infos = torch.lu(A.reshape([1, A.shape[0], -1]),
                                       get_infos=True)

        if infos.nonzero().size(0) != 0:
            converge = False
            message = "Quadratic approx was not possible."
            break

        coeff = torch.lu_solve(b.unsqueeze(0), A_LU, pivots)[0]
        alf_new = -coeff[1] / coeff[0]
        iter += 1

    if iter == iter_lim:
        converge = False
        message = "Quadratic approx. line search iteration limit reached."
    else:
        converge = True
        message = "Quadratic approx. line search converged."

    return alf_new, converge, message