def _broadcast_eigendecomp(self):
        """Broadcasts the eigendecompositions for all layers

        Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will
        either compute the eigendecomposition for a factor or just return
        zeros so we sum instead of averaging.
        """
        handles = []
        rank = hvd.rank()

        for i, m in enumerate(self.modules):
            rank_a = self.m_dA_ranks[m]
            rank_g = self.m_dG_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_dA[m], rank_a, name=name + 'mdA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_dG[m], rank_g, name=name + 'mdG')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 2
0
    def bcast_async_(self, name, tensor, rank):
        if self.merge:
            new_name, new_tensor = self._tensor_group.push_tensor(name, tensor)
            self._name_tensors[name] = tensor
            if new_tensor is not None:
                current_stream = torch.cuda.current_stream()
                current_stream.synchronize()

                handle = hvd.broadcast_async_(new_tensor, rank, name=self.prefix+new_name)
                self.handles.append(handle)
        else:
            handle = hvd.broadcast_async_(tensor, rank)
            self.handles.append(handle)
    def _broadcast_inverse_factors(self):
        handles = []

        for i, m in enumerate(self.modules):
            rank_a, rank_g = self.module_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_inv_A[m], rank_a, name=name+'inverseA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_inv_G[m], rank_g, name=name+'inverseG')
            handles.append(h)
    
        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 4
0
    def bcast_async_(self, names, tensors, rank):
        if self.fp16:
            comm_tensors = [t.half() for t in tensors]
        else:
            comm_tensors = tensors

        if self.symmetric:
            sym_comm_tensors = []
            for tensor in comm_tensors:
                upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device)
                comm_tensor = tensor[upper_indices[0], upper_indices[1]]
                sym_comm_tensors.append(comm_tensor)
            comm_tensors = sym_comm_tensors

        name = ','.join(names)
        if len(comm_tensors) > 1:
            if name not in self.merged_tensors:
                size = 0
                for t in comm_tensors:
                    size += t.numel()
                buf = comm_tensors[0].new_zeros(size)
                self.merged_tensors[name] = buf
        else:
            self.merged_tensors[name] = comm_tensors[0]
        buf = self.merged_tensors[name]
        if len(comm_tensors) > 1:
            offset = 0
            for t in comm_tensors:
                numel = t.numel()
                buf.data[offset:offset+numel].copy_(t.view(numel))
                offset += numel
        handle = hvd.broadcast_async_(buf, rank)
        self.handles.append((handle, names, tensors, comm_tensors))
    def _broadcast_sparse_inv(self):
        handles = []
        rank = hvd.rank()

        for i, m in enumerate(self.modules):
            rank_a = self.m_dA_ranks[m]
            rank_g = self.m_dG_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 6
0
    def _broadcast_precon_grads(self):
        handles = []

        for i, m in enumerate(self.modules):
            rank_a, rank_g = self.module_ranks[m]
            assert rank_a == rank_g
            name = self.module_names[i]
            v = self.m_precon_grad[m]

            h = hvd.broadcast_async_(v, rank_a, name=name + 'preconGrad')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 7
0
def barrier():
    torch.cuda.synchronize()
    handle = hvd.broadcast_async_(sync_tensor, root_rank=0)
    hvd.synchronize(handle)
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if hvd.size() > 1 and self.steps % self.fac_update_freq == 0:
            self.fw_merged_comm.synchronize()
            self.bw_merged_comm.synchronize()

            #for handle in self.fw_factor_handles:
            #    hvd.synchronize(handle)
            #self.fw_factor_handles.clear()
            #for handle in self.bw_factor_handles:
            #    hvd.synchronize(handle)
            #self.bw_factor_handles.clear()

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            #eigen_ranks = self._generate_eigen_ranks(epoch)
            eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
            #eigen_ranks = self._generate_eigen_ranks_naive(epoch)

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                self._update_eigen_A(module, ranks_a)
                h1 = hvd.broadcast_async_(self.m_QA[module], rank_a)
                h2 = hvd.broadcast_async_(self.m_dA[module], rank_a)
                self._update_eigen_G(module, ranks_g)
                h3 = hvd.broadcast_async_(self.m_QG[module], rank_g)
                h4 = hvd.broadcast_async_(self.m_dG[module], rank_g)
                handles.append((h1, h2, h3, h4))

            if hvd.size() > 1:
                #for handle in handles:
                #    hvd.synchronize(handle)
                #self._allreduce_eigendecomp()
                #self._broadcast_eigendecomp()
                pass

        for i, module in enumerate(self.modules):
            if hvd.size() > 1 and len(handles) > 0:
                h1, h2, h3, h4 = handles[i]
                hvd.synchronize(h1)
                hvd.synchronize(h2)
                hvd.synchronize(h3)
                hvd.synchronize(h4)

            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        #self._update_scale_grad(updates)

        self.steps += 1