def __init__(self,
                 model,
                 lr=0.1,
                 factor_decay=0.95,
                 damping=0.001,
                 kl_clip=0.001,
                 fac_update_freq=10,
                 kfac_update_freq=100,
                 batch_averaged=True,
                 diag_blocks=1,
                 diag_warmup=0,
                 distribute_layer_factors=None,
                 sparse=False,
                 sparse_ratio=0.01,
                 exclude_parts=''):
        #exclude_parts='CommunicateInverse,ComputeInverse,CommunicateFactor,ComputeFactor'):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < factor_decay <= 1:
            raise ValueError(
                "Invalid factor decay rate: {}".format(factor_decay))
        if not 0.0 < damping:
            raise ValueError("Invalid damping: {}".format(damping))
        if not 0.0 < kl_clip:
            raise ValueError("Invalid clipping value: {}".format(kl_clip))
        if not 0 < fac_update_freq:
            raise ValueError(
                "Invalid factor update frequency: {}".format(fac_update_freq))
        if not 0 < kfac_update_freq:
            raise ValueError(
                "Invalid K-FAC update frequency: {}".format(kfac_update_freq))
        if not 0 == kfac_update_freq % fac_update_freq:
            print(
                "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq"
            )
        if not 0 < diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 0 <= diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 1 == diag_blocks:
            print(
                "WARNING: diag_blocks > 1 is experimental and may give poor results."
            )

        # For compatibility with `KFACParamScheduler`
        defaults = dict(lr=lr,
                        damping=damping,
                        fac_update_freq=fac_update_freq,
                        kfac_update_freq=kfac_update_freq)

        super(KFAC, self).__init__(model.parameters(), defaults)

        self.computeA = ComputeA()
        self.computeG = ComputeG()
        self.known_modules = {'Linear', 'Conv2d'}
        self.modules = []
        self.module_names = []
        self.fw_factor_handles = []
        self.bw_factor_handles = []
        self.module_name_map = {}
        self._register_modules(model)

        self.steps = 0

        self.fw_merged_comm = MergedCommReduce(self.module_names,
                                               prefix='forward',
                                               merge=True,
                                               single_layer=False,
                                               symmetric=True,
                                               fp16=False)
        self.bw_merged_comm = MergedCommReduce(self.module_names[::-1],
                                               prefix='backward',
                                               merge=False,
                                               single_layer=False,
                                               symmetric=False,
                                               fp16=False)
        self.multi_comm = MultiTensorComm(symmetric=True, fp16=False)

        # Dictionaries keyed by `module` to storing the factors and
        # eigendecompositions
        self.m_a, self.m_g = {}, {}
        self.m_A, self.m_G = {}, {}
        self.m_QA, self.m_QG = {}, {}
        self.m_dA, self.m_dG = {}, {}
        self.m_dA_ranks = {}
        self.m_dG_ranks = {}
        self.module_ranks = None
        self.eigen_ranks = None

        self.sparse = sparse
        self.sparse_ratio = sparse_ratio
        self.residualsA, self.residualsG = {}, {}

        self.factor_decay = factor_decay
        self.kl_clip = kl_clip
        self.fac_update_freq = fac_update_freq
        self.kfac_update_freq = kfac_update_freq
        self.diag_blocks = diag_blocks
        self.diag_warmup = diag_warmup
        self.batch_averaged = batch_averaged

        self.exclude_communicate_inverse = True if exclude_parts.find(
            'CommunicateInverse') >= 0 else False
        self.exclude_compute_inverse = True if exclude_parts.find(
            'ComputeInverse') >= 0 else False
        self.exclude_communicate_factor = True if exclude_parts.find(
            'CommunicateFactor') >= 0 else False
        self.exclude_compute_factor = True if exclude_parts.find(
            'ComputeFactor') >= 0 else False

        # Compute ideal value for `distribute_layer_factors` based on
        # registered module count
        if distribute_layer_factors is None:
            self.distribute_layer_factors = True \
                    if hvd.size() > len(self.modules) else False
        else:
            self.distribute_layer_factors = distribute_layer_factors

        self.have_cleared_Q = True if self.diag_warmup == 0 else False
        self.eps = 1e-10  # for numerical stability
        self.rank_iter = cycle(list(range(hvd.size())))
class KFAC(optim.Optimizer):
    """KFAC Distributed Gradient Preconditioner

    Computes the natural gradient of a model in place with a layer-wise
    FIM approximation. Layer computations are distributed across workers
    using Horovod.

    Usage:
      optimizer = optim.SGD(model.parameters(), ...)
      optimizer = hvd.DistributedOptimizer(optimizer, ...)
      preconditioner = KFAC(model, ...)
      ... 
      for i, (data, target) in enumerate(train_loader):
          optimizer.zero_grad()
          output = model(data)
          loss = criterion(output, target)
          loss.backward()
          optimizer.synchronize()
          preconditioner.step()
          with optimizer.skip_synchronize():
              optimizer.step()

    Args:
      model (nn): Torch model to precondition
      lr (float, optional): learning rate (default: 0.1)
      factor_decay (float, optional): running average coefficient for Kronecker
          factors (default: 0.95)
      damping (float, optional): Tikhonov damping parameter (default: 0.001)
      kl_clip (float, optional): clipping parameter for gradient scaling
          (default: 0.001)
      fac_update_freq (int, optional): iterations between calculating and
          updating the running average of the Kronecker factors (default: 10)
      kfac_update_freq (int, optional): iterations between applying gradient
          preconditioning (default: 100)
      batch_averaged (bool, optional): boolean representing if the gradient
          is alrady averaged across the batches (default: True)
      diag_blocks (int, optional): Experimental: number of diagonal blocks to
          approximate the Kronecker factor eigendecomposition with. 
          `diag_blocks=1` computes the eigendecomposition of the entire factor
          (default: 1)
      diag_warmup (int, optional): number of epochs to wait before starting
          the block diagonal factor approximation (default: 0)
      distribute_layer_factors (bool, optional): if `True`, computes factors A
          and G on different workers else computes A and G for a single layer
          on the same worker. If `None`, determines best value based on layer
          count (default: None)
    """
    def __init__(self,
                 model,
                 lr=0.1,
                 factor_decay=0.95,
                 damping=0.001,
                 kl_clip=0.001,
                 fac_update_freq=10,
                 kfac_update_freq=100,
                 batch_averaged=True,
                 diag_blocks=1,
                 diag_warmup=0,
                 distribute_layer_factors=None,
                 sparse=False,
                 sparse_ratio=0.01,
                 exclude_parts=''):
        #exclude_parts='CommunicateInverse,ComputeInverse,CommunicateFactor,ComputeFactor'):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < factor_decay <= 1:
            raise ValueError(
                "Invalid factor decay rate: {}".format(factor_decay))
        if not 0.0 < damping:
            raise ValueError("Invalid damping: {}".format(damping))
        if not 0.0 < kl_clip:
            raise ValueError("Invalid clipping value: {}".format(kl_clip))
        if not 0 < fac_update_freq:
            raise ValueError(
                "Invalid factor update frequency: {}".format(fac_update_freq))
        if not 0 < kfac_update_freq:
            raise ValueError(
                "Invalid K-FAC update frequency: {}".format(kfac_update_freq))
        if not 0 == kfac_update_freq % fac_update_freq:
            print(
                "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq"
            )
        if not 0 < diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 0 <= diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 1 == diag_blocks:
            print(
                "WARNING: diag_blocks > 1 is experimental and may give poor results."
            )

        # For compatibility with `KFACParamScheduler`
        defaults = dict(lr=lr,
                        damping=damping,
                        fac_update_freq=fac_update_freq,
                        kfac_update_freq=kfac_update_freq)

        super(KFAC, self).__init__(model.parameters(), defaults)

        self.computeA = ComputeA()
        self.computeG = ComputeG()
        self.known_modules = {'Linear', 'Conv2d'}
        self.modules = []
        self.module_names = []
        self.fw_factor_handles = []
        self.bw_factor_handles = []
        self.module_name_map = {}
        self._register_modules(model)

        self.steps = 0

        self.fw_merged_comm = MergedCommReduce(self.module_names,
                                               prefix='forward',
                                               merge=True,
                                               single_layer=False,
                                               symmetric=True,
                                               fp16=False)
        self.bw_merged_comm = MergedCommReduce(self.module_names[::-1],
                                               prefix='backward',
                                               merge=False,
                                               single_layer=False,
                                               symmetric=False,
                                               fp16=False)
        self.multi_comm = MultiTensorComm(symmetric=True, fp16=False)

        # Dictionaries keyed by `module` to storing the factors and
        # eigendecompositions
        self.m_a, self.m_g = {}, {}
        self.m_A, self.m_G = {}, {}
        self.m_QA, self.m_QG = {}, {}
        self.m_dA, self.m_dG = {}, {}
        self.m_dA_ranks = {}
        self.m_dG_ranks = {}
        self.module_ranks = None
        self.eigen_ranks = None

        self.sparse = sparse
        self.sparse_ratio = sparse_ratio
        self.residualsA, self.residualsG = {}, {}

        self.factor_decay = factor_decay
        self.kl_clip = kl_clip
        self.fac_update_freq = fac_update_freq
        self.kfac_update_freq = kfac_update_freq
        self.diag_blocks = diag_blocks
        self.diag_warmup = diag_warmup
        self.batch_averaged = batch_averaged

        self.exclude_communicate_inverse = True if exclude_parts.find(
            'CommunicateInverse') >= 0 else False
        self.exclude_compute_inverse = True if exclude_parts.find(
            'ComputeInverse') >= 0 else False
        self.exclude_communicate_factor = True if exclude_parts.find(
            'CommunicateFactor') >= 0 else False
        self.exclude_compute_factor = True if exclude_parts.find(
            'ComputeFactor') >= 0 else False

        # Compute ideal value for `distribute_layer_factors` based on
        # registered module count
        if distribute_layer_factors is None:
            self.distribute_layer_factors = True \
                    if hvd.size() > len(self.modules) else False
        else:
            self.distribute_layer_factors = distribute_layer_factors

        self.have_cleared_Q = True if self.diag_warmup == 0 else False
        self.eps = 1e-10  # for numerical stability
        self.rank_iter = cycle(list(range(hvd.size())))

    def _compute_forward_factor(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.fac_update_freq == 0:
            self.m_a[module] = input[0].data
            if not self.exclude_compute_factor:
                self._update_module_A(module)
            if not self.exclude_communicate_factor:
                if hvd.size() > 1:
                    name = self.module_name_map[module]
                    if self.eigen_ranks is not None:
                        ranks_a, ranks_g = self.eigen_ranks[module]
                        rank_a = ranks_a[0]
                        self.fw_merged_comm.reduce_async_(
                            name, self.m_A[module].data, rank_a)

    def _compute_backward_factor(self, module, grad_input, grad_output):
        if self.steps % self.fac_update_freq == 0:
            self.m_g[module] = grad_output[0].data

            if not self.exclude_compute_factor:
                self._update_module_G(module)
            if not self.exclude_communicate_factor:
                if hvd.size() > 1:
                    name = self.module_name_map[module]
                    if self.eigen_ranks is not None:
                        ranks_a, ranks_g = self.eigen_ranks[module]
                        rank_g = ranks_g[0]
                        self.bw_merged_comm.reduce_async_(
                            name, self.m_G[module].data, rank_g)

    def _register_modules(self, model):
        """Register hooks to all supported layers in the model"""
        name_idx = 0
        for module in model.modules():
            classname = module.__class__.__name__
            if classname in self.known_modules:
                self.modules.append(module)
                module.register_forward_pre_hook(self._compute_forward_factor)
                module.register_backward_hook(self._compute_backward_factor)
                module_name = 'module_name_%s_%d' % (classname, name_idx)
                self.module_names.append(module_name)
                self.module_name_map[module] = module_name
                name_idx += 1

    def _init_A(self, factor, module):
        """Initialize memory for factor A and its eigendecomp"""
        self.m_A[module] = torch.diag(factor.new(factor.shape[0]).fill_(1))
        self.m_dA[module] = factor.new_zeros(factor.shape[0])
        self.m_QA[module] = factor.new_zeros(factor.shape)

    def _init_G(self, factor, module):
        """Initialize memory for factor G and its eigendecomp"""
        self.m_G[module] = torch.diag(factor.new(factor.shape[0]).fill_(1))
        self.m_dG[module] = factor.new_zeros(factor.shape[0])
        self.m_QG[module] = factor.new_zeros(factor.shape)

    def _clear_eigen(self):
        """Clear eigendecompositions

        Useful for when switching between `diag_blocks=1` and `diag-blocks>1`
        because eigendecompositions saved in place and the off-diagonals must
        be cleared.
        """
        for module in self.modules:
            self.m_QA[module].fill_(0)
            self.m_QG[module].fill_(0)
            self.m_dA[module].fill_(0)
            self.m_dG[module].fill_(0)

    def _update_module_A(self, module):
        a = self.computeA(self.m_a[module], module)
        if self.steps == 0:
            self._init_A(a, module)
        update_running_avg(a, self.m_A[module], self.factor_decay)
        if self.sparse:
            sparsification(self.m_A[module],
                           module,
                           ratio=self.sparse_ratio,
                           residuals=self.residualsA)

    def _update_A(self):
        """Compute and update factor A for all modules"""
        for module in self.modules:
            self._update_module_A(module)

    def _update_module_G(self, module):
        g = self.computeG(self.m_g[module], module, self.batch_averaged)
        #logger.info('G Name: %s, shape: %s', module, g.shape)
        if self.steps == 0:
            self._init_G(g, module)
        update_running_avg(g, self.m_G[module], self.factor_decay)
        if self.sparse:
            sparsification(self.m_G[module],
                           module,
                           ratio=self.sparse_ratio,
                           residuals=self.residualsG)

    def _update_G(self):
        """Compute and update factor G for all modules"""
        for module in self.modules:
            self._update_module_G(module)

    def _update_eigen_A(self, module, ranks):
        """Compute eigendecomposition of A for module on specified workers

        Note: all ranks will enter this function but only the ranks specified
        in `ranks` will continue to actually compute the eigendecomposition.
        All other ranks will simply zero out their buffer for the 
        eigendecomposition for the current module. This is done so we can sum
        the eigendecompositions across all ranks to communicate the results
        of locally computed eigendecompositions.

        Args:
          module: module to compute eigendecomposition of A on
          ranks: list of horovod ranks (i.e. workers) to use when computing
              the eigendecomposition.
        """
        if hvd.rank() in ranks:
            self._distributed_compute_eigen(self.m_A[module],
                                            self.m_QA[module],
                                            self.m_dA[module], ranks)
        else:
            self.m_QA[module].fill_(0)
            self.m_dA[module].fill_(0)

    def _update_eigen_G(self, module, ranks):
        """Compute eigendecomposition of A for module on specified workers

        See `_update_eigen_A` for more info`
        """
        if hvd.rank() in ranks:
            self._distributed_compute_eigen(self.m_G[module],
                                            self.m_QG[module],
                                            self.m_dG[module], ranks)
        else:
            self.m_QG[module].fill_(0)
            self.m_dG[module].fill_(0)

    def _distributed_compute_eigen(self, factor, evectors, evalues, ranks):
        """Computes the eigendecomposition of a factor across ranks
        
        Assigns each rank in `ranks` to enter this function to compute a
        diagonal block of `factor`. Results are written to `evectors` and
        `evalues`. If `len(ranks)==1`, then that rank computes the
        eigendecomposition of the entire `factor`.

        Args:
            factor (tensor): tensor to eigendecompose
            evectors (tensor): tensor to save eigenvectors of `factor` to
            evalues (tensor): tensor to save eigenvalues of `factor` to
            ranks (list): list of ranks that will enter this function
        """
        i = ranks.index(hvd.rank())
        n = len(ranks)
        if n > min(factor.shape):
            n = min(factor.shape)

        if i < n:
            start, end = get_block_boundary(i, n, factor.shape)
            block = factor[start[0]:end[0], start[1]:end[1]]
            block = add_value_to_diagonal(block, self.damping)
            inverse = torchsso.utils.inv(block)

            evectors.data[start[0]:end[0], start[1]:end[1]].copy_(inverse)

    def _get_diag_blocks(self, module, diag_blocks):
        """Helper method for determining number of diag_blocks to use

        Overrides `diag_blocks` if the `module` does not support
        `diag_blocks>1`. I.e. for a Linear layer, we do not want to
        use a `diag_blocks>1`.

        Args:
          module: module
          diag_blocks (int): default number of diag blocks to use
        """
        return diag_blocks if module.__class__.__name__ == 'Conv2d' else 1

    def _get_grad(self, module):
        """Get formated gradient of module

        Args:
          module: module/layer to get gradient of

        Returns:
          Formatted gradient with shape [output_dim, input_dim] for module
        """
        if module.__class__.__name__ == 'Conv2d':
            # n_filters * (in_c * kw * kh)
            grad = module.weight.grad.data.view(
                module.weight.grad.data.size(0), -1)
        else:
            grad = module.weight.grad.data
        if module.bias is not None:
            grad = torch.cat([grad, module.bias.grad.data.view(-1, 1)], 1)
        return grad

    def _get_preconditioned_grad(self, module, grad):
        """Precondition gradient of module
        
        Args:
          module: module to compute preconditioned gradient for
          grad: formatted gradient from `_get_grad()`

        Returns:
          preconditioned gradient with same shape as `grad`
        """
        #v = self.m_QG[module].t() @ grad @ self.m_QA[module]
        v = self.m_QG[module] @ grad @ self.m_QA[module]

        if module.bias is not None:
            v = [v[:, :-1], v[:, -1:]]
            v[0] = v[0].view(module.weight.grad.data.size())  # weight
            v[1] = v[1].view(module.bias.grad.data.size())  # bias
        else:
            v = [v.view(module.weight.grad.data.size())]
        return v

    def _update_scale_grad(self, updates):
        """Update the gradients in place and scale

        Updates the gradients in-place for all modules using the preconditioned
        gradients and scales the gradients.

        Args:
          updates (dict): dict of {module: precon_grad}
        """
        vg_sum = 0
        for module in self.modules:
            v = updates[module]
            vg_sum += (v[0] * module.weight.grad.data *
                       self.lr**2).sum().item()
            if module.bias is not None:
                vg_sum += (v[1] * module.bias.grad.data *
                           self.lr**2).sum().item()
        if self.exclude_communicate_inverse:
            nu = 1
        else:
            nu = min(1.0, math.sqrt(self.kl_clip / abs(vg_sum)))

        for module in self.modules:
            v = updates[module]
            module.weight.grad.data.copy_(v[0])
            module.weight.grad.data.mul_(nu)
            if module.bias is not None:
                module.bias.grad.data.copy_(v[1])
                module.bias.grad.data.mul_(nu)

    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if self.steps % self.fac_update_freq == 0:

            if self.eigen_ranks is None:
                if not self.exclude_compute_factor:
                    self._update_A()
                    self._update_G()
                #self.eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
                #self.eigen_ranks = self._generate_eigen_ranks_naive(epoch)
                self.eigen_ranks = self._generate_eigen_ranks_match_merging(
                    epoch)
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self._reduce_factors(self.eigen_ranks)
            else:
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self.fw_merged_comm.synchronize()
                        self.bw_merged_comm.synchronize()
            eigen_ranks = self.eigen_ranks

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                if not self.exclude_compute_inverse:
                    self._update_eigen_A(module, ranks_a)
                    self._update_eigen_G(module, ranks_g)

            if not self.exclude_communicate_inverse:
                if hvd.size() > 1:
                    self._broadcast_eigendecomp()
            elif not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        for i, module in enumerate(self.modules):
            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        self.steps += 1

    def _generate_eigen_ranks_naive(self, epoch):
        if self.module_ranks is not None:
            return self.module_ranks
        module_ranks = {}
        diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1
        buckets = [0] * hvd.size()
        for module in self.modules:
            # Get ranks to compute this layer on
            n = self._get_diag_blocks(module, diag_blocks)
            ranks_a = self.rank_iter.next(n)
            ranks_g = self.rank_iter.next(n) if self.distribute_layer_factors \
                                             else ranks_a
            module_ranks[module] = (ranks_a, ranks_g)
            buckets[ranks_a[0]] += self.m_A[module].shape[1]
            buckets[ranks_g[0]] += self.m_G[module].shape[1]
        self.module_ranks = module_ranks
        if hvd.rank() == 0:
            logger.info('buckets: %s', buckets)
            logger.info('module_ranks: %s', module_ranks.values())
        return module_ranks

    def _generate_eigen_ranks_match_merging(self, epoch):
        if self.module_ranks is not None:
            return self.module_ranks
        module_ranks = {}
        diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1
        assigned_rank = 0
        for i, module in enumerate(self.modules):
            # Get ranks to compute this layer on
            if i > 0 and i % 3 == 0:
                assigned_rank += 1
                assigned_rank %= hvd.size()
            rank = assigned_rank
            ranks_a = (rank, )
            ranks_g = (rank, )
            module_ranks[module] = (ranks_a, ranks_g)
        self.module_ranks = module_ranks
        return module_ranks

    def _generate_eigen_ranks_uniform(self, epoch):
        if self.module_ranks is not None:
            return self.module_ranks
        module_ranks = {}
        diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1
        buckets = [0] * hvd.size()
        dimensions = []
        module_factors = []
        for i, m in enumerate(self.modules):
            name = self.module_names[i]
            a_dimension = self.m_A[m].shape[1]
            #g_dimension = self.m_G[m].shape[1]
            dimensions.append(a_dimension)
            module_factors.append(name + '-A')
            #dimensions.append(g_dimension)
            #module_factors.append(name+'-G')

        descending_sorted_idx = np.argsort(dimensions)[::-1]
        A_ranks = {}
        G_ranks = {}
        for i in descending_sorted_idx:
            factor = module_factors[i]
            dimension = dimensions[i]
            m_i = self.module_names.index(factor[0:-2])
            m = self.modules[m_i]

            bi = np.argmin(buckets)
            buckets[bi] += dimension
            if factor[-1] == 'A':
                A_ranks[m] = (bi, )
                G_ranks[m] = (bi, )
            else:
                G_ranks[m] = (bi, )
        for m in self.modules:
            module_ranks[m] = (A_ranks[m], G_ranks[m])

        self.module_ranks = module_ranks
        if hvd.rank() == 0:
            logger.info('buckets: %s', buckets)
            logger.info('module_ranks: %s', module_ranks.values())
        return module_ranks

    def _generate_eigen_ranks(self, epoch):
        if self.module_ranks is not None:
            return self.module_ranks
        module_ranks = {}
        diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1
        buckets = [0] * hvd.size()

        for module in self.modules:
            i = np.argmin(buckets)
            if hvd.rank() == 0:
                logger.info('A Name: %s, shape: %s', module,
                            self.m_A[module].shape)
                logger.info('G Name: %s, shape: %s', module,
                            self.m_G[module].shape)
            a_dimension = self.m_A[module].shape[1]
            g_dimension = self.m_G[module].shape[1]
            #buckets[i] += (a_dimension) + g_dimension)
            buckets[i] += a_dimension
            ranks_a = (i, )
            i = np.argmin(buckets)
            ranks_g = (i, )
            buckets[i] += g_dimension

            module_ranks[module] = (ranks_a, ranks_g)
        self.module_ranks = module_ranks
        if hvd.rank() == 0:
            logger.info('buckets: %s', buckets)
            logger.info('module_ranks: %s', module_ranks.values())
        return module_ranks

    def _reduce_factors(self, eigen_ranks):
        """Allreduce the factors for all layers"""
        handles = []

        for m in self.modules:
            name = self.module_name_map[m]
            ranks_a, ranks_g = eigen_ranks[m]
            rank_a = ranks_a[0]
            rank_g = ranks_g[0]

            self.fw_merged_comm.reduce_async_(name, self.m_A[m].data, rank_a)
            self.bw_merged_comm.reduce_async_(name, self.m_G[m].data, rank_g)
        self.fw_merged_comm.synchronize()
        self.bw_merged_comm.synchronize()

        #for m in self.modules:
        #    name = self.module_name_map[m]
        #    ranks_a, ranks_g = eigen_ranks[m]
        #    rank_a = ranks_a[0]
        #    rank_g = ranks_g[0]
        #    if rank_a == hvd.rank():
        #        self.m_A[m].data.div_(hvd.size())
        #    if rank_g == hvd.rank():
        #        self.m_G[m].data.div_(hvd.size())

    def _allgather_factors(self):
        """Allgather the factors for all layers"""
        handles = []

        def _get_value_and_idx(sparse_tensor):
            tensor = sparse_tensor.data.view(-1)
            one_indexes = tensor != 0
            indexes = one_indexes.nonzero().data.squeeze().view(-1)
            values = tensor.data[indexes]
            return values, indexes.int()

        for i, m in enumerate(self.modules):
            module_name = self.module_names[i]

            A_values, A_indexes = _get_value_and_idx(self.m_A[m].data)
            A_value_name = module_name + '_A_value'
            A_idx_name = module_name + '_A_idx'
            h_value = allgather_async(A_values, A_value_name)
            h_idx = allgather_async(A_indexes, A_idx_name)

            G_values, G_indexes = _get_value_and_idx(self.m_G[m].data)
            G_value_name = module_name + '_G_value'
            G_idx_name = module_name + '_G_idx'
            h_value_G = allgather_async(G_values, G_value_name)
            h_idx_G = allgather_async(G_indexes, G_idx_name)
            handles.append((h_value, h_idx, h_value_G, h_idx_G))

        for i, handle in enumerate(handles):
            module_name = self.module_names[i]
            module = self.modules[i]
            m_A = self.m_A[module].view(-1)
            m_A.fill_(0.0)
            m_G = self.m_G[module].view(-1)
            m_G.fill_(0.0)

            h_value_A, h_idx_A, h_value_G, h_idx_G = handle
            A_values = hvd.synchronize(h_value_A)
            A_indexes = hvd.synchronize(h_idx_A).long()
            m_A.scatter_add_(0, A_indexes, A_values)
            m_A.div_(hvd.size())

            G_values = hvd.synchronize(h_value_G)
            G_indexes = hvd.synchronize(h_idx_G).long()
            m_G.scatter_add_(0, G_indexes, G_values)
            m_G.div_(hvd.size())

    def _allreduce_eigendecomp(self):
        """Allreduce the eigendecompositions for all layers

        Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will
        either compute the eigendecomposition for a factor or just return
        zeros so we sum instead of averaging.
        """
        handles = []

        for m in self.modules:
            handles.append(hvd.allreduce_async_(self.m_QA[m].data, op=hvd.Sum))
            handles.append(hvd.allreduce_async_(self.m_QG[m].data, op=hvd.Sum))
            handles.append(hvd.allreduce_async_(self.m_dA[m].data, op=hvd.Sum))
            handles.append(hvd.allreduce_async_(self.m_dG[m].data, op=hvd.Sum))

        for handle in handles:
            hvd.synchronize(handle)

    def _broadcast_eigendecomp(self):
        """Broadcasts the eigendecompositions for all layers

        Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will
        either compute the eigendecomposition for a factor or just return
        zeros so we sum instead of averaging.
        """
        rank = hvd.rank()

        for i, m in enumerate(self.modules):
            rank_a = self.m_dA_ranks[m]
            rank_g = self.m_dG_ranks[m]
            name = self.module_names[i]

            self.multi_comm.bcast_async_([name + 'mQA'], [self.m_QA[m]],
                                         rank_a)
            self.multi_comm.bcast_async_([name + 'mQG'], [self.m_QG[m]],
                                         rank_g)
        self.multi_comm.synchronize()