Example #1
0
    def _kron(self):

        if self.kron is None:
            I = torch.eye(self.in_size[0]).to(self.device())
            for size in self.in_size[1:]:
                I = torch.kron(I, torch.eye(size).to(self.device()))
            self.kron = self.r2c(
                I.reshape(np.prod(self.in_size), *self.in_size))
        return self.kron
Example #2
0
 def forward(self, x, uv):
     uv = torch.kron(self.harmonic_scales, uv)
     uv = torch.cat((torch.sin(uv), torch.cos(uv)), dim=1)
     uv = torch.flatten(uv, start_dim=1)
     mu = self.encode(x, uv)
     if not self.rica:
         return self.decode(mu, uv), mu
     else:
         mu = F.elu(self.fc2in(mu))
         muprime = F.elu(self.fc2out(mu))
         return self.decode(muprime, uv), mu
Example #3
0
def test_kron():
    """Test the implementation by evaluation.

    The example is taken from
    https://de.wikipedia.org/wiki/Kronecker-Produkt
    """
    a = torch.tensor([[1, 2], [3, 2], [5, 6]]).to_sparse()
    b = torch.tensor([[7, 8], [9, 0]]).to_sparse()
    sparse_result = sparse_kron(a, b)
    dense_result = torch.kron(a.to_dense(), b.to_dense())
    err = torch.sum(torch.abs(sparse_result.to_dense() - dense_result))
    condition = np.allclose(sparse_result.to_dense().numpy(),
                            dense_result.numpy())
    print("error {:2.2f}".format(err), condition)
    assert condition
def _dense_kron(sparse_tensor_a: torch.Tensor,
                sparse_tensor_b: torch.Tensor) -> torch.Tensor:
    """Faster than sparse_kron.

    Limited to resolutions of approximately 128x128 pixels
    by memory on my machine.

    Args:
        sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n].
        sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q].

    Returns:
        torch.Tensor: The resulting [mp, nq] tensor.

    """
    return torch.kron(sparse_tensor_a.to_dense(),
                      sparse_tensor_b.to_dense()).to_sparse()
 def pre_compute(self):
     # This method is auxiliary of the `get_solution` method (below)
     # It extracts some computation irrelevant to `area` to reduce computational time
     # You'll need to first understand the theories to solve truss problems and read this method together with the `get_solution` method to understand
     dis_vec = self.coords[self.con[:, 1], :] - self.coords[self.con[:,
                                                                     0], :]
     self.length = torch.linalg.norm(dis_vec, dim=-1)  # length of bars
     self.T = dis_vec / self.length.view(
         -1, 1)  # distance unit vector or transformation vector
     self.matrix = torch.kron(
         torch.tensor([[1, -1], [-1, 1]], device=self.device),
         torch.einsum('ij,ik->ijk', self.T, self.T))
     self.idx = torch.cat((torch.arange(0, self.n_dim, device=self.device) +
                           (self.con[:, 0] * self.n_dim).view(-1, 1),
                           torch.arange(0, self.n_dim, device=self.device) +
                           (self.con[:, 1] * self.n_dim).view(-1, 1)),
                          dim=1)
     self.free_list = self.free_mask.flatten()
Example #6
0
def matvec_product(W: nn.ParameterList, x: torch.Tensor,
                   bias: Optional[nn.ParameterList],
                   phm_rule: Union[list, nn.ParameterList]) -> torch.Tensor:
    """
    Functional method to compute the generalized matrix-vector product based on the paper
    "Parameterization of Hypercomplex Multiplications (2020)"
    https://openreview.net/forum?id=rcQdycl0zyk
    y = Hx + b , where W is generated through the sum of kronecker products from the Parameterlist W, i.e.
    
    W is a nn.ParamterList with len(phm_rule) tensors of size (out_features, in_features)
    x has shape (batch_size, phm_dim*in_features)
    H = sum_{i=0}^{d} mul_rule \otimes W[i], where \otimes is the kronecker product

    As of now, it iterates over the "hyper-imaginary" components, a more efficient implementation
    would be to stack the x and bias vector directly as a 1D vector.
    """
    assert len(phm_rule) == len(W)
    assert x.size(1) == sum([weight.size(1) for weight in W]), (f"x has size(1): {x.size(1)}."
                                                                f"Should have {sum([weight.size(1) for weight in W])}")

    #H = torch.stack([kronecker_product(Ai, Wi) for Ai, Wi in zip(phm_rule, W)], dim=0).sum(0)
    A = torch.stack([Ai for Ai in phm_rule], dim=0)
    W = torch.stack([Wi for Wi in W], dim=0)
    H = torch.kron(A, W).sum(0)

    #y = torch.mm(H, x.t()).t()
    # avoid one transpose for speed other alternative tested 
    # ( x @ H.t() is slower , 
    # torch.einsum('ik,jk->ij', [x, H]) very slower, 
    # torch.matmul(x, H.t()) is slower)
    y = x.mm(H.t())
    
    if bias is not None:
        bias = torch.cat([b for b in bias], dim=-1)
        y += bias
    return y
Example #7
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
m = torch.Tensor([[1, 2, 3], [4, 5, 6]])
n = torch.Tensor([[4, 5, 6], [1, 7, 3]])
assert bool((torch.mul(m, n) == m * n).all()) == True
print(torch.mul(m, n).size())

# tensor product 也是矩阵相乘: \otimes
# 同时也和 torch.bmm 相等
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
assert bool((torch.matmul(tensor1,
                          tensor2) == (tensor1 @ tensor2)).all()) == True
print(torch.matmul(tensor1, tensor2).size())

tensor1 = torch.randn(10, 3, 2)
tensor2 = torch.randn(4, 5, 7)
assert (torch.kron(tensor1, tensor2)).shape == (40, 15, 14)

# bmm 适用于3维,matmul 普遍适用,mm只适用二维
m = torch.randn(10, 3, 4)
n = torch.randn(10, 4, 6)
assert torch.bmm(m, n).size() == (10, 3, 6)

# dropout 主要解决过拟合的问题
m = nn.Dropout(p=0.2)
inputs = torch.randn(20, 16)
print(m(inputs).size())

# layerNorm ,对原始的输入维度没有影响
inputs = torch.randn(20, 5, 10, 10)
m = nn.LayerNorm(10)
print(m(inputs).size())
Example #9
0
import numpy as np
import torch
"""
Test for kronecker gradient
"""
print(torch.__version__)
m = 7
p = 8
q = 9
p0 = 6
q0 = 5
A = torch.rand(m, p * q, requires_grad=False)
x = torch.rand(p, p0, requires_grad=True)
y = torch.rand(q, q0, requires_grad=False)
B = torch.rand(p0 * q0, p0 * q0, requires_grad=False)
print(torch.kron(A, B))
check1 = torch.matmul(torch.transpose(torch.kron(x, y), 0, 1),
                      torch.transpose(A, 0, 1))
check2 = torch.matmul(B, check1)
check3 = torch.matmul(torch.kron(x, y), check2)
check4 = torch.matmul(A, check3)
loss = torch.trace(check4)
#loss=torch.trace((@))

#for i in range(10):

loss.backward()
print(loss)

#print(loss.grad_fn)
#print(loss.grad_fn.next_functions)
    def forward(self, z, x, *args):
        """
        Implement adaptive TAP fixed-point iteration step.

        Note:
            The linear response actually does too much work for this module's
            default choice of spin priors. In particular, the intermediate
            `big_lambda` is always a batch of identity matrices.

            WARNING: This module is very slow, especially the backward pass.

        Args:
            z (`torch.Tensor`):
                Current fixed-point state as a batch of big vectors.
            x (`torch.Tensor`):
                Input source injection (data). Shape should match that
                of `spin_mean` in `z` (see `_initial_guess`).

        Returns:
            `torch.Tensor` with updated fixed-point state as batch of vectors

        """
        spin_mean, cav_var = self.unpack_state(z)

        weight = self.weight()

        cav_mean = torch.einsum(
            'i j d e, b j e -> b i d', weight, spin_mean
        ) - torch.einsum('b i d e, b i d -> b i e', cav_var, spin_mean)

        spin_mean, spin_var = self._spin_mean_var(x, cav_mean, cav_var[0])

        if self.lin_response:
            N, dim = spin_mean.shape[-2], spin_mean.shape[-1]

            V = cav_var[0]
            S = rearrange(spin_var, 'i a b -> a b i')
            J = weight

            A = (
                torch.kron(torch.eye(dim, dtype=x.dtype, device=x.device),
                           torch.eye(N, dtype=x.dtype, device=x.device))
                - torch.einsum('a c i, i k c d -> a i d k', S, J).reshape(
                    dim * N, dim * N
                )
                + torch.einsum(
                    'a c i, i c d, i k -> a i d k', S, V, torch.eye(
                        N, dtype=x.dtype, device=x.device)
                ).reshape(dim * N, dim * N)
            )
            B = rearrange(torch.diag_embed(S), 'a b i j -> (a i) (b j)')
            spin_cov = torch.solve(B, A).solution
            spin_cov = rearrange(
                spin_cov, '(a i) (b j) -> a b i j', a=dim, b=dim, i=N, j=N
            )

            # [DEBUG] check conditioning of system
            # print(torch.linalg.cond(A))

            spin_cov_diag = torch.diagonal(spin_cov, dim1=-2, dim2=-1)
            spin_cov_diag = rearrange(spin_cov_diag, 'a b i -> i a b')

            ones = batched_eye_like(spin_var)
            spin_inv_var = torch.solve(ones, spin_var).solution
            big_lambda = V + spin_inv_var
            A = spin_cov_diag
            B = spin_cov_diag @ big_lambda - batched_eye_like(spin_cov_diag)
            cav_var = torch.solve(B, A).solution

            # [DEBUG] eigvals should be positive (cov matrices should be psd)
            # print(torch.eig(spin_var[0]))  # check for spin 0
            # print(torch.eig(cav_var[0]))  # check for spin 0

            cav_var = cav_var.unsqueeze(
                0).expand(x.shape[0], -1, -1, -1)

        return self.pack_state([spin_mean, cav_var])
Example #11
0
def train(epoch):
    torch.set_printoptions(precision=16)
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    step_st_time = time.time()
    epoch_time = 0
    print('\nKFAC/KBFGS damping: %f' % damping)
    print('\nNGD damping: %f' % (damping))

    # 
    desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total))

    writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch)

    prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
    for batch_idx, (inputs, targets) in prog_bar:

        if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
            loss.backward()
            optimizer.step()
        elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net.forward(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # do another forward-backward pass over batch inside step()
            def closure():
                return inputs, targets, criterion, False # is_autoencoder = False
            optimizer.step(closure)
        elif optim_name == 'exact_ngd':
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            # update Fisher inverse
            if batch_idx % args.freq == 0:
              # compute true fisher
              with torch.no_grad():
                sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
              # use backpack extension to compute individual gradient in a batch
              batch_grad = []
              with backpack(BatchGrad()):
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)

              for name, param in net.named_parameters():
                if hasattr(param, "grad_batch"):
                  batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1))
                else:
                  raise NotImplementedError

              J = torch.cat(batch_grad, 1)
              fisher = torch.matmul(J.t(), J) / args.batch_size
              inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device))
              # clean the gradient to compute the true fisher
              optimizer.zero_grad()

            loss.backward()
            # compute the step direction p = F^-1 @ g
            grad_list = []
            for name, param in net.named_parameters():
              grad_list.append(param.grad.data.reshape(-1, 1))
            g = torch.cat(grad_list, 0)
            p = torch.matmul(inv, g)

            start = 0
            for name, param in net.named_parameters():
              end = start + param.data.reshape(-1, 1).size(0)
              param.grad.copy_(p[start:end].reshape(param.grad.data.shape))
              start = end

            optimizer.step()

        ### new optimizer test
        elif optim_name in ['kngd'] :
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if  optimizer.steps % optimizer.freq == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
                if args.partial_backprop == 'true':
                  idx = (sampled_y == targets) == False
                  loss = criterion(outputs[idx,:], targets[idx])
                  # print('extra:', idx.sum().item())
            loss.backward()
            optimizer.step()

        elif optim_name == 'ngd':
            if batch_idx % args.freq == 0:
                store_io_(True)
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward(retain_graph=True)

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                
                if args.trial == 'true':
                    update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt)
                else:
                    update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient)

                # optimizer.zero_grad()
                # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp)

                optimizer.zero_grad()
   
                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    param.grad.copy_(update_list[name])
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org
                store_io_(False)
            else:
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward()

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                # with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                all_modules = net.modules()

                for m in net.modules():
                    if hasattr(m, "NGD_inv"):                    
                        grad = m.weight.grad
                        if isinstance(m, nn.Linear):
                            I = m.I
                            G = m.G
                            n = I.shape[0]
                            NGD_inv = m.NGD_inv
                            grad_prod = einsum("ni,oi->no", (I, grad))
                            grad_prod = einsum("no,no->n", (grad_prod, G))
                            v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                            gv = einsum("n,no->no", (v, G))
                            gv = einsum("no,ni->oi", (gv, I))
                            gv = gv / n
                            update = (grad - gv)/damp
                            m.weight.grad.copy_(update)
                        elif isinstance(m, nn.Conv2d):
                            if hasattr(m, "AX"):

                                if args.low_rank.lower() == 'true':
                                    ###### using low rank structure
                                    U = m.U
                                    S = m.S
                                    V = m.V
                                    NGD_inv = m.NGD_inv
                                    n = NGD_inv.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = V @ grad_reshape.t().reshape(-1, 1)
                                    grad_prod = torch.diag(S) @ grad_prod
                                    grad_prod = U @ grad_prod
                                    
                                    grad_prod = grad_prod.squeeze()
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = U.t() @ v.unsqueeze(1)
                                    gv = torch.diag(S) @ gv
                                    gv = V.t() @ gv

                                    gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t()
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                                else:
                                    AX = m.AX
                                    NGD_inv = m.NGD_inv
                                    n = AX.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = einsum("nkm,mk->n", (AX, grad_reshape))
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = einsum("nkm,n->mk", (AX, v))
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                            elif hasattr(m, "I"):
                                I = m.I
                                if args.memory_efficient == 'true':
                                    I = unfold_func(m)(I)
                                G = m.G
                                n = I.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_reshape = grad.reshape(grad.shape[0], -1)
                                x1 = einsum("nkl,mk->nml", (I, grad_reshape))
                                grad_prod = einsum("nml,nml->n", (x1, G))
                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,nml->nml", (v, G))
                                gv = einsum("nml,nkl->mk", (gv, I))
                                gv = gv.view_as(grad)
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                            if args.batchnorm == 'true':
                                dw = m.dw
                                n = dw.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_prod = einsum("ni,i->n", (dw, grad))

                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,ni->i", (v, dw))
                                
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        
                        

                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org


            ##### do kl clip
            lr = lr_scheduler.get_last_lr()[0]
            # vg_sum = 0
            # vg_sum += (grad_new * grad_org ).sum()
            # vg_sum = vg_sum * (lr ** 2)
            # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum))
            # for name, param in net.named_parameters():
            #     param.grad.mul_(nu)

            # optimizer.step()
            # manual optimizing:
            with torch.no_grad():
                for name, param in net.named_parameters():
                    d_p = param.grad.data
                    # print('=== step ===')

                    # apply momentum
                    # if args.momentum != 0:
                    #     buf[name].mul_(args.momentum).add_(d_p)
                    #     d_p.copy_(buf[name])

                    # apply weight decay
                    if args.weight_decay != 0:
                        d_p.add_(args.weight_decay, param.data)

                    lr = lr_scheduler.get_last_lr()[0]
                    param.data.add_(-lr, d_p)
                    # print('d_p:', d_p.shape)
                    # print(d_p)



        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        prog_bar.set_description(desc, refresh=True)
        if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1):
            step_saved_time = time.time() - step_st_time
            epoch_time += step_saved_time
            test_acc, test_loss = test(epoch)
            TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total)))
            TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc)))
            TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1))))
            TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss)))
            TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time)))
            if args.debug_mem == 'true':
                TRAIN_INFO['memory'].append(torch.cuda.memory_reserved())
            step_st_time = time.time()
            net.train()

    writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
    writer.add_scalar('train/acc', 100. * correct / total, epoch)
    acc = 100. * correct / total
    train_loss = train_loss/(batch_idx + 1)
    if args.step_info == 'true':
        TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time)))
    # save diagonal blocks of exact Fisher inverse or its approximations
    if args.save_inv == 'true':
      all_modules = net.modules()

      count = 0
      start, end = 0, 0
      if optim_name == 'ngd':
        for m in all_modules:
          if m.__class__.__name__ == 'Linear':
            with torch.no_grad():
              I = m.I
              G = m.G
              J = torch.einsum('ni,no->nio', I, G)
              J = J.reshape(J.size(0), -1)
              JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size

              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

          elif m.__class__.__name__ == 'Conv2d':
            with torch.no_grad():
              AX = m.AX
              AX = AX.reshape(AX.size(0), -1)
              JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size
              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

      elif optim_name == 'exact_ngd':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              end = start + m.weight.data.reshape(1, -1).size(1)
              np.save(f, inv[start:end,start:end].cpu().numpy())
              start = end + m.bias.data.size(0)
              count += 1

      elif optim_name == 'kfac':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              G = optimizer.m_gg[m]
              A = optimizer.m_aa[m]

              H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device))
              H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device))

              end = m.weight.data.reshape(1, -1).size(1)
              kfac_inv = torch.kron(H_a, H_g)[:end,:end]
              np.save(f, kfac_inv.cpu().numpy())
              count += 1

    return acc, train_loss