Beispiel #1
0
 def _update_module_G(self, module):
     g = self.computeG(self.m_g[module], module, self.batch_averaged)
     #logger.info('G Name: %s, shape: %s', module, g.shape)
     if self.steps == 0:
         self._init_G(g, module)
     update_running_avg(g, self.m_G[module], self.factor_decay)
     if self.sparse:
         sparsification(self.m_G[module],
                        module,
                        ratio=self.sparse_ratio,
                        residuals=self.residualsG)
Beispiel #2
0
 def _update_A(self):
     """Compute and update factor A for all modules"""
     for module in self.modules:
         a = self.computeA(self.m_a[module], module)
         if self.steps == 0:
             self._init_A(a, module)
         update_running_avg(a, self.m_A[module], self.factor_decay)
         if self.sparse:
             sparsification(self.m_A[module],
                            module,
                            ratio=self.sparse_ratio,
                            residuals=self.residualsA)
 def _update_module_A(self, module):
     a = self.computeA(self.m_a[module], module)
     if self.steps == 0:
         self._init_A(a, module)
     update_running_avg(a, self.m_A[module], self.factor_decay)
     if self.sparse:
         #self.m_sparseA[module] = sparsification_randk(self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA)
         #self.m_sparseA[module] = sparsification(self.m_A[module], module, ratio=self.sparse_ratio, residuals=self.residualsA)
         self.m_sparseA[module] = fake_sparsification(
             self.m_A[module],
             module,
             ratio=self.sparse_ratio,
             residuals=self.residualsA)
Beispiel #4
0
 def _update_module_G(self, module):
     G = self.computeG(self.m_g[module], module, self.batch_averaged)
     if self.steps == 0:
         self._init_G(G, module)
     update_running_avg(G, self.m_G[module], self.factor_decay)
Beispiel #5
0
 def _update_module_A(self, module):
     A = self.computeA(self.m_a[module], module)
     if self.steps == 0:
         self._init_A(A, module)
     update_running_avg(A, self.m_A[module], self.factor_decay)
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']
        #print('fac_update_freq: ', self.fac_update_freq)
        #print('kfac_update_freq: ', self.kfac_update_freq)

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if hvd.size() > 1 and self.steps % self.fac_update_freq == 0:
            self.fw_merged_comm.synchronize()
            self.bw_merged_comm.synchronize()

            # Compute A and G after aggregation of a and g
            for module in self.modules:
                a = self.m_a[module]
                g = self.m_g[module]
                if hvd.rank() == 0:
                    logger.info('a Name: %s, shape %s', module, a.shape)
                    logger.info('g Name: %s, shape %s', module, g.shape)
                A = torch.einsum('ki,kj->ij', a, a / a.size(0))
                G = torch.einsum('ki,kj->ij', g, g / g.size(0))
                update_running_avg(A, self.m_A[module], self.factor_decay)
                update_running_avg(G, self.m_G[module], self.factor_decay)
            raise

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            #eigen_ranks = self._generate_eigen_ranks(epoch)
            eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
            #eigen_ranks = self._generate_eigen_ranks_naive(epoch)
            #inverse_As = []
            #A_ranks = []
            #inverse_Gs = []
            #G_ranks = []
            rank_to_tensors = {}

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                name = self.module_name_map[module]
                self._update_inverse_A(module, ranks_a)
                #if hvd.size() > 1 and rank_a >= 0:
                #    self.inverseA_merged_comm.bcast_async_(name, self.m_QA[module], rank_a)

                self._update_inverse_G(module, ranks_g)
                #if hvd.size() > 1 and rank_g >= 0:
                #    self.inverseG_merged_comm.bcast_async_(name, self.m_QG[module], rank_g)
                #if rank_a not in rank_to_tensors:
                #    rank_to_tensors[rank_a] = []
                #rank_to_tensors[rank_a].append((name, self.m_QA[module], self.m_QG[module]))
                if hvd.size() > 1 and rank_g >= 0:
                    self.multi_comm.bcast_async_(
                        [name], [self.m_QA[module], self.m_QG[module]], rank_g)
            #if hvd.size() > 1:
            #    for rank in rank_to_tensors.keys():
            #        names = []
            #        tensors = []
            #        for name, ta, tb in rank_to_tensors[rank]:
            #            names.append(name)
            #            tensors.append(ta)
            #            tensors.append(tb)
            #        self.multi_comm.bcast_async_(names, tensors, rank)

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            #self.inverseA_merged_comm.synchronize()
            #self.inverseG_merged_comm.synchronize()
            self.multi_comm.synchronize()

        for i, module in enumerate(self.modules):

            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        self.steps += 1