def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): assert param.dim() == 4, "This pruning is only supported for 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) filter_mags = magnitude_fn(view_filters, dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, topk_filters, filter_mags.size(0)) # Then we threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # Compensate for dropping filters #param.data /= float(distiller.sparsity(mask)) return binary_map
def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is transposed. This is counter-intuitive. To deal with this, we can either transpose the matrix and then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ assert param.dim() == 2, "This pruning is only supported for 2D weights" ROWS_DIM = 0 THRESHOLD_DIM = 'Cols' rows_mags = magnitude_fn(param, dim=ROWS_DIM) num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0)) if num_rows_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune) return bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) threshold = bottomk_rows[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) return binary_map
def test_channel_thresholding_2(): p = get_test_4d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, 'Channels', 0.7, 'Mean_L1') # Test the binary map: 1s indicate 3D-channels that have a length-normalized-L2 above 1.3 assert map.shape == torch.Size([3]) assert torch.eq(map, torch.tensor([1., 1., 0.], device=map.device)).all() # Test the full mask expected_mask = torch.tensor( [[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], [[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], device=mask.device) assert torch.eq(mask, expected_mask).all() return mask
def rank_and_prune_rows(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, group_size=1): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is transposed. This is because the output is computed as follows: y = x(W^T) + b ; where W^T is the transpose of W Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W. To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(. We choose the latter, because transposing very large matrices can be detrimental to performance. Note that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far away from each other in memory, and this means poor use of caches and system memory. """ bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param) THRESHOLD_DIM = 'Cols' threshold = bottomk_cols[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) ROWS_DIM = 0 num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM)) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM)) return binary_map
def test_filter_thresholding(): p = get_test_4d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, '3D', 4.7, 'L2') # Test the binary map: 1s indicate 3D-filters that have an L2 above 4.7 assert map.shape == torch.Size([2]) assert torch.eq(map, torch.tensor([0., 1.], device=map.device)).all() # Test the full mask expected_mask = torch.tensor( [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], [[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]], device=mask.device) assert torch.eq(mask, expected_mask).all() return mask
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" threshold = None if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) filter_mags = view_filters.data.abs().mean(dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)", param_name, topk_filters, filter_mags.size(0)) # Then we threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", param_name, distiller.sparsity(mask), fraction_to_prune) return binary_map
def test_col_thresholding(): p = get_test_2d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, 'Cols', 11, 'Max') assert torch.eq(mask, torch.tensor([[ 0., 0., 1.], [ 0., 0., 1.], [ 0., 0., 1.], [ 0., 0., 1.]], device=mask.device)).all() return mask
def test_row_thresholding(): p = get_test_2d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, 'Rows', 7, 'Max') assert torch.eq(map, torch.tensor([0., 0., 1., 1.], device=mask.device)).all() assert torch.eq( mask, torch.tensor([[0., 0., 0.], [0., 0., 0.], [1., 1., 1.], [1., 1., 1.]], device=mask.device)).all() masked_tensor = distiller.mask_tensor(p, mask) assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask) return mask
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, noise=0.0, group_size=1, rounding_fn=math.floor): assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' num_filters = param.size(0) num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters) num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size) # We can't allow removing all of the filters! -- # Except when the fraction_to_prune is explicitly instructing us to do so. if num_filters_to_prune == num_filters and fraction_to_prune != 1.0: num_filters_to_prune = num_filters - group_size # We can't allow removing all of the filters! if binary_map is None: # First we rank the filters view_filters = param.view(num_filters, -1) filter_mags = magnitude_fn(view_filters, dim=1) if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, num_filters_to_prune, filter_mags.size(0)) # Now apply a threshold mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # param.data = torch.randn_like(param) return binary_map
def test_kernel_thresholding(): p = get_test_4d_tensor().cuda() mask, map = distiller.group_threshold_mask(p, '2D', 6, 'L1') # Test the binary map: 1s indicate 2D-kernels that have an L1 above 6 assert map.shape == torch.Size([6]) assert torch.eq(map, torch.tensor([1., 1., 0., 1., 0., 1.], device=map.device)).all() # Test the full mask expected_mask = torch.tensor( [[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], [[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]], device=mask.device) assert torch.eq(mask, expected_mask).all() return mask
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, noise=0.0, group_size=1, rounding_fn=math.floor, fpgm=False, HRank=False, conv_index={}): assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' num_filters = param.size(0) num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters) num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size) # We can't allow removing all of the filters! -- # Except when the fraction_to_prune is explicitly instructing us to do so. if num_filters_to_prune == num_filters and fraction_to_prune != 1.0: num_filters_to_prune = num_filters - group_size # We can't allow removing all of the filters! if binary_map is None and fpgm==False and HRank==False: print("RANK") # First we rank the filters view_filters = param.view(num_filters, -1) filter_mags = magnitude_fn(view_filters, dim=1) # L1 norm if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True) threshold = bottomk[-1] msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", threshold_type, param_name, num_filters_to_prune, filter_mags.size(0)) # Now apply a threshold threshold是临界值, threshold_type表示用L1 norm还是L2 norm, binary_map为空, binary_map是将threshold与对权值norm后中逐个比较形成的真假数组 mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) # ############################################################################################################ elif fpgm: binary_map = None if binary_map is None: print("FPGM") # First we rank the filters view_filters = param.view(num_filters, -1) # filter_mags = magnitude_fn(view_filters, dim=1) # L1 norm # print(filter_mags.shape) if noise and uniform(0, 1) <= noise: msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", threshold_type, param_name) filter_mags *= torch.randn_like(filter_mags) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune) return view_filters = view_filters.cpu().detach().numpy() similar_matrix = distance.cdist(view_filters, view_filters, 'euclidean') similar_sum = np.sum(np.abs(similar_matrix), axis=0) similar_small_index = similar_sum.argsort()[:num_filters_to_prune] binary_map = torch.Tensor([1] * param.size(0)).to(param.device) for i in similar_small_index: binary_map[i] = 0 a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() mask = a.view(*param.shape) # ############################################################################################################# elif binary_map is None and HRank: print("HRank",param_name) small_index = conv_index[param_name].argsort()[:num_filters_to_prune] binary_map = torch.Tensor([1] * param.size(0)).to(param.device) for i in small_index: binary_map[i] = 0 binary_all = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() mask = binary_all.view(*param.shape) # ############################################################################################################# if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) # param.data = torch.randn_like(param) return binary_map