def _save_input(self, m, input): if torch.is_grad_enabled() and self.steps % self.TCov == 0: if m.__class__.__name__ == "Conv2d": # KF-QN-CNN use an estimate over a batch instead of running estimate a, self.a_avg[m] = self.CovAHandler(input[0].data, m, bfgs=True) if not m in self.H_a: batch_size, spatial_size = a.size(0), a.size(1) a_ = a.view(-1, a.size(-1)) / spatial_size cov_a = a_.t() @ (a_ / batch_size) self.H_a[m] = torch.linalg.inv(cov_a + math.sqrt(self.damping) * torch.eye(cov_a.size(0)).to(cov_a.device)) self.s_a[m] = self.H_a[m] @ self.a_avg[m].transpose(0, 1) s_a = self.s_a[m].view(self.s_a[m].size(0)) batch_size, spatial_size = a.size(0), a.size(1) self.As[m] = torch.einsum('ntd,d->nt', (a, s_a)) # broadcasted dot product self.As[m] = torch.einsum('nt,ntd->ntd', (self.As[m], a)) # vector scaling self.As[m] = torch.einsum('ntd->d', self.As[m]) # sum over batch and spatial dim self.As[m] = self.As[m].unsqueeze(1) / batch_size elif m.__class__.__name__ == "Linear": aa, self.a_avg[m] = self.CovAHandler(input[0].data, m, bfgs=True) # initialize buffer if self.steps == 0: self.m_aa[m] = torch.diag(aa.new(aa.size(0)).fill_(1)) # KF-QN-FC use a running estimate update_running_stat(aa, self.m_aa[m], self.stat_decay) # initialize buffer if not m in self.H_a: self.H_a[m] = torch.linalg.inv(self.m_aa[m] + math.sqrt(self.damping) * torch.eye(self.m_aa[m].size(0)).to(self.m_aa[m].device))
def _save_input(self, module, input): if torch.is_grad_enabled() and self.steps % self.TCov == 0: aa = self.CovAHandler(input[0].data, module) # Initialize buffers if self.steps == 0: self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1)) update_running_stat(aa, self.m_aa[module], self.stat_decay)
def _update_scale(self, m): with torch.no_grad(): A, S = self.A[m], self.DS[m] grad_mat = self.MatGradHandler(A, S, m) # batch_size * out_dim * in_dim if self.batch_averaged: grad_mat *= S.size(0) s_l = (self.Q_g[m] @ grad_mat @ self.Q_a[m].t())**2 # <- this consumes too much memory! s_l = s_l.mean(dim=0) if self.steps == 0: self.S_l[m] = s_l.new(s_l.size()).fill_(1) # s_ls = self.Q_g[m] @ grad_s # s_la = in_a @ self.Q_a[m].t() # s_l = 0 # for i in range(0, s_ls.size(0), S.size(0)): # tradeoff between time and memory # start = i # end = min(s_ls.size(0), i + S.size(0)) # s_l += (torch.bmm(s_ls[start:end,:], s_la[start:end,:]) ** 2).sum(0) # s_l /= s_ls.size(0) # if self.steps == 0: # self.S_l[m] = s_l.new(s_l.size()).fill_(1) update_running_stat(s_l, self.S_l[m], self.stat_decay) # remove reference for reducing memory cost. self.A[m] = None self.DS[m] = None
def _save_grad_output(self, module, grad_input, grad_output): # Accumulate statistics for Fisher matrices if self.acc_stats and self.steps % self.TCov == 0: gg, _ = self.CovGHandler(grad_output[0], module, self.batch_averaged) # Initialize buffers if self.steps == 0: self.m_gg[module] = torch.zeros_like(gg) update_running_stat(gg, self.m_gg[module], self.stat_decay)
def _save_input(self, module, input): if torch.is_grad_enabled() and self.steps % self.TCov == 0: with torch.no_grad(): aa, _ = self.CovAHandler(input[0], module) # Initialize buffers if self.steps == 0: self.m_aa[module] = torch.zeros_like(aa) update_running_stat(aa, self.m_aa[module], self.stat_decay)
def _save_grad_output(self, module, grad_input, grad_output): # Accumulate statistics for Fisher matrices if self.acc_stats and self.steps % self.TCov == 0: gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) # Initialize buffers if self.steps == 0: self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1)) update_running_stat(gg, self.m_gg[module], self.stat_decay)
def _save_input(self, module, input): if torch.is_grad_enabled() and self.steps % self.TCov == 0: # Get module index i = self.modules.index(module) with torch.no_grad(): aa, self.a[i] = self.CovAHandler(input[0], module) # Initialize buffer if self.steps == 0: self.m_aa[module] = torch.zeros_like(aa) update_running_stat(aa, self.m_aa[module], self.stat_decay) self.sum_aa[i][i] = self.m_aa[module].sum() # Update sums of off-diagonal blocks of A for j in range(i): # Compute inter-layer covariances (downsample if needed) new_aa = self._downsample_multiply(self.a, i, j, input[0].shape[0], self.mode) # Update sum self.sum_aa[i][j] *= self.stat_decay self.sum_aa[i][j] += (1 - self.stat_decay) * new_aa.sum()
def _save_grad_output(self, module, grad_input, grad_output): # Accumulate statistics for Fisher matrices if self.acc_stats and self.steps % self.TCov == 0: # Get module index i = self.modules.index(module) gg, self.g[i] = self.CovGHandler(grad_output[0], module, self.batch_averaged) # Initialize buffers if self.steps == 0: self.m_gg[module] = torch.zeros_like(gg) update_running_stat(gg, self.m_gg[module], self.stat_decay) self.sum_gg[i][i] = self.m_gg[module].sum() # Update sums of off-diagonal blocks of G for j in range(i, self.nlayers): # Compute inter-layer covariances (downsample if needed) new_gg = self._downsample_multiply(self.g, i, j, grad_output[0].shape[0], self.mode) # Update sum self.sum_gg[i][j] *= self.stat_decay self.sum_gg[i][j] += (1 - self.stat_decay) * new_gg.sum()