def global_add_by_abs_grad(sparse_modules, num_add): """ Adds weights globally (among all given sparse modules) by ranking and selecting the top absolute gradients. Weights are added by adjusting the `zero_masks` of their modules; switching on the weight occurs by changing 1 to 0. Be sure to call `rezero_weights` following this function to initialize the new weights to zero. :param sparse_modules: list of modules of type SparseWeightsBase :param num_add: number of weights to add """ flattened_grads = parameters_to_vector( [m.weight.grad for m in sparse_modules]) flattened_mask = parameters_to_vector( [m.zero_mask for m in sparse_modules]).bool() # Find a mask of the top grads for weights that are currently off. topk_mask = get_topk_submask(k=num_add, values=flattened_grads.abs(), mask=flattened_mask, largest=True) # Pointer for slicing the mask to match the shape of each parameter. pointer = 0 for module in sparse_modules: num_params = module.weight.numel() add_mask = topk_mask[pointer:pointer + num_params].view_as( module.weight) module.zero_mask[add_mask] = 0 pointer += num_params
def test_topk_submask_structure(self): values = torch.rand(5, 6, 7) k = 5 mask = torch.rand_like(values) < 0.5 submask = get_topk_submask(k, values, mask, largest=True) # Validate that the submask's on-positions are a subset of the original. is_subset = (submask <= mask).all() self.assertTrue(is_subset) # Validate only k positions have been chosen. self.assertEqual(submask.sum(), k)
def global_prune_by_abs_weight(sparse_modules, prune_fraction=None, num_remove=None): """ Globally prune 'num_remove' weights from a list of sparse modules by ranking and selecting the top absolute weights. If prune_fraction is given, then num_removed is calculated. Modules are pruned adjusting their `zero_masks`; switching off the weight occurs by changing 0 to 1. Be sure to call `rezero_weights` following this function to zero out the pruned weights. :param sparse_modules: list of modules of type SparseWeightsBase :param prune_fraction: fraction of weights to prune; between 0 and 1; can't be specified if num_remove is not None :param num_remove: how many parameters to remove; can't be specified if prune_fraction is not None """ # Only one of these arguments may be given. assert not (prune_fraction is not None and num_remove is not None) # Flatten parameters to compare them all at once for global pruning. flattened_params = parameters_to_vector([m.weight for m in sparse_modules]) flattened_off_mask = parameters_to_vector( [m.zero_mask for m in sparse_modules]) flattened_on_mask = ~flattened_off_mask.bool() # Calculate the number of parameters to keep. total_on = flattened_on_mask.sum().item() if prune_fraction is not None: assert 0 <= prune_fraction <= 1 num_remove = int(round(total_on * prune_fraction)) num_keep = total_on - num_remove # Prune by only keeping the top weights ranked by their absolute values. topk_mask = get_topk_submask(k=num_keep, values=flattened_params.abs(), mask=flattened_on_mask, largest=True) # Pointer for slicing the mask to match the shape of each parameter pointer = 0 for module in sparse_modules: num_params = module.weight.numel() keep_mask = topk_mask[pointer:pointer + num_params].view_as( module.weight) module.zero_mask[:] = ~keep_mask pointer += num_params return num_remove
def test_get_topk_submask(self): k = 3 values = torch.tensor([[0.5086, 0.5467, 0.2095], [0.9721, 0.2540, 0.2837], [0.4696, 0.9867, 0.6543]]) mask = torch.tensor([[True, True, False], [True, True, False], [True, False, True]]) submask = get_topk_submask(k, values, mask, largest=True) expected_submask = torch.tensor([[False, True, False], [True, False, False], [False, False, True]]) # Validate that the submask's on-positions are a subset of the original. is_subset = (submask <= mask).all() self.assertTrue(is_subset) # Validate calculated submask. all_equal = (submask == expected_submask).all() self.assertTrue(all_equal)
def test_get_bottomk_submask(self): k = 3 values = torch.tensor([[0.5184, 0.1562, 0.3428], [0.7742, 0.8507, 0.0986], [0.3525, 0.8384, 0.4315]]) mask = torch.tensor([[True, False, True], [True, True, False], [True, False, False]]) submask = get_topk_submask(k, values, mask, largest=False) expected_submask = torch.tensor([[True, False, True], [False, False, False], [True, False, False]]) # Validate that the submask's on-positions are a subset of the original. is_subset = (submask <= mask).all() self.assertTrue(is_subset) # Validate calculated submask. all_equal = (submask == expected_submask).all() self.assertTrue(all_equal)
def local_add_by_abs_grad(sparse_modules, num_add): """ Adds `num_add` weights distributed across the modules by ranking and selecting the top absolute gradients for each `sparse_module`. Weights are added by adjusting the sparse module `zero_masks`. Be sure to call `rezero_weights` following this function to initialize the new weights to zero. :param sparse_modules: list of modules of type SparseWeightsBase :param num_add: list with number of weights to add per sparse_module, usually the output of :meth:`local_prune_by_abs_weight` """ for i, module in enumerate(sparse_modules): grads = module.weight.grad.flatten() mask = module.zero_mask.detach().flatten().bool() # Find a mask of the top grads for weights that are currently off. topk_mask = get_topk_submask(k=num_add[i], values=grads.abs(), mask=mask, largest=True) module.zero_mask[topk_mask.view_as(module.weight)] = 0.
def local_prune_by_abs_weight(sparse_modules, prune_fraction): """ Prune `prune_fraction` of the weigths of each module in `sparse_modules`. Modules are pruned by ranking and selecting the top absolute weights of each module and adjusting the sparse module `zero_mask`s. Be sure to call `rezero_weights` following this function to zero out the pruned weights. :param sparse_modules: list of modules of type SparseWeightsBase :param prune_fraction: fraction of weights to prune; between 0 and 1 :return: list containing the number of weights removed for each `sparse_module` """ assert 0 <= prune_fraction <= 1 total_removed = [] for module in sparse_modules: params = module.weight.detach().flatten() off_mask = module.zero_mask.detach().flatten().bool() # Compute new top K value on_mask = ~off_mask total_on = on_mask.sum() num_remove = (total_on * prune_fraction).round() k = int(total_on - num_remove) # Prune by only keeping the top weights ranked by their absolute values. topk_mask = get_topk_submask(k=k, values=params.abs(), mask=on_mask, largest=True) # Update module with new mask module.zero_mask[:] = (~topk_mask).view_as(module.weight) total_removed.append(int(num_remove)) return total_removed