def test_magnitude_pruning(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create a masks dictionary and populate it with one ParameterMasker zeros_mask_dict = {} masker = distiller.ParameterMasker('a') zeros_mask_dict['a'] = masker # Try to use a MagnitudeParameterPruner with defining a default threshold with pytest.raises(AssertionError): pruner = distiller.pruning.MagnitudeParameterPruner("test", None) # Now define the default threshold thresholds = {"*": 0.4} pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds) assert distiller.sparsity(a) == 0 # Create a mask for parameter 'a' pruner.set_param_mask(a, 'a', zeros_mask_dict, None) assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a)) # Let's now use the masker to prune a parameter masker = zeros_mask_dict['a'] masker.apply_mask(a) assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a)) # We can use the masker on other tensors, if we want (and if they have the correct shape). # Remember that the mask was created already, so we're not thresholding - we are pruning b = torch.ones(3, 64, 32, 32) b[:] = 0.3 masker.apply_mask(b) assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
def test_sparsity(): zeros = torch.zeros(2,3,5,6) print(distiller.sparsity(zeros)) assert distiller.sparsity(zeros) == 1.0 ones = torch.zeros(12,43,4,6) ones.fill_(1) assert distiller.sparsity(ones) == 0.0
def test_row_thresholding(): p = get_test_2d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, 'Rows', 7, 'Max') assert torch.eq(map, torch.tensor([0., 0., 1., 1.], device=mask.device)).all() assert torch.eq( mask, torch.tensor([[0., 0., 0.], [0., 0., 0.], [1., 1., 1.], [1., 1., 1.]], device=mask.device)).all() masked_tensor = distiller.mask_tensor(p, mask) assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask) return mask
def test_level_mask(): # Create a 4-D tensor of 1s a = torch.rand(3, 64, 32, 32) # Create and apply a mask mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3) assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001)
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" threshold = None if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) filter_mags = view_filters.data.abs().mean(dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)", param_name, topk_filters, filter_mags.size(0)) # Then we threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", param_name, distiller.sparsity(mask), fraction_to_prune) return binary_map
def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, group_size=1): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is transposed. This is because the output is computed as follows: y = x(W^T) + b ; where W^T is the transpose of W Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W. To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param) THRESHOLD_DIM = 'Cols' threshold = bottomk_cols[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) ROWS_DIM = 0 num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM)) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM)) return binary_map
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm, noise=0.0, group_size=1, rounding_fn=math.floor): assert param.dim() == 4 or param.dim( ) == 3, "This pruning is only supported for 3D and 4D weights" if binary_map is None: bottomk_filters, filter_mags = distiller.norms.rank_filters( param, group_size, magnitude_fn, fraction_to_prune, rounding_fn, noise) if bottomk_filters is None: # Empty list means that fraction_to_prune is too low to prune anything msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return threshold = bottomk_filters[-1] binary_map = filter_mags.gt(threshold).type(param.data.type()) if zeros_mask_dict is not None: mask, _ = distiller.thresholding.expand_binary_map( param, 'Filters', binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info( "%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", magnitude_fn, param_name, distiller.sparsity(mask), fraction_to_prune) return binary_map
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): assert param.dim() == 4, "This pruning is only supported for 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) filter_mags = magnitude_fn(view_filters, dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, topk_filters, filter_mags.size(0)) # Then we threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # Compensate for dropping filters #param.data /= float(distiller.sparsity(mask)) return binary_map
def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is transposed. This is counter-intuitive. To deal with this, we can either transpose the matrix and then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ assert param.dim() == 2, "This pruning is only supported for 2D weights" ROWS_DIM = 0 THRESHOLD_DIM = 'Cols' rows_mags = magnitude_fn(param, dim=ROWS_DIM) num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0)) if num_rows_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune) return bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) threshold = bottomk_rows[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) return binary_map
def rank_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict): assert param.dim( ) == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) filter_mags = view_filters.data.norm( 1, dim=1) # same as view_filters.data.abs().sum(dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] binary_map = filter_mags.gt(threshold).type(param.data.type()) expanded = binary_map.expand( param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() zeros_mask_dict[param_name].mask = expanded.view( param.size(0), param.size(1), param.size(2), param.size(3)) msglogger.info( "L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, topk_filters, filter_mags.size(0))
def set_param_mask(self, param, param_name, zeros_mask_dict, meta): if param_name not in self.reg_regims.keys(): return group_type = self.reg_regims[param_name][1] fraction_to_prune = self.reg_regims[param_name][0] if fraction_to_prune == 0: return assert group_type == "3D", "Currently only filter ranking is supported" assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) filter_mags = view_filters.data.abs().mean(dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] binary_map = filter_mags.gt(threshold).type(type(param.data)) expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3)) msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, topk_filters, filter_mags.size(0))
def print_sparsities(masks_dict): mask_sparsities = [(param_name, distiller.sparsity(mask)) for param_name, mask in masks_dict.items() if mask is not None] print( tabulate(mask_sparsities, headers=["Module", "Mask Sparsity"], tablefmt="fancy_grid"))
def test_sparsity(): zeros = torch.zeros(2, 3, 5, 6) print(distiller.sparsity(zeros)) assert distiller.sparsity(zeros) == 1.0 assert distiller.sparsity_3D(zeros) == 1.0 assert distiller.density_3D(zeros) == 0.0 ones = torch.ones(12, 43, 4, 6) assert distiller.sparsity(ones) == 0.0 x = torch.tensor([[1., 2., 0, 4., 0], [1., 2., 0, 4., 0]]) assert distiller.density(x) == 0.6 assert distiller.density_cols(x, transposed=False) == 0.6 assert distiller.sparsity_rows(x, transposed=False) == 0 x = torch.tensor([[0., 0., 0], [1., 4., 0], [1., 2., 0], [0., 0., 0]]) assert distiller.density(x) == 4 / 12 assert distiller.sparsity_rows(x, transposed=False) == 0.5 assert common.almost_equal(distiller.sparsity_cols(x, transposed=False), 1 / 3) assert common.almost_equal(distiller.sparsity_rows(x), 1 / 3)
def test_threshold_mask(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask mask = distiller.threshold_mask(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
def test_sensitivity_mask(): # Create a 4-D tensor of normally-distributed coefficients a = torch.randn(3, 64, 32, 32) # Create and apply a mask mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1) # The width of 1-std on ~N(0,1) is about 68.27%. In other words: # Pr(mean - std <= X <= mean + std) is about 68.27% assert common.almost_equal(distiller.sparsity(mask), 0.6827, max_diff=0.005)
def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=distiller.norms.l1_norm, group_size=1): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is transposed. This is because the output is computed as follows: y = x(W^T) + b ; where W^T is the transpose of W Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W. To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ if binary_map is None: bottomk_cols, cols_mags = distiller.norms.rank_cols( param, group_size, magnitude_fn, fraction_to_prune, rounding_fn=math.floor, noise=None) if bottomk_cols is None: # Empty list means that fraction_to_prune is too low to prune anything msglogger.info("Too few cols - can't prune %.1f%% cols", 100 * fraction_to_prune) return threshold = bottomk_cols[-1] binary_map = cols_mags.gt(threshold).type(param.data.type()) if zeros_mask_dict is not None: mask, _ = distiller.thresholding.expand_binary_map( param, 'Cols', binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info( "%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", magnitude_fn, param_name, distiller.sparsity(mask), fraction_to_prune) return binary_map
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, noise=0.0, group_size=1, rounding_fn=math.floor): assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' num_filters = param.size(0) num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters) num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size) # We can't allow removing all of the filters! -- # Except when the fraction_to_prune is explicitly instructing us to do so. if num_filters_to_prune == num_filters and fraction_to_prune != 1.0: num_filters_to_prune = num_filters - group_size # We can't allow removing all of the filters! if binary_map is None: # First we rank the filters view_filters = param.view(num_filters, -1) filter_mags = magnitude_fn(view_filters, dim=1) if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, num_filters_to_prune, filter_mags.size(0)) # Now apply a threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # param.data = torch.randn_like(param) return binary_map
def extract_lottery_ticket(args, untrained_ckpt_name, pruned_ckpt_name): untrained_ckpt = apputils.load_checkpoint(model=None, chkpt_file=untrained_ckpt_name, model_device='cpu') untrained_model, _, optimizer, start_epoch = untrained_ckpt pruned_ckpt = apputils.load_checkpoint(model=None, chkpt_file=pruned_ckpt_name, model_device='cpu') pruned_model, pruned_scheduler, optimizer, start_epoch = pruned_ckpt # create a dictionary of masks by inferring the masks from the parameter sparsity masks_dict = {pname: (torch.ne(param, 0)).type(param.type()) for pname, param in pruned_model.named_parameters() if pname in pruned_scheduler.zeros_mask_dict.keys()} for pname, mask in masks_dict.items(): untrained_model.state_dict()[pname].mul_(mask) sparsities = {pname: distiller.sparsity(mask) for pname, mask in masks_dict.items()} print(sparsities) pruned_scheduler.init_from_masks_dict(masks_dict) apputils.save_checkpoint(0, pruned_model.arch, untrained_model, optimizer=optimizer, scheduler=pruned_scheduler, name='_'.join([untrained_ckpt_name, 'masked']))
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, noise=0.0, group_size=1, rounding_fn=math.floor, fpgm=False, HRank=False, conv_index={}): assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' num_filters = param.size(0) num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters) num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size) # We can't allow removing all of the filters! -- # Except when the fraction_to_prune is explicitly instructing us to do so. if num_filters_to_prune == num_filters and fraction_to_prune != 1.0: num_filters_to_prune = num_filters - group_size # We can't allow removing all of the filters! if binary_map is None and fpgm==False and HRank==False: print("RANK") # First we rank the filters view_filters = param.view(num_filters, -1) filter_mags = magnitude_fn(view_filters, dim=1) # L1 norm if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, num_filters_to_prune, filter_mags.size(0)) # Now apply a threshold threshold是临界值, threshold_type表示用L1 norm还是L2 norm, binary_map为空, binary_map是将threshold与对权值norm后中逐个比较形成的真假数组 mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) # ############################################################################################################ elif fpgm: binary_map = None if binary_map is None: print("FPGM") # First we rank the filters view_filters = param.view(num_filters, -1) # filter_mags = magnitude_fn(view_filters, dim=1) # L1 norm # print(filter_mags.shape) if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return view_filters = view_filters.cpu().detach().numpy() similar_matrix = distance.cdist(view_filters, view_filters, 'euclidean') similar_sum = np.sum(np.abs(similar_matrix), axis=0) similar_small_index = similar_sum.argsort()[:num_filters_to_prune] binary_map = torch.Tensor([1] * param.size(0)).to(param.device) for i in similar_small_index: binary_map[i] = 0 a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() mask = a.view(*param.shape) # ############################################################################################################# elif binary_map is None and HRank: print("HRank",param_name) small_index = conv_index[param_name].argsort()[:num_filters_to_prune] binary_map = torch.Tensor([1] * param.size(0)).to(param.device) for i in small_index: binary_map[i] = 0 binary_all = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() mask = binary_all.view(*param.shape) # ############################################################################################################# if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # param.data = torch.randn_like(param) return binary_map