def weight(self, ext, module, g_inp, g_out, backproped): I = module.input0 n = g_out[0].shape[0] g_out_sc = n * g_out[0] G = g_out_sc grad = module.weight.grad B = einsum("ni,li->nl", (I, I)) A = einsum("no,lo->nl", (G, G)) # compute vector jacobian product in optimization method grad_prod = einsum("ni,oi->no", (I, grad)) grad_prod = einsum("no,no->n", (grad_prod, G)) # grad_prod = 0 out = A * B # out = 0 NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->no", (v, G)) gv = einsum("no,ni->oi", (gv, I)) gv = gv / n update = (grad - gv) / self.damping module.I = I module.G = G module.NGD_inv = NGD_inv return update
def weight(self, ext, module, g_inp, g_out, backproped): n = g_out[0].shape[0] g_out_sc = n * g_out[0] G = g_out_sc I = module.input0 mean = I.mean(dim=0) var = I.var(dim=0, unbiased=False) xhat = (I - mean) / (var + module.eps).sqrt() dw = g_out_sc * xhat # compute vector jacobian product in optimization method grad = module.weight.grad grad_prod = einsum("nk,k->n", (dw, grad)) out = matmul(dw, dw.t()) NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() # gv = einsum("n,nk->k", (v, G)) ### multiply with Jacobian gv = einsum("n,nk->k", (v, dw)) gv = gv / n update = (grad - gv)/self.damping module.dw = dw module.NGD_inv = NGD_inv return (out, grad_prod, update)
def weight(self, ext, module, g_inp, g_out, backproped): # print(g_out) # check if there are stored variables: # if hasattr(module, "I"): # this is a sampling technique # inp = module.I # l = inp.shape[0] # prob = 0.1 # l_new = int(np.floor(prob * l)) # # print('input to linear layer before droput:', inp.shape) # Borg = einsum("ni,li->nl", (inp, inp)) # if inp.shape[1] > 7000: # inp = inp[:, torch.randint(l, (l_new,))] # B = einsum("ni,li->nl", (inp, inp)) / ( prob) I = module.input0 n = g_out[0].shape[0] g_out_sc = n * g_out[0] G = g_out_sc grad = module.weight.grad B = einsum("ni,li->nl", (I, I)) A = einsum("no,lo->nl", (G, G)) # compute vector jacobian product in optimization method grad_prod = einsum("ni,oi->no", (I, grad)) grad_prod = einsum("no,no->n", (grad_prod, G)) # grad_prod = 0 out = A * B # out = 0 NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->no", (v, G)) gv = einsum("no,ni->oi", (gv, I)) gv = gv / n update = (grad - gv) / self.damping # update = grad # store for later use: # module.A = A # module.B = B # module.out = out module.I = I module.G = G module.NGD_inv = NGD_inv return (out, grad_prod, update)
def __init__(self, m, c, k, dt, **kwargs): super(RungeKuttaIntegratorCell, self).__init__(**kwargs) self.Minv = linalg.inv(diag(m)) self.c1 = Parameter(c[0]) self.c2 = Parameter(c[1]) self.c3 = Parameter(c[2]) self.K = Tensor([[k[0] + k[1], -k[1]], [-k[1], k[1] + k[2]]]) self.state_size = 2 * len(m) self.A = Tensor([0., 0.5, 0.5, 1.0]) self.B = Tensor([[1 / 6, 2 / 6, 2 / 6, 1 / 6]]) self.dt = dt
def bias(self, ext, module, g_inp, g_out, backproped): n = g_out[0].shape[0] g_out_sc = n * g_out[0] # compute vector jacobian product in optimization method grad = module.bias.grad grad_prod = einsum("no,o->n", (g_out_sc, grad)) out = einsum("no,lo->nl", g_out_sc, g_out_sc) NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->o", (v, g_out_sc)) gv = gv / n update = (grad - gv)/self.damping return (out, grad_prod, update)
def weight(self, ext, module, g_inp, g_out, bpQuantities): if MODE == 0: # my implementation grad = module.weight.grad # print(grad.shape) grad_reshape = grad.reshape(grad.shape[0], -1) n = g_out[0].shape[0] g_out_sc = n * g_out[0] I = unfold_func(module)(module.input0) grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0], g_out_sc.shape[1], -1) G = grad_output_viewed N = I.shape[0] K = I.shape[1] L = I.shape[2] M = G.shape[1] # print(N,K,L,M) flag = False if self.super_opt == 'true': flag = N * (L * L) * (K + M) < K * M * L + N * K * M else: flag = (L * L) * (K + M) < K * M if flag == True: II = einsum("nkl,qkp->nqlp", (I, I)) GG = einsum("nml,qmp->nqlp", (G, G)) out = einsum('nqlp->nq', II * GG) x1 = einsum("nkl,mk->nml", (I, grad_reshape)) grad_prod = einsum("nml,nml->n", (x1, G)) NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,nml->nml", (v, G)) gv = einsum("nml,nkl->mk", (gv, I)) gv = gv.view_as(grad) gv = gv / n module.NGD_inv = NGD_inv if self.memory_efficient == 'true': module.I = module.input0 else: module.I = I module.G = G del I del G del II del GG empty_cache() else: AX = einsum("nkl,nml->nkm", (I, G)) del I del G AX_ = AX.reshape(n, -1) NGD_kernel = matmul(AX_, AX_.t()) / n ### testing low-rank if self.low_rank == 'true': V, S, U = svd(AX_.T, compute_uv=True, full_matrices=False) U = U.t() V = V.t() cs = cumsum(S, dim=0) sum_s = sum(S) index = ((cs - self.gamma * sum_s) <= 0).sum() U = U[:, 0:index] S = S[0:index] V = V[0:index, :] module.U = U module.S = S module.V = V del AX_ grad_prod = einsum("nkm,mk->n", (AX, grad_reshape)) NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) module.NGD_inv = NGD_inv v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() del NGD_inv empty_cache() gv = einsum("nkm,n->mk", (AX, v)).view_as(grad) / n module.AX = AX update = (grad - gv) / self.damping return update
def weight(self, ext, module, g_inp, g_out, bpQuantities): if MODE == 0: # my implementation grad = module.weight.grad # print(grad.shape) grad_reshape = grad.reshape(grad.shape[0], -1) n = g_out[0].shape[0] g_out_sc = n * g_out[0] input = unfold_func(module)(module.input0) I = input grad_output_viewed = g_out_sc.reshape(g_out_sc.shape[0], g_out_sc.shape[1], -1) G = grad_output_viewed N = I.shape[0] K = I.shape[1] L = I.shape[2] M = G.shape[1] # print(N,K,L,M) if (L * L) * (K + M) < K * M: II = einsum("nkl,qkp->nqlp", (I, I)) GG = einsum("nml,qmp->nqlp", (G, G)) out = einsum('nqlp->nq', II * GG) x1 = einsum("nkl,mk->nml", (I, grad_reshape)) grad_prod = einsum("nml,nml->n", (x1, G)) NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,nml->nml", (v, G)) gv = einsum("nml,nkl->mk", (gv, I)) gv = gv.view_as(grad) gv = gv / n module.NGD_inv = NGD_inv if self.memory_efficient == 'true': module.I = module.input0 else: module.I = I module.G = G else: AX = einsum("nkl,nml->nkm", (I, G)) AX_ = AX.reshape(n, -1) out = matmul(AX_, AX_.t()) grad_prod = einsum("nkm,mk->n", (AX, grad_reshape)) NGD_kernel = out / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(grad.device)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("nkm,n->mk", (AX, v)) gv = gv.view_as(grad) gv = gv / n module.NGD_inv = NGD_inv module.AX = AX ### testing low-rank if self.low_rank == 'true': V, S, U = svd(AX_.T, compute_uv=True, full_matrices=False) U = U.t() V = V.t() cs = cumsum(S, dim=0) sum_s = sum(S) index = ((cs - self.gamma * sum_s) <= 0).sum() U = U[:, 0:index] S = S[0:index] V = V[0:index, :] module.U = U module.S = S module.V = V update = (grad - gv) / self.damping return (out, grad_prod, update) elif MODE == 2: # st = time.time() A = module.input0 n = A.shape[0] p = 1 M = g_out[0] M = M.reshape(M.shape[1] * M.shape[0], M.shape[2], M.shape[3]).unsqueeze(1) A = A.permute(1, 0, 2, 3) output = conv2d(A, M, groups=n, padding=(p, p)) output = output.permute(1, 0, 2, 3) output = output.reshape(n, -1) K_torch = matmul(output, output.t()) # en = time.time() # print('Elapsed Time Conv2d Mode 2:', en - st) return K_torch
def _update_inv(self, m): classname = m.__class__.__name__.lower() if classname == 'linear': assert(m.optimized == True) II = self.m_I[m][0] GG = self.m_G[m][0] n = II.shape[0] ### bias kernel is GG (II = all ones) bias_kernel = GG / n bias_inv = inv(bias_kernel + self.damping * eye(n).to(GG.device)) self.m_bias_Kernel[m] = bias_inv NGD_kernel = (II * GG) / n if self.rand_svd: # U, S, Vh = torch.linalg.svd(NGD_kernel, full_matrices=False) U, S, Vh = self.rSVD(NGD_kernel, 50, 0, 20) # cs = torch.cumsum(S, dim=0) # cs_norm = cs / torch.sum(S) self.S_r[m] = inv(torch.diag(S) + self.damping * eye(S.shape[0]).to(II.device)) self.V_r[m] = Vh else: NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) if not self.rand_svd: self.m_NGD_Kernel[m] = NGD_inv self.m_I[m] = (None, self.m_I[m][1]) self.m_G[m] = (None, self.m_G[m][1]) torch.cuda.empty_cache() elif classname == 'conv2d': # SAEED: @TODO: we don't need II and GG after computations, clear the memory if m.optimized == True: # print('=== optimized ===') II = self.m_I[m][0] GG = self.m_G[m][0] n = II.shape[0] NGD_kernel = None if self.reduce_sum == 'true': if self.diag == 'true': NGD_kernel = (II * GG / n) NGD_inv = torch.reciprocal(NGD_kernel + self.damping) else: NGD_kernel = II * GG / n if self.rand_svd: _, S, Vh = self.rSVD(NGD_kernel, 50, 0, 20) self.S_r[m] = inv(torch.diag(S) + self.damping * eye(S.shape[0]).to(II.device)) self.V_r[m] = Vh else: NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) else: NGD_kernel = (einsum('nqlp->nq', II * GG)) / n NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(II.device)) if not self.rand_svd: self.m_NGD_Kernel[m] = NGD_inv self.m_I[m] = (None, self.m_I[m][1]) self.m_G[m] = (None, self.m_G[m][1]) torch.cuda.empty_cache() else: # SAEED: @TODO memory cleanup I = self.m_I[m][1] G = self.m_G[m][1] n = I.shape[0] AX = einsum("nkl,nml->nkm", (I, G)) del I del G AX_ = AX.reshape(n , -1) out = matmul(AX_, AX_.t()) del AX NGD_kernel = out / n ### low-rank approximation of Jacobian if self.low_rank == 'true': # print('=== low rank ===') V, S, U = svd(AX_.T, full_matrices=False) U = U.t() V = V.t() cs = cumsum(S, dim = 0) sum_s = sum(S) index = ((cs - self.gamma * sum_s) <= 0).sum() U = U[:, 0:index] S = S[0:index] V = V[0:index, :] self.m_UV[m] = U, S, V del AX_ NGD_inv = inv(NGD_kernel + self.damping * eye(n).to(NGD_kernel.device)) self.m_NGD_Kernel[m] = NGD_inv del NGD_inv self.m_I[m] = None, self.m_I[m][1] self.m_G[m] = None, self.m_G[m][1] torch.cuda.empty_cache()
def compute_Sigma(self, X): n, dim = X.shape XjXj = th.matmul(X.T, X) + th.eye(dim, device=self.device) Sigma = linalg.inv( n * th.matmul(X.reshape(n, dim, 1), X.reshape(n, 1, dim)) + XjXj) return Sigma