Esempio n. 1
0
    def _save_input(self, m, input):
        if torch.is_grad_enabled() and self.steps % self.TCov == 0:
            if m.__class__.__name__ == "Conv2d":
                # KF-QN-CNN use an estimate over a batch instead of running estimate
                a, self.a_avg[m] = self.CovAHandler(input[0].data, m, bfgs=True)
                if not m in self.H_a:
                    batch_size, spatial_size = a.size(0), a.size(1)
                    a_ = a.view(-1, a.size(-1)) / spatial_size
                    cov_a = a_.t() @ (a_ / batch_size)
                    self.H_a[m] = torch.linalg.inv(cov_a + math.sqrt(self.damping) * torch.eye(cov_a.size(0)).to(cov_a.device))
                self.s_a[m] = self.H_a[m] @ self.a_avg[m].transpose(0, 1)
                s_a = self.s_a[m].view(self.s_a[m].size(0))
                batch_size, spatial_size = a.size(0), a.size(1)
                self.As[m] = torch.einsum('ntd,d->nt', (a, s_a)) # broadcasted dot product
                self.As[m] = torch.einsum('nt,ntd->ntd', (self.As[m], a)) # vector scaling
                self.As[m] = torch.einsum('ntd->d', self.As[m]) # sum over batch and spatial dim
                self.As[m] = self.As[m].unsqueeze(1) / batch_size

            elif m.__class__.__name__ == "Linear":
                aa, self.a_avg[m] = self.CovAHandler(input[0].data, m, bfgs=True)
                # initialize buffer
                if self.steps == 0:
                    self.m_aa[m] = torch.diag(aa.new(aa.size(0)).fill_(1))
                # KF-QN-FC use a running estimate
                update_running_stat(aa, self.m_aa[m], self.stat_decay)
                # initialize buffer
                if not m in self.H_a:
                    self.H_a[m] = torch.linalg.inv(self.m_aa[m] + math.sqrt(self.damping) * torch.eye(self.m_aa[m].size(0)).to(self.m_aa[m].device))
Esempio n. 2
0
 def _save_input(self, module, input):
     if torch.is_grad_enabled() and self.steps % self.TCov == 0:
         aa = self.CovAHandler(input[0].data, module)
         # Initialize buffers
         if self.steps == 0:
             self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1))
         update_running_stat(aa, self.m_aa[module], self.stat_decay)
Esempio n. 3
0
    def _update_scale(self, m):
        with torch.no_grad():
            A, S = self.A[m], self.DS[m]
            grad_mat = self.MatGradHandler(A, S,
                                           m)  # batch_size * out_dim * in_dim
            if self.batch_averaged:
                grad_mat *= S.size(0)

            s_l = (self.Q_g[m] @ grad_mat
                   @ self.Q_a[m].t())**2  # <- this consumes too much memory!
            s_l = s_l.mean(dim=0)
            if self.steps == 0:
                self.S_l[m] = s_l.new(s_l.size()).fill_(1)
            # s_ls = self.Q_g[m] @ grad_s
            # s_la = in_a @ self.Q_a[m].t()
            # s_l = 0
            # for i in range(0, s_ls.size(0), S.size(0)):    # tradeoff between time and memory
            #     start = i
            #     end = min(s_ls.size(0), i + S.size(0))
            #     s_l += (torch.bmm(s_ls[start:end,:], s_la[start:end,:]) ** 2).sum(0)
            # s_l /= s_ls.size(0)
            # if self.steps == 0:
            #     self.S_l[m] = s_l.new(s_l.size()).fill_(1)
            update_running_stat(s_l, self.S_l[m], self.stat_decay)
            # remove reference for reducing memory cost.
            self.A[m] = None
            self.DS[m] = None
Esempio n. 4
0
 def _save_grad_output(self, module, grad_input, grad_output):
     # Accumulate statistics for Fisher matrices
     if self.acc_stats and self.steps % self.TCov == 0:
         gg, _ = self.CovGHandler(grad_output[0], module, self.batch_averaged)
         # Initialize buffers
         if self.steps == 0:
             self.m_gg[module] = torch.zeros_like(gg)
         update_running_stat(gg, self.m_gg[module], self.stat_decay)
Esempio n. 5
0
 def _save_input(self, module, input):
     if torch.is_grad_enabled() and self.steps % self.TCov == 0:
         with torch.no_grad():
             aa, _ = self.CovAHandler(input[0], module)
         # Initialize buffers
         if self.steps == 0:
             self.m_aa[module] = torch.zeros_like(aa)
         update_running_stat(aa, self.m_aa[module], self.stat_decay)
Esempio n. 6
0
 def _save_grad_output(self, module, grad_input, grad_output):
     # Accumulate statistics for Fisher matrices
     if self.acc_stats and self.steps % self.TCov == 0:
         gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged)
         # Initialize buffers
         if self.steps == 0:
             self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1))
         update_running_stat(gg, self.m_gg[module], self.stat_decay)
Esempio n. 7
0
 def _save_input(self, module, input):
     if torch.is_grad_enabled() and self.steps % self.TCov == 0:
         # Get module index
         i = self.modules.index(module)
         with torch.no_grad():
             aa, self.a[i] = self.CovAHandler(input[0], module)
         # Initialize buffer
         if self.steps == 0:
             self.m_aa[module] = torch.zeros_like(aa)
         update_running_stat(aa, self.m_aa[module], self.stat_decay)
         self.sum_aa[i][i] = self.m_aa[module].sum()
         # Update sums of off-diagonal blocks of A
         for j in range(i):
             # Compute inter-layer covariances (downsample if needed)
             new_aa = self._downsample_multiply(self.a, i, j, input[0].shape[0], self.mode)
             # Update sum
             self.sum_aa[i][j] *= self.stat_decay
             self.sum_aa[i][j] += (1 - self.stat_decay) * new_aa.sum()
Esempio n. 8
0
 def _save_grad_output(self, module, grad_input, grad_output):
     # Accumulate statistics for Fisher matrices
     if self.acc_stats and self.steps % self.TCov == 0:
         # Get module index
         i = self.modules.index(module)
         gg, self.g[i] = self.CovGHandler(grad_output[0], module, self.batch_averaged)
         # Initialize buffers
         if self.steps == 0:
             self.m_gg[module] = torch.zeros_like(gg)
         update_running_stat(gg, self.m_gg[module], self.stat_decay)
         self.sum_gg[i][i] = self.m_gg[module].sum()
         # Update sums of off-diagonal blocks of G
         for j in range(i, self.nlayers):
             # Compute inter-layer covariances (downsample if needed)
             new_gg = self._downsample_multiply(self.g, i, j, grad_output[0].shape[0], self.mode)
             # Update sum
             self.sum_gg[i][j] *= self.stat_decay
             self.sum_gg[i][j] += (1 - self.stat_decay) * new_gg.sum()