Ejemplo n.º 1
0
    def weight(self, ext, module, g_inp, g_out, backproped):

        I = module.input0
        n = g_out[0].shape[0]
        g_out_sc = n * g_out[0]
        G = g_out_sc
        grad = module.weight.grad

        B = einsum("ni,li->nl", (I, I))
        A = einsum("no,lo->nl", (G, G))

        # compute vector jacobian product in optimization method
        grad_prod = einsum("ni,oi->no", (I, grad))
        grad_prod = einsum("no,no->n", (grad_prod, G))
        # grad_prod = 0
        out = A * B
        # out = 0
        NGD_kernel = out / n
        NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device))
        v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()

        gv = einsum("n,no->no", (v, G))
        gv = einsum("no,ni->oi", (gv, I))
        gv = gv / n

        update = (grad - gv) / self.damping

        module.I = I
        module.G = G
        module.NGD_inv = NGD_inv
        return update
Ejemplo n.º 2
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        n = g_out[0].shape[0]
        g_out_sc = n * g_out[0]
        G = g_out_sc

        I = module.input0
        mean = I.mean(dim=0)
        var = I.var(dim=0, unbiased=False)
        xhat = (I - mean) / (var + module.eps).sqrt()
        dw = g_out_sc * xhat

        # compute vector jacobian product in optimization method
        grad = module.weight.grad
        grad_prod = einsum("nk,k->n", (dw, grad))

        out = matmul(dw, dw.t())
        NGD_kernel = out / n
        NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device))
        v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()

        # gv = einsum("n,nk->k", (v, G))
        ### multiply with Jacobian
        gv = einsum("n,nk->k", (v, dw))
        gv = gv / n

        update = (grad - gv)/self.damping

        module.dw = dw
        module.NGD_inv = NGD_inv

        return (out, grad_prod, update)
Ejemplo n.º 3
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        # print(g_out)

        # check if there are stored variables:
        # if hasattr(module, "I"):
        # this is a sampling technique
        # inp = module.I
        # l = inp.shape[0]
        #     prob = 0.1
        #     l_new = int(np.floor(prob * l))

        #     # print('input to linear layer before droput:', inp.shape)
        #     Borg = einsum("ni,li->nl", (inp, inp))

        #     if inp.shape[1] > 7000:
        #         inp =  inp[:, torch.randint(l, (l_new,))]

        #     B =  einsum("ni,li->nl", (inp, inp)) / ( prob)

        I = module.input0
        n = g_out[0].shape[0]
        g_out_sc = n * g_out[0]
        G = g_out_sc
        grad = module.weight.grad

        B = einsum("ni,li->nl", (I, I))
        A = einsum("no,lo->nl", (G, G))

        # compute vector jacobian product in optimization method
        grad_prod = einsum("ni,oi->no", (I, grad))
        grad_prod = einsum("no,no->n", (grad_prod, G))
        # grad_prod = 0
        out = A * B
        # out = 0
        NGD_kernel = out / n
        NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device))
        v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()

        gv = einsum("n,no->no", (v, G))
        gv = einsum("no,ni->oi", (gv, I))
        gv = gv / n

        update = (grad - gv) / self.damping
        # update = grad

        # store for later use:
        # module.A = A
        # module.B = B
        # module.out = out
        module.I = I
        module.G = G
        module.NGD_inv = NGD_inv
        return (out, grad_prod, update)
    def __init__(self, m, c, k, dt, **kwargs):
        super(RungeKuttaIntegratorCell, self).__init__(**kwargs)
        self.Minv = linalg.inv(diag(m))
        self.c1 = Parameter(c[0])
        self.c2 = Parameter(c[1])
        self.c3 = Parameter(c[2])

        self.K = Tensor([[k[0] + k[1], -k[1]], [-k[1], k[1] + k[2]]])
        self.state_size = 2 * len(m)
        self.A = Tensor([0., 0.5, 0.5, 1.0])
        self.B = Tensor([[1 / 6, 2 / 6, 2 / 6, 1 / 6]])
        self.dt = dt
Ejemplo n.º 5
0
    def bias(self, ext, module, g_inp, g_out, backproped):
        n = g_out[0].shape[0]
        g_out_sc = n * g_out[0]

        # compute vector jacobian product in optimization method
        grad = module.bias.grad
        grad_prod = einsum("no,o->n", (g_out_sc, grad))

        out = einsum("no,lo->nl", g_out_sc, g_out_sc)

        NGD_kernel = out / n
        NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device))
        v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
        gv = einsum("n,no->o", (v, g_out_sc))
        gv = gv / n

        update = (grad - gv)/self.damping

        return (out, grad_prod, update)
Ejemplo n.º 6
0
    def weight(self, ext, module, g_inp, g_out, bpQuantities):
        if MODE == 0:  # my implementation

            grad = module.weight.grad
            # print(grad.shape)
            grad_reshape = grad.reshape(grad.shape[0], -1)
            n = g_out[0].shape[0]
            g_out_sc = n * g_out[0]

            I = unfold_func(module)(module.input0)
            grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0],
                                                  g_out_sc.shape[1], -1)
            G = grad_output_viewed

            N = I.shape[0]
            K = I.shape[1]
            L = I.shape[2]
            M = G.shape[1]
            # print(N,K,L,M)
            flag = False
            if self.super_opt == 'true':
                flag = N * (L * L) * (K + M) < K * M * L + N * K * M
            else:
                flag = (L * L) * (K + M) < K * M

            if flag == True:
                II = einsum("nkl,qkp->nqlp", (I, I))
                GG = einsum("nml,qmp->nqlp", (G, G))
                out = einsum('nqlp->nq', II * GG)
                x1 = einsum("nkl,mk->nml", (I, grad_reshape))
                grad_prod = einsum("nml,nml->n", (x1, G))
                NGD_kernel = out / n
                NGD_inv = inv(NGD_kernel +
                              self.damping * eye(n).to(grad.device))
                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                gv = einsum("n,nml->nml", (v, G))
                gv = einsum("nml,nkl->mk", (gv, I))
                gv = gv.view_as(grad)
                gv = gv / n

                module.NGD_inv = NGD_inv
                if self.memory_efficient == 'true':
                    module.I = module.input0
                else:
                    module.I = I
                module.G = G
                del I
                del G
                del II
                del GG
                empty_cache()

            else:
                AX = einsum("nkl,nml->nkm", (I, G))
                del I
                del G
                AX_ = AX.reshape(n, -1)
                NGD_kernel = matmul(AX_, AX_.t()) / n
                ### testing low-rank
                if self.low_rank == 'true':
                    V, S, U = svd(AX_.T, compute_uv=True, full_matrices=False)
                    U = U.t()
                    V = V.t()

                    cs = cumsum(S, dim=0)
                    sum_s = sum(S)
                    index = ((cs - self.gamma * sum_s) <= 0).sum()
                    U = U[:, 0:index]
                    S = S[0:index]
                    V = V[0:index, :]

                    module.U = U
                    module.S = S
                    module.V = V
                del AX_

                grad_prod = einsum("nkm,mk->n", (AX, grad_reshape))

                NGD_inv = inv(NGD_kernel +
                              self.damping * eye(n).to(grad.device))
                module.NGD_inv = NGD_inv
                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                del NGD_inv
                empty_cache()
                gv = einsum("nkm,n->mk", (AX, v)).view_as(grad) / n
                module.AX = AX

            update = (grad - gv) / self.damping
            return update
Ejemplo n.º 7
0
    def weight(self, ext, module, g_inp, g_out, bpQuantities):
        if MODE == 0:  # my implementation

            grad = module.weight.grad
            # print(grad.shape)
            grad_reshape = grad.reshape(grad.shape[0], -1)
            n = g_out[0].shape[0]
            g_out_sc = n * g_out[0]

            input = unfold_func(module)(module.input0)
            I = input
            grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0],
                                                  g_out_sc.shape[1], -1)
            G = grad_output_viewed

            N = I.shape[0]
            K = I.shape[1]
            L = I.shape[2]
            M = G.shape[1]
            # print(N,K,L,M)
            if (L * L) * (K + M) < K * M:
                II = einsum("nkl,qkp->nqlp", (I, I))
                GG = einsum("nml,qmp->nqlp", (G, G))
                out = einsum('nqlp->nq', II * GG)
                x1 = einsum("nkl,mk->nml", (I, grad_reshape))
                grad_prod = einsum("nml,nml->n", (x1, G))
                NGD_kernel = out / n
                NGD_inv = inv(NGD_kernel +
                              self.damping * eye(n).to(grad.device))
                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                gv = einsum("n,nml->nml", (v, G))
                gv = einsum("nml,nkl->mk", (gv, I))
                gv = gv.view_as(grad)
                gv = gv / n

                module.NGD_inv = NGD_inv
                if self.memory_efficient == 'true':
                    module.I = module.input0
                else:
                    module.I = I
                module.G = G

            else:
                AX = einsum("nkl,nml->nkm", (I, G))
                AX_ = AX.reshape(n, -1)
                out = matmul(AX_, AX_.t())
                grad_prod = einsum("nkm,mk->n", (AX, grad_reshape))

                NGD_kernel = out / n
                NGD_inv = inv(NGD_kernel +
                              self.damping * eye(n).to(grad.device))
                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                gv = einsum("nkm,n->mk", (AX, v))
                gv = gv.view_as(grad)
                gv = gv / n

                module.NGD_inv = NGD_inv
                module.AX = AX

                ### testing low-rank
                if self.low_rank == 'true':
                    V, S, U = svd(AX_.T, compute_uv=True, full_matrices=False)
                    U = U.t()
                    V = V.t()

                    cs = cumsum(S, dim=0)
                    sum_s = sum(S)
                    index = ((cs - self.gamma * sum_s) <= 0).sum()
                    U = U[:, 0:index]
                    S = S[0:index]
                    V = V[0:index, :]

                    module.U = U
                    module.S = S
                    module.V = V

            update = (grad - gv) / self.damping
            return (out, grad_prod, update)
        elif MODE == 2:
            # st = time.time()

            A = module.input0
            n = A.shape[0]
            p = 1
            M = g_out[0]

            M = M.reshape(M.shape[1] * M.shape[0], M.shape[2],
                          M.shape[3]).unsqueeze(1)
            A = A.permute(1, 0, 2, 3)
            output = conv2d(A, M, groups=n, padding=(p, p))
            output = output.permute(1, 0, 2, 3)
            output = output.reshape(n, -1)
            K_torch = matmul(output, output.t())
            # en = time.time()
            # print('Elapsed Time Conv2d Mode 2:', en - st)

            return K_torch
Ejemplo n.º 8
0
    def _update_inv(self, m):
        classname = m.__class__.__name__.lower()
        if classname == 'linear':
            assert(m.optimized == True)
            II = self.m_I[m][0]
            GG = self.m_G[m][0]
            n = II.shape[0]

            ### bias kernel is GG (II = all ones)
            bias_kernel = GG / n
            bias_inv = inv(bias_kernel + self.damping * eye(n).to(GG.device))
            self.m_bias_Kernel[m] = bias_inv

            NGD_kernel = (II * GG) / n
            if self.rand_svd:
                # U, S, Vh = torch.linalg.svd(NGD_kernel, full_matrices=False)
                U, S, Vh = self.rSVD(NGD_kernel, 50, 0, 20)
                # cs = torch.cumsum(S, dim=0)
                # cs_norm = cs / torch.sum(S)
                self.S_r[m] = inv(torch.diag(S) + self.damping * eye(S.shape[0]).to(II.device))
                self.V_r[m] =  Vh
            else:
                NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device))

            if not self.rand_svd:
                self.m_NGD_Kernel[m] = NGD_inv

            self.m_I[m] = (None, self.m_I[m][1])
            self.m_G[m] = (None, self.m_G[m][1])
            torch.cuda.empty_cache()
        elif classname == 'conv2d':
            # SAEED: @TODO: we don't need II and GG after computations, clear the memory
            if m.optimized == True:
                # print('=== optimized ===')
                II = self.m_I[m][0]
                GG = self.m_G[m][0]
                n = II.shape[0]

                NGD_kernel = None
                if self.reduce_sum == 'true':
                    if self.diag == 'true':
                        NGD_kernel = (II * GG / n)
                        NGD_inv = torch.reciprocal(NGD_kernel + self.damping)
                    else:
                        NGD_kernel = II * GG / n
                        if self.rand_svd:
                            _, S, Vh = self.rSVD(NGD_kernel, 50, 0, 20)
                            self.S_r[m] = inv(torch.diag(S) + self.damping * eye(S.shape[0]).to(II.device))
                            self.V_r[m] =  Vh
                        else:
                            NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device))
                        
                else:
                    NGD_kernel = (einsum('nqlp->nq', II * GG)) / n
                    NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device))

                if not self.rand_svd:
                    self.m_NGD_Kernel[m] = NGD_inv

                self.m_I[m] = (None, self.m_I[m][1])
                self.m_G[m] = (None, self.m_G[m][1])
                torch.cuda.empty_cache()
            else:
                # SAEED: @TODO memory cleanup
                I = self.m_I[m][1]
                G = self.m_G[m][1]
                n = I.shape[0]
                AX = einsum("nkl,nml->nkm", (I, G))

                del I
                del G

                AX_ = AX.reshape(n , -1)
                out = matmul(AX_, AX_.t())

                del AX

                NGD_kernel = out / n
                ### low-rank approximation of Jacobian
                if self.low_rank == 'true':
                    # print('=== low rank ===')
                    V, S, U = svd(AX_.T, full_matrices=False)
                    U = U.t()
                    V = V.t()
                    cs = cumsum(S, dim = 0)
                    sum_s = sum(S)
                    index = ((cs - self.gamma * sum_s) <= 0).sum()
                    U = U[:, 0:index]
                    S = S[0:index]
                    V = V[0:index, :]
                    self.m_UV[m] = U, S, V

                del AX_

                NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device))
                self.m_NGD_Kernel[m] = NGD_inv

                del NGD_inv
                self.m_I[m] = None, self.m_I[m][1]
                self.m_G[m] = None, self.m_G[m][1]
                torch.cuda.empty_cache()
Ejemplo n.º 9
0
 def compute_Sigma(self, X):
     n, dim = X.shape
     XjXj = th.matmul(X.T, X) + th.eye(dim, device=self.device)
     Sigma = linalg.inv(
         n * th.matmul(X.reshape(n, dim, 1), X.reshape(n, 1, dim)) + XjXj)
     return Sigma