def _broadcast_eigendecomp(self): """Broadcasts the eigendecompositions for all layers Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will either compute the eigendecomposition for a factor or just return zeros so we sum instead of averaging. """ handles = [] rank = hvd.rank() for i, m in enumerate(self.modules): rank_a = self.m_dA_ranks[m] rank_g = self.m_dG_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA') handles.append(h) h = hvd.broadcast_async_(self.m_dA[m], rank_a, name=name + 'mdA') handles.append(h) h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG') handles.append(h) h = hvd.broadcast_async_(self.m_dG[m], rank_g, name=name + 'mdG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def bcast_async_(self, name, tensor, rank): if self.merge: new_name, new_tensor = self._tensor_group.push_tensor(name, tensor) self._name_tensors[name] = tensor if new_tensor is not None: current_stream = torch.cuda.current_stream() current_stream.synchronize() handle = hvd.broadcast_async_(new_tensor, rank, name=self.prefix+new_name) self.handles.append(handle) else: handle = hvd.broadcast_async_(tensor, rank) self.handles.append(handle)
def _broadcast_inverse_factors(self): handles = [] for i, m in enumerate(self.modules): rank_a, rank_g = self.module_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_inv_A[m], rank_a, name=name+'inverseA') handles.append(h) h = hvd.broadcast_async_(self.m_inv_G[m], rank_g, name=name+'inverseG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def bcast_async_(self, names, tensors, rank): if self.fp16: comm_tensors = [t.half() for t in tensors] else: comm_tensors = tensors if self.symmetric: sym_comm_tensors = [] for tensor in comm_tensors: upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device) comm_tensor = tensor[upper_indices[0], upper_indices[1]] sym_comm_tensors.append(comm_tensor) comm_tensors = sym_comm_tensors name = ','.join(names) if len(comm_tensors) > 1: if name not in self.merged_tensors: size = 0 for t in comm_tensors: size += t.numel() buf = comm_tensors[0].new_zeros(size) self.merged_tensors[name] = buf else: self.merged_tensors[name] = comm_tensors[0] buf = self.merged_tensors[name] if len(comm_tensors) > 1: offset = 0 for t in comm_tensors: numel = t.numel() buf.data[offset:offset+numel].copy_(t.view(numel)) offset += numel handle = hvd.broadcast_async_(buf, rank) self.handles.append((handle, names, tensors, comm_tensors))
def _broadcast_sparse_inv(self): handles = [] rank = hvd.rank() for i, m in enumerate(self.modules): rank_a = self.m_dA_ranks[m] rank_g = self.m_dG_ranks[m] name = self.module_names[i] h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA') handles.append(h) h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG') handles.append(h) for handle in handles: hvd.synchronize(handle)
def _broadcast_precon_grads(self): handles = [] for i, m in enumerate(self.modules): rank_a, rank_g = self.module_ranks[m] assert rank_a == rank_g name = self.module_names[i] v = self.m_precon_grad[m] h = hvd.broadcast_async_(v, rank_a, name=name + 'preconGrad') handles.append(h) for handle in handles: hvd.synchronize(handle)
def barrier(): torch.cuda.synchronize() handle = hvd.broadcast_async_(sync_tensor, root_rank=0) hvd.synchronize(handle)
def step(self, closure=None, epoch=None): """Perform one K-FAC step Note: - this function should always be called before `optimizer.step()` - gradients must be averaged across ranks before calling `step()` Args: closure: for compatibility with the base optimizer class. `closure` is ignored by KFAC epoch (int, optional): epoch to use for determining when to end the `diag_warmup` period. `epoch` is not necessary if not using `diag_warmup` """ # Update params, used for compatibilty with `KFACParamScheduler` group = self.param_groups[0] self.lr = group['lr'] self.damping = group['damping'] self.fac_update_freq = group['fac_update_freq'] self.kfac_update_freq = group['kfac_update_freq'] updates = {} handles = [] if epoch is None: if self.diag_warmup > 0: print("WARNING: diag_warmup > 0 but epoch was not passed to " "KFAC.step(). Defaulting to no diag_warmup") diag_blocks = self.diag_blocks else: diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1 if hvd.size() > 1 and self.steps % self.fac_update_freq == 0: self.fw_merged_comm.synchronize() self.bw_merged_comm.synchronize() #for handle in self.fw_factor_handles: # hvd.synchronize(handle) #self.fw_factor_handles.clear() #for handle in self.bw_factor_handles: # hvd.synchronize(handle) #self.bw_factor_handles.clear() # if we are switching from no diag approx to approx, we need to clear # off-block-diagonal elements if not self.have_cleared_Q and \ epoch == self.diag_warmup and \ self.steps % self.kfac_update_freq == 0: self._clear_eigen() self.have_cleared_Q = True if self.steps % self.kfac_update_freq == 0: # reset rank iter so device get the same layers # to compute to take advantage of caching self.rank_iter.reset() handles = [] #eigen_ranks = self._generate_eigen_ranks(epoch) eigen_ranks = self._generate_eigen_ranks_uniform(epoch) #eigen_ranks = self._generate_eigen_ranks_naive(epoch) for module in self.modules: ranks_a, ranks_g = eigen_ranks[module] self.m_dA_ranks[module] = ranks_a[0] self.m_dG_ranks[module] = ranks_g[0] rank_a = ranks_a[0] rank_g = ranks_g[0] self._update_eigen_A(module, ranks_a) h1 = hvd.broadcast_async_(self.m_QA[module], rank_a) h2 = hvd.broadcast_async_(self.m_dA[module], rank_a) self._update_eigen_G(module, ranks_g) h3 = hvd.broadcast_async_(self.m_QG[module], rank_g) h4 = hvd.broadcast_async_(self.m_dG[module], rank_g) handles.append((h1, h2, h3, h4)) if hvd.size() > 1: #for handle in handles: # hvd.synchronize(handle) #self._allreduce_eigendecomp() #self._broadcast_eigendecomp() pass for i, module in enumerate(self.modules): if hvd.size() > 1 and len(handles) > 0: h1, h2, h3, h4 = handles[i] hvd.synchronize(h1) hvd.synchronize(h2) hvd.synchronize(h3) hvd.synchronize(h4) grad = self._get_grad(module) precon_grad = self._get_preconditioned_grad(module, grad) updates[module] = precon_grad #self._update_scale_grad(updates) self.steps += 1