def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']
        #print('fac_update_freq: ', self.fac_update_freq)
        #print('kfac_update_freq: ', self.kfac_update_freq)

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if hvd.size() > 1 and self.steps % self.fac_update_freq == 0:
            self.fw_merged_comm.synchronize()
            self.bw_merged_comm.synchronize()

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True
        torch.cuda.synchronize()

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            #eigen_ranks = self._generate_eigen_ranks(epoch)
            #eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
            eigen_ranks = self._generate_eigen_ranks_naive(epoch)
            #inverse_As = []
            #A_ranks = []
            #inverse_Gs = []
            #G_ranks = []
            rank_to_tensors = {}

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                name = self.module_name_map[module]
                if not self.exclude_compute_inverse:
                    self._update_inverse_A(module, ranks_a)
                if not self.exclude_communicate_inverse:
                    if hvd.size() > 1 and rank_a >= 0:
                        self.multi_comm.bcast_async_([name + 'mQA'],
                                                     [self.m_QA[module]],
                                                     rank_a)

                if not self.exclude_compute_inverse:
                    self._update_inverse_G(module, ranks_g)
                if not self.exclude_communicate_inverse:
                    if hvd.size() > 1 and rank_g >= 0:
                        self.multi_comm.bcast_async_([name + 'mQG'],
                                                     [self.m_QG[module]],
                                                     rank_g)
            if self.exclude_communicate_inverse and not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            self.multi_comm.synchronize()

        for i, module in enumerate(self.modules):

            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        if self.dynamic_merge and hvd.size(
        ) > 1 and self.steps % self.kfac_update_freq == 0:
            if self.steps == 5:
                self.profiling = True
            elif self.steps == 25:
                fw_layerwise_times = torch.tensor(
                    self.fw_profiler.get_results())
                bw_layerwise_times = torch.tensor(
                    self.bw_profiler.get_results())
                hvd.broadcast_(fw_layerwise_times, root_rank=0)
                hvd.broadcast_(bw_layerwise_times, root_rank=0)
                fw_layerwise_times = fw_layerwise_times.numpy()
                bw_layerwise_times = bw_layerwise_times.numpy()
                if hvd.rank() == 0:
                    pass
                    #logger.info('fw_layerwise_times: %s, sum: %f', fw_layerwise_times, np.sum(fw_layerwise_times))
                    #logger.info('bw_layerwise_times: %s, sum: %f', bw_layerwise_times, np.sum(bw_layerwise_times))

                fw_factor_sizes = [
                    self.fw_factor_sizes[m] for m in self.module_names
                ]
                bw_factor_sizes = [
                    self.bw_factor_sizes[m] for m in self.module_names[::-1]
                ]
                self.fw_merged_comm.update_groups(fw_factor_sizes,
                                                  fw_layerwise_times,
                                                  reverse=False)
                self.bw_merged_comm.update_groups(bw_factor_sizes,
                                                  bw_layerwise_times,
                                                  reverse=True)
                self.profiling = False

        self.steps += 1
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if self.steps % self.fac_update_freq == 0:

            if self.eigen_ranks is None:
                if not self.exclude_compute_factor:
                    self._update_A()
                    self._update_G()
                #self.eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
                #self.eigen_ranks = self._generate_eigen_ranks_naive(epoch)
                self.eigen_ranks = self._generate_eigen_ranks_match_merging(
                    epoch)
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self._reduce_factors(self.eigen_ranks)
            else:
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self.fw_merged_comm.synchronize()
                        self.bw_merged_comm.synchronize()
            eigen_ranks = self.eigen_ranks

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                if not self.exclude_compute_inverse:
                    self._update_eigen_A(module, ranks_a)
                    self._update_eigen_G(module, ranks_g)

            if not self.exclude_communicate_inverse:
                if hvd.size() > 1:
                    self._broadcast_eigendecomp()
            elif not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        for i, module in enumerate(self.modules):
            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        self.steps += 1
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if self.steps % self.fac_update_freq == 0:

            if self.eigen_ranks is None:
                if not self.exclude_compute_factor:
                    self._update_A()
                    self._update_G()
                self.eigen_ranks, self.fusion_groups_A, self.fusion_groups_G = self._generate_eigen_ranks_blockpartition_opt(
                    epoch)
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self._reduce_factors(self.eigen_ranks)
                self.fw_merged_comm.init_tensor_group(self.reduce_module_names)
                self.bw_merged_comm.init_tensor_group(
                    self.reduce_module_names[::-1])
                if hvd.rank() == 0:
                    print('module_names: ', self.module_names)
                    print('fusion_groups_A: ', self.fusion_groups_A)

                self.fw_merged_comm.update_tensor_fusion(self.fusion_groups_A)
            else:  # starting from the 2nd iteration
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self.fw_merged_comm.synchronize()
                        self.bw_merged_comm.synchronize()
                        self.fw_allreduce_comm.synchronize()
                        self.bw_allreduce_comm.synchronize()
            eigen_ranks = self.eigen_ranks

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            merged_name_AGs = [[]] * hvd.size()
            merged_tensor_AGs = [[]] * hvd.size()
            for i, module in enumerate(self.modules):
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]
                name = self.module_name_map[module]

                if not self.exclude_compute_inverse:
                    self._update_eigen_A(module, ranks_a)

                if not self.exclude_compute_inverse:
                    self._update_eigen_G(module, ranks_g)

                merged_name_AGs[rank_a].append(name + '-A')
                merged_name_AGs[rank_g].append(name + '-G')
                merged_tensor_AGs[rank_a].append(self.m_QA[module])
                merged_tensor_AGs[rank_g].append(self.m_QG[module])

            if not self.exclude_communicate_inverse:
                if hvd.size() > 1:
                    #for rank, names in enumerate(merged_name_AGs):
                    #    merged_names = merged_name_AGs[rank]
                    #    merged_tensors = merged_tensor_AGs[rank]
                    #    self.multi_comm.bcast_async_(merged_names, merged_tensors, rank)
                    self._broadcast_eigendecomp()

            if self.exclude_communicate_inverse and not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            self.multi_comm.synchronize()

        for i, module in enumerate(self.modules):
            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            if self.steps == 5:
                self.profiling = True
            elif self.steps == 25:
                fw_layerwise_times = torch.tensor(
                    self.fw_profiler.get_results())
                bw_layerwise_times = torch.tensor(
                    self.bw_profiler.get_results())
                hvd.broadcast_(fw_layerwise_times, root_rank=0)
                hvd.broadcast_(bw_layerwise_times, root_rank=0)
                fw_layerwise_times = fw_layerwise_times.numpy()
                bw_layerwise_times = bw_layerwise_times.numpy()
                if hvd.rank() == 0:
                    pass
                fw_factor_sizes = [
                    self.fw_factor_sizes[m] for m in self.module_names
                ]
                bw_factor_sizes = [
                    self.bw_factor_sizes[m] for m in self.module_names[::-1]
                ]
                self.fw_merged_comm.update_groups(self.fusion_groups_A,
                                                  fw_factor_sizes,
                                                  fw_layerwise_times,
                                                  reverse=False)
                #self.bw_merged_comm.update_groups(self.fusion_groups_G, bw_factor_sizes, bw_layerwise_times, reverse=False)
                self.profiling = False

        self.steps += 1