def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] #print('fac_update_freq: ', self.fac_update_freq) #print('kfac_update_freq: ', self.kfac_update_freq) updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if hvd.size() > 1 and self.steps % self.fac_update_freq == 0: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True torch.cuda.synchronize() if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] #eigen_ranks = self._generate_eigen_ranks(epoch) #eigen_ranks = self._generate_eigen_ranks_uniform(epoch) eigen_ranks = self._generate_eigen_ranks_naive(epoch) #inverse_As = [] #A_ranks = [] #inverse_Gs = [] #G_ranks = [] rank_to_tensors = {} for module in self.modules: ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] name = self.module_name_map[module] if not self.exclude_compute_inverse: self._update_inverse_A(module, ranks_a) if not self.exclude_communicate_inverse: if hvd.size() > 1 and rank_a >= 0: self.multi_comm.bcast_async_([name + 'mQA'], [self.m_QA[module]], rank_a) if not self.exclude_compute_inverse: self._update_inverse_G(module, ranks_g) if not self.exclude_communicate_inverse: if hvd.size() > 1 and rank_g >= 0: self.multi_comm.bcast_async_([name + 'mQG'], [self.m_QG[module]], rank_g) if self.exclude_communicate_inverse and not self.exclude_compute_inverse: # should have a barriar if hvd.size() > 1: barrier() if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0: self.multi_comm.synchronize() for i, module in enumerate(self.modules): grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad self._update_scale_grad(updates) if self.dynamic_merge and hvd.size( ) > 1 and self.steps % self.kfac_update_freq == 0: if self.steps == 5: self.profiling = True elif self.steps == 25: fw_layerwise_times = torch.tensor( self.fw_profiler.get_results()) bw_layerwise_times = torch.tensor( self.bw_profiler.get_results()) hvd.broadcast_(fw_layerwise_times, root_rank=0) hvd.broadcast_(bw_layerwise_times, root_rank=0) fw_layerwise_times = fw_layerwise_times.numpy() bw_layerwise_times = bw_layerwise_times.numpy() if hvd.rank() == 0: pass #logger.info('fw_layerwise_times: %s, sum: %f', fw_layerwise_times, np.sum(fw_layerwise_times)) #logger.info('bw_layerwise_times: %s, sum: %f', bw_layerwise_times, np.sum(bw_layerwise_times)) fw_factor_sizes = [ self.fw_factor_sizes[m] for m in self.module_names ] bw_factor_sizes = [ self.bw_factor_sizes[m] for m in self.module_names[::-1] ] self.fw_merged_comm.update_groups(fw_factor_sizes, fw_layerwise_times, reverse=False) self.bw_merged_comm.update_groups(bw_factor_sizes, bw_layerwise_times, reverse=True) self.profiling = False self.steps += 1
def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if self.steps % self.fac_update_freq == 0: if self.eigen_ranks is None: if not self.exclude_compute_factor: self._update_A() self._update_G() #self.eigen_ranks = self._generate_eigen_ranks_uniform(epoch) #self.eigen_ranks = self._generate_eigen_ranks_naive(epoch) self.eigen_ranks = self._generate_eigen_ranks_match_merging( epoch) if not self.exclude_communicate_factor: if hvd.size() > 1: self._reduce_factors(self.eigen_ranks) else: if not self.exclude_communicate_factor: if hvd.size() > 1: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() eigen_ranks = self.eigen_ranks # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] for module in self.modules: ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] if not self.exclude_compute_inverse: self._update_eigen_A(module, ranks_a) self._update_eigen_G(module, ranks_g) if not self.exclude_communicate_inverse: if hvd.size() > 1: self._broadcast_eigendecomp() elif not self.exclude_compute_inverse: # should have a barriar if hvd.size() > 1: barrier() for i, module in enumerate(self.modules): grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad self._update_scale_grad(updates) self.steps += 1
def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if self.steps % self.fac_update_freq == 0: if self.eigen_ranks is None: if not self.exclude_compute_factor: self._update_A() self._update_G() self.eigen_ranks, self.fusion_groups_A, self.fusion_groups_G = self._generate_eigen_ranks_blockpartition_opt( epoch) if not self.exclude_communicate_factor: if hvd.size() > 1: self._reduce_factors(self.eigen_ranks) self.fw_merged_comm.init_tensor_group(self.reduce_module_names) self.bw_merged_comm.init_tensor_group( self.reduce_module_names[::-1]) if hvd.rank() == 0: print('module_names: ', self.module_names) print('fusion_groups_A: ', self.fusion_groups_A) self.fw_merged_comm.update_tensor_fusion(self.fusion_groups_A) else: # starting from the 2nd iteration if not self.exclude_communicate_factor: if hvd.size() > 1: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() self.fw_allreduce_comm.synchronize() self.bw_allreduce_comm.synchronize() eigen_ranks = self.eigen_ranks # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] merged_name_AGs = [[]] * hvd.size() merged_tensor_AGs = [[]] * hvd.size() for i, module in enumerate(self.modules): ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] name = self.module_name_map[module] if not self.exclude_compute_inverse: self._update_eigen_A(module, ranks_a) if not self.exclude_compute_inverse: self._update_eigen_G(module, ranks_g) merged_name_AGs[rank_a].append(name + '-A') merged_name_AGs[rank_g].append(name + '-G') merged_tensor_AGs[rank_a].append(self.m_QA[module]) merged_tensor_AGs[rank_g].append(self.m_QG[module]) if not self.exclude_communicate_inverse: if hvd.size() > 1: #for rank, names in enumerate(merged_name_AGs): # merged_names = merged_name_AGs[rank] # merged_tensors = merged_tensor_AGs[rank] # self.multi_comm.bcast_async_(merged_names, merged_tensors, rank) self._broadcast_eigendecomp() if self.exclude_communicate_inverse and not self.exclude_compute_inverse: # should have a barriar if hvd.size() > 1: barrier() if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0: self.multi_comm.synchronize() for i, module in enumerate(self.modules): grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad self._update_scale_grad(updates) if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0: if self.steps == 5: self.profiling = True elif self.steps == 25: fw_layerwise_times = torch.tensor( self.fw_profiler.get_results()) bw_layerwise_times = torch.tensor( self.bw_profiler.get_results()) hvd.broadcast_(fw_layerwise_times, root_rank=0) hvd.broadcast_(bw_layerwise_times, root_rank=0) fw_layerwise_times = fw_layerwise_times.numpy() bw_layerwise_times = bw_layerwise_times.numpy() if hvd.rank() == 0: pass fw_factor_sizes = [ self.fw_factor_sizes[m] for m in self.module_names ] bw_factor_sizes = [ self.bw_factor_sizes[m] for m in self.module_names[::-1] ] self.fw_merged_comm.update_groups(self.fusion_groups_A, fw_factor_sizes, fw_layerwise_times, reverse=False) #self.bw_merged_comm.update_groups(self.fusion_groups_G, bw_factor_sizes, bw_layerwise_times, reverse=False) self.profiling = False self.steps += 1