def _update_module_G(self, module): g = self.computeG(self.m_g[module], module, self.batch_averaged) #logger.info('G Name: %s, shape: %s', module, g.shape) if self.steps == 0: self._init_G(g, module) update_running_avg(g, self.m_G[module], self.factor_decay) if self.sparse: sparsification(self.m_G[module], module, ratio=self.sparse_ratio, residuals=self.residualsG)
def _update_A(self): """Compute and update factor A for all modules""" for module in self.modules: a = self.computeA(self.m_a[module], module) if self.steps == 0: self._init_A(a, module) update_running_avg(a, self.m_A[module], self.factor_decay) if self.sparse: sparsification(self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA)
def _update_module_A(self, module): a = self.computeA(self.m_a[module], module) if self.steps == 0: self._init_A(a, module) update_running_avg(a, self.m_A[module], self.factor_decay) if self.sparse: #self.m_sparseA[module] = sparsification_randk(self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA) #self.m_sparseA[module] = sparsification(self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA) self.m_sparseA[module] = fake_sparsification( self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA)
def _update_module_G(self, module): G = self.computeG(self.m_g[module], module, self.batch_averaged) if self.steps == 0: self._init_G(G, module) update_running_avg(G, self.m_G[module], self.factor_decay)
def _update_module_A(self, module): A = self.computeA(self.m_a[module], module) if self.steps == 0: self._init_A(A, module) update_running_avg(A, self.m_A[module], self.factor_decay)
def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] #print('fac_update_freq: ', self.fac_update_freq) #print('kfac_update_freq: ', self.kfac_update_freq) updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if hvd.size() > 1 and self.steps % self.fac_update_freq == 0: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() # Compute A and G after aggregation of a and g for module in self.modules: a = self.m_a[module] g = self.m_g[module] if hvd.rank() == 0: logger.info('a Name: %s, shape %s', module, a.shape) logger.info('g Name: %s, shape %s', module, g.shape) A = torch.einsum('ki,kj->ij', a, a / a.size(0)) G = torch.einsum('ki,kj->ij', g, g / g.size(0)) update_running_avg(A, self.m_A[module], self.factor_decay) update_running_avg(G, self.m_G[module], self.factor_decay) raise # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] #eigen_ranks = self._generate_eigen_ranks(epoch) eigen_ranks = self._generate_eigen_ranks_uniform(epoch) #eigen_ranks = self._generate_eigen_ranks_naive(epoch) #inverse_As = [] #A_ranks = [] #inverse_Gs = [] #G_ranks = [] rank_to_tensors = {} for module in self.modules: ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] name = self.module_name_map[module] self._update_inverse_A(module, ranks_a) #if hvd.size() > 1 and rank_a >= 0: # self.inverseA_merged_comm.bcast_async_(name, self.m_QA[module], rank_a) self._update_inverse_G(module, ranks_g) #if hvd.size() > 1 and rank_g >= 0: # self.inverseG_merged_comm.bcast_async_(name, self.m_QG[module], rank_g) #if rank_a not in rank_to_tensors: # rank_to_tensors[rank_a] = [] #rank_to_tensors[rank_a].append((name, self.m_QA[module], self.m_QG[module])) if hvd.size() > 1 and rank_g >= 0: self.multi_comm.bcast_async_( [name], [self.m_QA[module], self.m_QG[module]], rank_g) #if hvd.size() > 1: # for rank in rank_to_tensors.keys(): # names = [] # tensors = [] # for name, ta, tb in rank_to_tensors[rank]: # names.append(name) # tensors.append(ta) # tensors.append(tb) # self.multi_comm.bcast_async_(names, tensors, rank) if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0: #self.inverseA_merged_comm.synchronize() #self.inverseG_merged_comm.synchronize() self.multi_comm.synchronize() for i, module in enumerate(self.modules): grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad self._update_scale_grad(updates) self.steps += 1