def _kron(self): if self.kron is None: I = torch.eye(self.in_size[0]).to(self.device()) for size in self.in_size[1:]: I = torch.kron(I, torch.eye(size).to(self.device())) self.kron = self.r2c( I.reshape(np.prod(self.in_size), *self.in_size)) return self.kron
def forward(self, x, uv): uv = torch.kron(self.harmonic_scales, uv) uv = torch.cat((torch.sin(uv), torch.cos(uv)), dim=1) uv = torch.flatten(uv, start_dim=1) mu = self.encode(x, uv) if not self.rica: return self.decode(mu, uv), mu else: mu = F.elu(self.fc2in(mu)) muprime = F.elu(self.fc2out(mu)) return self.decode(muprime, uv), mu
def test_kron(): """Test the implementation by evaluation. The example is taken from https://de.wikipedia.org/wiki/Kronecker-Produkt """ a = torch.tensor([[1, 2], [3, 2], [5, 6]]).to_sparse() b = torch.tensor([[7, 8], [9, 0]]).to_sparse() sparse_result = sparse_kron(a, b) dense_result = torch.kron(a.to_dense(), b.to_dense()) err = torch.sum(torch.abs(sparse_result.to_dense() - dense_result)) condition = np.allclose(sparse_result.to_dense().numpy(), dense_result.numpy()) print("error {:2.2f}".format(err), condition) assert condition
def _dense_kron(sparse_tensor_a: torch.Tensor, sparse_tensor_b: torch.Tensor) -> torch.Tensor: """Faster than sparse_kron. Limited to resolutions of approximately 128x128 pixels by memory on my machine. Args: sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n]. sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q]. Returns: torch.Tensor: The resulting [mp, nq] tensor. """ return torch.kron(sparse_tensor_a.to_dense(), sparse_tensor_b.to_dense()).to_sparse()
def pre_compute(self): # This method is auxiliary of the `get_solution` method (below) # It extracts some computation irrelevant to `area` to reduce computational time # You'll need to first understand the theories to solve truss problems and read this method together with the `get_solution` method to understand dis_vec = self.coords[self.con[:, 1], :] - self.coords[self.con[:, 0], :] self.length = torch.linalg.norm(dis_vec, dim=-1) # length of bars self.T = dis_vec / self.length.view( -1, 1) # distance unit vector or transformation vector self.matrix = torch.kron( torch.tensor([[1, -1], [-1, 1]], device=self.device), torch.einsum('ij,ik->ijk', self.T, self.T)) self.idx = torch.cat((torch.arange(0, self.n_dim, device=self.device) + (self.con[:, 0] * self.n_dim).view(-1, 1), torch.arange(0, self.n_dim, device=self.device) + (self.con[:, 1] * self.n_dim).view(-1, 1)), dim=1) self.free_list = self.free_mask.flatten()
def matvec_product(W: nn.ParameterList, x: torch.Tensor, bias: Optional[nn.ParameterList], phm_rule: Union[list, nn.ParameterList]) -> torch.Tensor: """ Functional method to compute the generalized matrix-vector product based on the paper "Parameterization of Hypercomplex Multiplications (2020)" https://openreview.net/forum?id=rcQdycl0zyk y = Hx + b , where W is generated through the sum of kronecker products from the Parameterlist W, i.e. W is a nn.ParamterList with len(phm_rule) tensors of size (out_features, in_features) x has shape (batch_size, phm_dim*in_features) H = sum_{i=0}^{d} mul_rule \otimes W[i], where \otimes is the kronecker product As of now, it iterates over the "hyper-imaginary" components, a more efficient implementation would be to stack the x and bias vector directly as a 1D vector. """ assert len(phm_rule) == len(W) assert x.size(1) == sum([weight.size(1) for weight in W]), (f"x has size(1): {x.size(1)}." f"Should have {sum([weight.size(1) for weight in W])}") #H = torch.stack([kronecker_product(Ai, Wi) for Ai, Wi in zip(phm_rule, W)], dim=0).sum(0) A = torch.stack([Ai for Ai in phm_rule], dim=0) W = torch.stack([Wi for Wi in W], dim=0) H = torch.kron(A, W).sum(0) #y = torch.mm(H, x.t()).t() # avoid one transpose for speed other alternative tested # ( x @ H.t() is slower , # torch.einsum('ik,jk->ij', [x, H]) very slower, # torch.matmul(x, H.t()) is slower) y = x.mm(H.t()) if bias is not None: bias = torch.cat([b for b in bias], dim=-1) y += bias return y
def other_ops(self): a = torch.randn(4) b = torch.randn(4) c = torch.randint(0, 8, (5, ), dtype=torch.int64) e = torch.randn(4, 3) f = torch.randn(4, 4, 4) size = [0, 1] dims = [0, 1] return ( torch.atleast_1d(a), torch.atleast_2d(a), torch.atleast_3d(a), torch.bincount(c), torch.block_diag(a), torch.broadcast_tensors(a), torch.broadcast_to(a, (4)), # torch.broadcast_shapes(a), torch.bucketize(a, b), torch.cartesian_prod(a), torch.cdist(e, e), torch.clone(a), torch.combinations(a), torch.corrcoef(a), # torch.cov(a), torch.cross(e, e), torch.cummax(a, 0), torch.cummin(a, 0), torch.cumprod(a, 0), torch.cumsum(a, 0), torch.diag(a), torch.diag_embed(a), torch.diagflat(a), torch.diagonal(e), torch.diff(a), torch.einsum("iii", f), torch.flatten(a), torch.flip(e, dims), torch.fliplr(e), torch.flipud(e), torch.kron(a, b), torch.rot90(e), torch.gcd(c, c), torch.histc(a), torch.histogram(a), torch.meshgrid(a), torch.lcm(c, c), torch.logcumsumexp(a, 0), torch.ravel(a), torch.renorm(e, 1, 0, 5), torch.repeat_interleave(c), torch.roll(a, 1, 0), torch.searchsorted(a, b), torch.tensordot(e, e), torch.trace(e), torch.tril(e), torch.tril_indices(3, 3), torch.triu(e), torch.triu_indices(3, 3), torch.vander(a), torch.view_as_real(torch.randn(4, dtype=torch.cfloat)), torch.view_as_complex(torch.randn(4, 2)), torch.resolve_conj(a), torch.resolve_neg(a), )
m = torch.Tensor([[1, 2, 3], [4, 5, 6]]) n = torch.Tensor([[4, 5, 6], [1, 7, 3]]) assert bool((torch.mul(m, n) == m * n).all()) == True print(torch.mul(m, n).size()) # tensor product 也是矩阵相乘: \otimes # 同时也和 torch.bmm 相等 tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) assert bool((torch.matmul(tensor1, tensor2) == (tensor1 @ tensor2)).all()) == True print(torch.matmul(tensor1, tensor2).size()) tensor1 = torch.randn(10, 3, 2) tensor2 = torch.randn(4, 5, 7) assert (torch.kron(tensor1, tensor2)).shape == (40, 15, 14) # bmm 适用于3维,matmul 普遍适用,mm只适用二维 m = torch.randn(10, 3, 4) n = torch.randn(10, 4, 6) assert torch.bmm(m, n).size() == (10, 3, 6) # dropout 主要解决过拟合的问题 m = nn.Dropout(p=0.2) inputs = torch.randn(20, 16) print(m(inputs).size()) # layerNorm ,对原始的输入维度没有影响 inputs = torch.randn(20, 5, 10, 10) m = nn.LayerNorm(10) print(m(inputs).size())
import numpy as np import torch """ Test for kronecker gradient """ print(torch.__version__) m = 7 p = 8 q = 9 p0 = 6 q0 = 5 A = torch.rand(m, p * q, requires_grad=False) x = torch.rand(p, p0, requires_grad=True) y = torch.rand(q, q0, requires_grad=False) B = torch.rand(p0 * q0, p0 * q0, requires_grad=False) print(torch.kron(A, B)) check1 = torch.matmul(torch.transpose(torch.kron(x, y), 0, 1), torch.transpose(A, 0, 1)) check2 = torch.matmul(B, check1) check3 = torch.matmul(torch.kron(x, y), check2) check4 = torch.matmul(A, check3) loss = torch.trace(check4) #loss=torch.trace((@)) #for i in range(10): loss.backward() print(loss) #print(loss.grad_fn) #print(loss.grad_fn.next_functions)
def forward(self, z, x, *args): """ Implement adaptive TAP fixed-point iteration step. Note: The linear response actually does too much work for this module's default choice of spin priors. In particular, the intermediate `big_lambda` is always a batch of identity matrices. WARNING: This module is very slow, especially the backward pass. Args: z (`torch.Tensor`): Current fixed-point state as a batch of big vectors. x (`torch.Tensor`): Input source injection (data). Shape should match that of `spin_mean` in `z` (see `_initial_guess`). Returns: `torch.Tensor` with updated fixed-point state as batch of vectors """ spin_mean, cav_var = self.unpack_state(z) weight = self.weight() cav_mean = torch.einsum( 'i j d e, b j e -> b i d', weight, spin_mean ) - torch.einsum('b i d e, b i d -> b i e', cav_var, spin_mean) spin_mean, spin_var = self._spin_mean_var(x, cav_mean, cav_var[0]) if self.lin_response: N, dim = spin_mean.shape[-2], spin_mean.shape[-1] V = cav_var[0] S = rearrange(spin_var, 'i a b -> a b i') J = weight A = ( torch.kron(torch.eye(dim, dtype=x.dtype, device=x.device), torch.eye(N, dtype=x.dtype, device=x.device)) - torch.einsum('a c i, i k c d -> a i d k', S, J).reshape( dim * N, dim * N ) + torch.einsum( 'a c i, i c d, i k -> a i d k', S, V, torch.eye( N, dtype=x.dtype, device=x.device) ).reshape(dim * N, dim * N) ) B = rearrange(torch.diag_embed(S), 'a b i j -> (a i) (b j)') spin_cov = torch.solve(B, A).solution spin_cov = rearrange( spin_cov, '(a i) (b j) -> a b i j', a=dim, b=dim, i=N, j=N ) # [DEBUG] check conditioning of system # print(torch.linalg.cond(A)) spin_cov_diag = torch.diagonal(spin_cov, dim1=-2, dim2=-1) spin_cov_diag = rearrange(spin_cov_diag, 'a b i -> i a b') ones = batched_eye_like(spin_var) spin_inv_var = torch.solve(ones, spin_var).solution big_lambda = V + spin_inv_var A = spin_cov_diag B = spin_cov_diag @ big_lambda - batched_eye_like(spin_cov_diag) cav_var = torch.solve(B, A).solution # [DEBUG] eigvals should be positive (cov matrices should be psd) # print(torch.eig(spin_var[0])) # check for spin 0 # print(torch.eig(cav_var[0])) # check for spin 0 cav_var = cav_var.unsqueeze( 0).expand(x.shape[0], -1, -1, -1) return self.pack_state([spin_mean, cav_var])
def train(epoch): torch.set_printoptions(precision=16) print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 step_st_time = time.time() epoch_time = 0 print('\nKFAC/KBFGS damping: %f' % damping) print('\nNGD damping: %f' % (damping)) # desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total)) writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch) prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) for batch_idx, (inputs, targets) in prog_bar: if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. loss.backward() optimizer.step() elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net.forward(inputs) loss = criterion(outputs, targets) loss.backward() # do another forward-backward pass over batch inside step() def closure(): return inputs, targets, criterion, False # is_autoencoder = False optimizer.step(closure) elif optim_name == 'exact_ngd': inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) # update Fisher inverse if batch_idx % args.freq == 0: # compute true fisher with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) # use backpack extension to compute individual gradient in a batch batch_grad = [] with backpack(BatchGrad()): loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) for name, param in net.named_parameters(): if hasattr(param, "grad_batch"): batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1)) else: raise NotImplementedError J = torch.cat(batch_grad, 1) fisher = torch.matmul(J.t(), J) / args.batch_size inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device)) # clean the gradient to compute the true fisher optimizer.zero_grad() loss.backward() # compute the step direction p = F^-1 @ g grad_list = [] for name, param in net.named_parameters(): grad_list.append(param.grad.data.reshape(-1, 1)) g = torch.cat(grad_list, 0) p = torch.matmul(inv, g) start = 0 for name, param in net.named_parameters(): end = start + param.data.reshape(-1, 1).size(0) param.grad.copy_(p[start:end].reshape(param.grad.data.shape)) start = end optimizer.step() ### new optimizer test elif optim_name in ['kngd'] : inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optimizer.steps % optimizer.freq == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. if args.partial_backprop == 'true': idx = (sampled_y == targets) == False loss = criterion(outputs[idx,:], targets[idx]) # print('extra:', idx.sum().item()) loss.backward() optimizer.step() elif optim_name == 'ngd': if batch_idx % args.freq == 0: store_io_(True) inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward(retain_graph=True) # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) if args.trial == 'true': update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt) else: update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient) # optimizer.zero_grad() # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp) optimizer.zero_grad() # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): param.grad.copy_(update_list[name]) grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org store_io_(False) else: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward() # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher # with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) all_modules = net.modules() for m in net.modules(): if hasattr(m, "NGD_inv"): grad = m.weight.grad if isinstance(m, nn.Linear): I = m.I G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,oi->no", (I, grad)) grad_prod = einsum("no,no->n", (grad_prod, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->no", (v, G)) gv = einsum("no,ni->oi", (gv, I)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.Conv2d): if hasattr(m, "AX"): if args.low_rank.lower() == 'true': ###### using low rank structure U = m.U S = m.S V = m.V NGD_inv = m.NGD_inv n = NGD_inv.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = V @ grad_reshape.t().reshape(-1, 1) grad_prod = torch.diag(S) @ grad_prod grad_prod = U @ grad_prod grad_prod = grad_prod.squeeze() v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = U.t() @ v.unsqueeze(1) gv = torch.diag(S) @ gv gv = V.t() @ gv gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t() gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) else: AX = m.AX NGD_inv = m.NGD_inv n = AX.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = einsum("nkm,mk->n", (AX, grad_reshape)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("nkm,n->mk", (AX, v)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif hasattr(m, "I"): I = m.I if args.memory_efficient == 'true': I = unfold_func(m)(I) G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_reshape = grad.reshape(grad.shape[0], -1) x1 = einsum("nkl,mk->nml", (I, grad_reshape)) grad_prod = einsum("nml,nml->n", (x1, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,nml->nml", (v, G)) gv = einsum("nml,nkl->mk", (gv, I)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): if args.batchnorm == 'true': dw = m.dw n = dw.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,i->n", (dw, grad)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,ni->i", (v, dw)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org ##### do kl clip lr = lr_scheduler.get_last_lr()[0] # vg_sum = 0 # vg_sum += (grad_new * grad_org ).sum() # vg_sum = vg_sum * (lr ** 2) # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum)) # for name, param in net.named_parameters(): # param.grad.mul_(nu) # optimizer.step() # manual optimizing: with torch.no_grad(): for name, param in net.named_parameters(): d_p = param.grad.data # print('=== step ===') # apply momentum # if args.momentum != 0: # buf[name].mul_(args.momentum).add_(d_p) # d_p.copy_(buf[name]) # apply weight decay if args.weight_decay != 0: d_p.add_(args.weight_decay, param.data) lr = lr_scheduler.get_last_lr()[0] param.data.add_(-lr, d_p) # print('d_p:', d_p.shape) # print(d_p) train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) prog_bar.set_description(desc, refresh=True) if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1): step_saved_time = time.time() - step_st_time epoch_time += step_saved_time test_acc, test_loss = test(epoch) TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total))) TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc))) TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1)))) TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss))) TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time))) if args.debug_mem == 'true': TRAIN_INFO['memory'].append(torch.cuda.memory_reserved()) step_st_time = time.time() net.train() writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch) writer.add_scalar('train/acc', 100. * correct / total, epoch) acc = 100. * correct / total train_loss = train_loss/(batch_idx + 1) if args.step_info == 'true': TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) # save diagonal blocks of exact Fisher inverse or its approximations if args.save_inv == 'true': all_modules = net.modules() count = 0 start, end = 0, 0 if optim_name == 'ngd': for m in all_modules: if m.__class__.__name__ == 'Linear': with torch.no_grad(): I = m.I G = m.G J = torch.einsum('ni,no->nio', I, G) J = J.reshape(J.size(0), -1) JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif m.__class__.__name__ == 'Conv2d': with torch.no_grad(): AX = m.AX AX = AX.reshape(AX.size(0), -1) JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif optim_name == 'exact_ngd': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: end = start + m.weight.data.reshape(1, -1).size(1) np.save(f, inv[start:end,start:end].cpu().numpy()) start = end + m.bias.data.size(0) count += 1 elif optim_name == 'kfac': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: G = optimizer.m_gg[m] A = optimizer.m_aa[m] H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device)) H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device)) end = m.weight.data.reshape(1, -1).size(1) kfac_inv = torch.kron(H_a, H_g)[:end,:end] np.save(f, kfac_inv.cpu().numpy()) count += 1 return acc, train_loss