Example #1
0
def test_magnitude_pruning():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create a masks dictionary and populate it with one ParameterMasker
    zeros_mask_dict = {}
    masker = distiller.ParameterMasker('a')
    zeros_mask_dict['a'] = masker
    # Try to use a MagnitudeParameterPruner with defining a default threshold
    with pytest.raises(AssertionError):
        pruner = distiller.pruning.MagnitudeParameterPruner("test", None)

    # Now define the default threshold
    thresholds = {"*": 0.4}
    pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds)
    assert distiller.sparsity(a) == 0
    # Create a mask for parameter 'a'
    pruner.set_param_mask(a, 'a', zeros_mask_dict, None)
    assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a))

    # Let's now use the masker to prune a parameter
    masker = zeros_mask_dict['a']
    masker.apply_mask(a)
    assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a))
    # We can use the masker on other tensors, if we want (and if they have the correct shape).
    # Remember that the mask was created already, so we're not thresholding - we are pruning
    b = torch.ones(3, 64, 32, 32)
    b[:] = 0.3
    masker.apply_mask(b)
    assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
Example #2
0
def test_sparsity():
    zeros = torch.zeros(2,3,5,6)
    print(distiller.sparsity(zeros))
    assert distiller.sparsity(zeros) == 1.0

    ones = torch.zeros(12,43,4,6)
    ones.fill_(1)
    assert distiller.sparsity(ones) == 0.0
Example #3
0
def test_row_thresholding():
    p = get_test_2d_tensor().cuda()
    mask, map = distiller.group_threshold_mask(p, 'Rows', 7, 'Max')

    assert torch.eq(map, torch.tensor([0., 0., 1., 1.],
                                      device=mask.device)).all()
    assert torch.eq(
        mask,
        torch.tensor([[0., 0., 0.], [0., 0., 0.], [1., 1., 1.], [1., 1., 1.]],
                     device=mask.device)).all()
    masked_tensor = distiller.mask_tensor(p, mask)
    assert distiller.sparsity(masked_tensor) == distiller.sparsity(mask)
    return mask
Example #4
0
def test_level_mask():
    # Create a 4-D tensor of 1s
    a = torch.rand(3, 64, 32, 32)

    # Create and apply a mask
    mask = distiller.create_mask_level_criterion(a, desired_sparsity=0.3)
    assert common.almost_equal(distiller.sparsity(mask), 0.3, max_diff=0.0001)
    def rank_and_prune_filters(fraction_to_prune, param, param_name,
                               zeros_mask_dict, model=None, binary_map=None):
        assert param.dim() == 4, "This thresholding is only supported for 4D weights"

        threshold = None
        if binary_map is None:
            # First we rank the filters
            view_filters = param.view(param.size(0), -1)
            filter_mags = view_filters.data.abs().mean(dim=1)
            topk_filters = int(fraction_to_prune * filter_mags.size(0))
            if topk_filters == 0:
                msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
                return
            bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
            threshold = bottomk[-1]
            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)",
                           param_name,
                           topk_filters, filter_mags.size(0))
        # Then we threshold
        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map)
        if zeros_mask_dict is not None:
            zeros_mask_dict[param_name].mask = mask
        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                       param_name,
                       distiller.sparsity(mask),
                       fraction_to_prune)
        return binary_map
    def rank_and_prune_rows(fraction_to_prune, param, param_name,
                            zeros_mask_dict, model=None, binary_map=None,
                            magnitude_fn=l1_magnitude, group_size=1):
        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.

        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
        transposed.  This is because the output is computed as follows:
            y = x(W^T) + b ; where W^T is the transpose of W

        Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W.

        To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks
        as usual, or we can treat columns as rows, and rows as columns :-(.
        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
        that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far
        away from each other in memory, and this means poor use of caches and system memory.
        """
        bottomk_cols, cols_mags = LpRankedStructureParameterPruner.rank_rows(magnitude_fn, fraction_to_prune, param)
        THRESHOLD_DIM = 'Cols'
        threshold = bottomk_cols[-1]
        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
        zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
                                                                                      threshold, threshold_type)
        ROWS_DIM = 0
        num_cols_to_prune = int(fraction_to_prune * cols_mags.size(ROWS_DIM))
        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       threshold_type, param_name,
                       distiller.sparsity(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_cols_to_prune, cols_mags.size(ROWS_DIM))
        return binary_map
    def rank_and_prune_filters(fraction_to_prune,
                               param,
                               param_name,
                               zeros_mask_dict,
                               model=None,
                               binary_map=None,
                               magnitude_fn=distiller.norms.l1_norm,
                               noise=0.0,
                               group_size=1,
                               rounding_fn=math.floor):
        assert param.dim() == 4 or param.dim(
        ) == 3, "This pruning is only supported for 3D and 4D weights"
        if binary_map is None:
            bottomk_filters, filter_mags = distiller.norms.rank_filters(
                param, group_size, magnitude_fn, fraction_to_prune,
                rounding_fn, noise)
            if bottomk_filters is None:
                # Empty list means that fraction_to_prune is too low to prune anything
                msglogger.info("Too few filters - can't prune %.1f%% filters",
                               100 * fraction_to_prune)
                return
            threshold = bottomk_filters[-1]
            binary_map = filter_mags.gt(threshold).type(param.data.type())

        if zeros_mask_dict is not None:
            mask, _ = distiller.thresholding.expand_binary_map(
                param, 'Filters', binary_map)
            zeros_mask_dict[param_name].mask = mask
            msglogger.info(
                "%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                magnitude_fn, param_name, distiller.sparsity(mask),
                fraction_to_prune)
        return binary_map
    def rank_and_prune_filters(fraction_to_prune, param, param_name,
                               zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
        assert param.dim() == 4, "This pruning is only supported for 4D weights"

        threshold = None
        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
        if binary_map is None:
            # First we rank the filters
            view_filters = param.view(param.size(0), -1)
            filter_mags = magnitude_fn(view_filters, dim=1)
            topk_filters = int(fraction_to_prune * filter_mags.size(0))
            if topk_filters == 0:
                msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
                return
            bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
            threshold = bottomk[-1]
            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
                           threshold_type, param_name,
                           topk_filters, filter_mags.size(0))
        # Then we threshold
        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
        if zeros_mask_dict is not None:
            zeros_mask_dict[param_name].mask = mask
        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                       threshold_type, param_name,
                       distiller.sparsity(mask),
                       fraction_to_prune)

        # Compensate for dropping filters
        #param.data /= float(distiller.sparsity(mask))
        return binary_map
    def rank_and_prune_rows(fraction_to_prune, param, param_name,
                            zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.

        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
        transposed.  This is counter-intuitive.  To deal with this, we can either transpose the matrix and
        then proceed to compute the masks as usual, or we can treat columns as rows, and rows as columns :-(.
        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
        that computing mean L1-norm of columns is also not optimal, because consequtive column elements are far
        away from each other in memory, and this means poor use of caches and system memory.
        """

        assert param.dim() == 2, "This pruning is only supported for 2D weights"
        ROWS_DIM = 0
        THRESHOLD_DIM = 'Cols'
        rows_mags = magnitude_fn(param, dim=ROWS_DIM)
        num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
        if num_rows_to_prune == 0:
            msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
            return
        bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
        threshold = bottomk_rows[-1]
        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
        zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
                                                                                      threshold, threshold_type)
        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
                       threshold_type, param_name,
                       distiller.sparsity(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
        return binary_map
    def rank_prune_filters(self, fraction_to_prune, param, param_name,
                           zeros_mask_dict):
        assert param.dim(
        ) == 4, "This thresholding is only supported for 4D weights"
        view_filters = param.view(param.size(0), -1)
        filter_mags = view_filters.data.norm(
            1, dim=1)  # same as view_filters.data.abs().sum(dim=1)
        topk_filters = int(fraction_to_prune * filter_mags.size(0))
        if topk_filters == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters",
                           100 * fraction_to_prune)
            return

        bottomk, _ = torch.topk(filter_mags,
                                topk_filters,
                                largest=False,
                                sorted=True)
        threshold = bottomk[-1]
        binary_map = filter_mags.gt(threshold).type(param.data.type())
        expanded = binary_map.expand(
            param.size(1) * param.size(2) * param.size(3),
            param.size(0)).t().contiguous()
        zeros_mask_dict[param_name].mask = expanded.view(
            param.size(0), param.size(1), param.size(2), param.size(3))
        msglogger.info(
            "L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
            param_name, distiller.sparsity(zeros_mask_dict[param_name].mask),
            fraction_to_prune, topk_filters, filter_mags.size(0))
    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
        if param_name not in self.reg_regims.keys():
            return

        group_type = self.reg_regims[param_name][1]
        fraction_to_prune = self.reg_regims[param_name][0]
        if fraction_to_prune == 0:
            return

        assert group_type == "3D", "Currently only filter ranking is supported"
        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
        view_filters = param.view(param.size(0), -1)
        filter_mags = view_filters.data.abs().mean(dim=1)
        topk_filters = int(fraction_to_prune * filter_mags.size(0))
        if topk_filters == 0:
            msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
            return

        bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
        threshold = bottomk[-1]
        binary_map = filter_mags.gt(threshold).type(type(param.data))
        expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
        zeros_mask_dict[param_name].mask = expanded.view(param.size(0), param.size(1), param.size(2), param.size(3))
        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
                       distiller.sparsity(zeros_mask_dict[param_name].mask),
                       fraction_to_prune, topk_filters, filter_mags.size(0))
Example #12
0
def print_sparsities(masks_dict):
    mask_sparsities = [(param_name, distiller.sparsity(mask))
                       for param_name, mask in masks_dict.items()
                       if mask is not None]
    print(
        tabulate(mask_sparsities,
                 headers=["Module", "Mask Sparsity"],
                 tablefmt="fancy_grid"))
Example #13
0
def test_sparsity():
    zeros = torch.zeros(2, 3, 5, 6)
    print(distiller.sparsity(zeros))
    assert distiller.sparsity(zeros) == 1.0
    assert distiller.sparsity_3D(zeros) == 1.0
    assert distiller.density_3D(zeros) == 0.0
    ones = torch.ones(12, 43, 4, 6)
    assert distiller.sparsity(ones) == 0.0
    x = torch.tensor([[1., 2., 0, 4., 0], [1., 2., 0, 4., 0]])
    assert distiller.density(x) == 0.6
    assert distiller.density_cols(x, transposed=False) == 0.6
    assert distiller.sparsity_rows(x, transposed=False) == 0
    x = torch.tensor([[0., 0., 0], [1., 4., 0], [1., 2., 0], [0., 0., 0]])
    assert distiller.density(x) == 4 / 12
    assert distiller.sparsity_rows(x, transposed=False) == 0.5
    assert common.almost_equal(distiller.sparsity_cols(x, transposed=False),
                               1 / 3)
    assert common.almost_equal(distiller.sparsity_rows(x), 1 / 3)
Example #14
0
def test_threshold_mask():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create and apply a mask
    mask = distiller.threshold_mask(a, threshold=0.3)
    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
    assert mask[1, 4, 17, 31] == 0
    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
Example #15
0
def test_sensitivity_mask():
    # Create a 4-D tensor of normally-distributed coefficients
    a = torch.randn(3, 64, 32, 32)

    # Create and apply a mask
    mask = distiller.create_mask_sensitivity_criterion(a, sensitivity=1)
    # The width of 1-std on ~N(0,1) is about 68.27%.  In other words:
    # Pr(mean - std <= X <= mean + std) is about 68.27%
    assert common.almost_equal(distiller.sparsity(mask),
                               0.6827,
                               max_diff=0.005)
    def rank_and_prune_rows(fraction_to_prune,
                            param,
                            param_name,
                            zeros_mask_dict,
                            model=None,
                            binary_map=None,
                            magnitude_fn=distiller.norms.l1_norm,
                            group_size=1):
        """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.

        PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
        transposed.  This is because the output is computed as follows:
            y = x(W^T) + b ; where W^T is the transpose of W

        Removing input_channels from W^T, is removing rows of W^T, which is removing columns of W.

        To deal with this rotation, we can either transpose the matrix and then proceed to compute the masks
        as usual, or we can treat columns as rows, and rows as columns :-(.
        We choose the latter, because transposing very large matrices can be detrimental to performance.  Note
        that computing mean L1-norm of columns is also not optimal, because consecutive column elements are far
        away from each other in memory, and this means poor use of caches and system memory.
        """
        if binary_map is None:
            bottomk_cols, cols_mags = distiller.norms.rank_cols(
                param,
                group_size,
                magnitude_fn,
                fraction_to_prune,
                rounding_fn=math.floor,
                noise=None)
            if bottomk_cols is None:
                # Empty list means that fraction_to_prune is too low to prune anything
                msglogger.info("Too few cols - can't prune %.1f%% cols",
                               100 * fraction_to_prune)
                return
            threshold = bottomk_cols[-1]
            binary_map = cols_mags.gt(threshold).type(param.data.type())

        if zeros_mask_dict is not None:
            mask, _ = distiller.thresholding.expand_binary_map(
                param, 'Cols', binary_map)
            zeros_mask_dict[param_name].mask = mask
            msglogger.info(
                "%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                magnitude_fn, param_name, distiller.sparsity(mask),
                fraction_to_prune)
        return binary_map
    def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, 
                               model=None, binary_map=None, magnitude_fn=l1_magnitude, 
                               noise=0.0, group_size=1, rounding_fn=math.floor):
        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"

        threshold = None
        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
        num_filters = param.size(0)
        num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters)
        num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size)
        # We can't allow removing all of the filters! --
        # Except when the fraction_to_prune is explicitly instructing us to do so.
        if num_filters_to_prune == num_filters and fraction_to_prune != 1.0:
            num_filters_to_prune = num_filters - group_size  # We can't allow removing all of the filters!

        if binary_map is None:
            # First we rank the filters
            view_filters = param.view(num_filters, -1)
            filter_mags = magnitude_fn(view_filters, dim=1)

            if noise and uniform(0, 1) <= noise:
                msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", 
                               threshold_type, param_name)
                filter_mags *= torch.randn_like(filter_mags)

            if num_filters_to_prune == 0:
                msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
                return
            bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True)
            threshold = bottomk[-1]
            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
                           threshold_type, param_name,
                           num_filters_to_prune, filter_mags.size(0))

        # Now apply a threshold
        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)

        if zeros_mask_dict is not None:
            zeros_mask_dict[param_name].mask = mask
        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                       threshold_type, param_name,
                       distiller.sparsity(mask),
                       fraction_to_prune)
        # param.data = torch.randn_like(param)
        return binary_map
Example #18
0
def extract_lottery_ticket(args, untrained_ckpt_name, pruned_ckpt_name):
    untrained_ckpt = apputils.load_checkpoint(model=None, chkpt_file=untrained_ckpt_name, model_device='cpu')
    untrained_model, _, optimizer, start_epoch = untrained_ckpt

    pruned_ckpt = apputils.load_checkpoint(model=None, chkpt_file=pruned_ckpt_name, model_device='cpu')
    pruned_model, pruned_scheduler, optimizer, start_epoch = pruned_ckpt

    # create a dictionary of masks by inferring the masks from the parameter sparsity
    masks_dict = {pname: (torch.ne(param, 0)).type(param.type())
                  for pname, param in pruned_model.named_parameters()
                  if pname in pruned_scheduler.zeros_mask_dict.keys()}
    for pname, mask in masks_dict.items():
        untrained_model.state_dict()[pname].mul_(mask)

    sparsities = {pname: distiller.sparsity(mask) for pname, mask in masks_dict.items()}
    print(sparsities)
    pruned_scheduler.init_from_masks_dict(masks_dict)

    apputils.save_checkpoint(0, pruned_model.arch, untrained_model, optimizer=optimizer,
                             scheduler=pruned_scheduler,
                             name='_'.join([untrained_ckpt_name, 'masked']))
    def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, 
                               model=None, binary_map=None, magnitude_fn=l1_magnitude, 
                               noise=0.0, group_size=1, rounding_fn=math.floor, fpgm=False, HRank=False, conv_index={}):
        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"

        threshold = None
        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
        num_filters = param.size(0)
        num_filters_to_prune = rounding_fn(fraction_to_prune * num_filters)
        num_filters_to_prune = int(rounding_fn(num_filters_to_prune * 1. / group_size) * group_size)
        # We can't allow removing all of the filters! --
        # Except when the fraction_to_prune is explicitly instructing us to do so.
        if num_filters_to_prune == num_filters and fraction_to_prune != 1.0:
            num_filters_to_prune = num_filters - group_size  # We can't allow removing all of the filters!

        if binary_map is None and fpgm==False and HRank==False:
            print("RANK")
            # First we rank the filters
            view_filters = param.view(num_filters, -1)
            filter_mags = magnitude_fn(view_filters, dim=1) # L1 norm

            if noise and uniform(0, 1) <= noise:
                msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters", 
                               threshold_type, param_name)
                filter_mags *= torch.randn_like(filter_mags)

            if num_filters_to_prune == 0:
                msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
                return
            bottomk, _ = torch.topk(filter_mags, num_filters_to_prune, largest=False, sorted=True)
            threshold = bottomk[-1]
            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
                           threshold_type, param_name,
                           num_filters_to_prune, filter_mags.size(0))

            # Now apply a threshold threshold是临界值, threshold_type表示用L1 norm还是L2 norm, binary_map为空, binary_map是将threshold与对权值norm后中逐个比较形成的真假数组
            mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)

        # ############################################################################################################
        elif fpgm:
            binary_map = None
            if binary_map is None:
                print("FPGM")
                # First we rank the filters
                view_filters = param.view(num_filters, -1)
                # filter_mags = magnitude_fn(view_filters, dim=1)  # L1 norm
                # print(filter_mags.shape)
                if noise and uniform(0, 1) <= noise:
                    msglogger.info("%sRankedStructureParameterPruner - param: %s - randomly choosing filters",
                                   threshold_type, param_name)
                    filter_mags *= torch.randn_like(filter_mags)

                if num_filters_to_prune == 0:
                    msglogger.info("Too few filters - can't prune %.1f%% filters", 100 * fraction_to_prune)
                    return
                
                view_filters = view_filters.cpu().detach().numpy()
                similar_matrix = distance.cdist(view_filters, view_filters, 'euclidean')
                similar_sum = np.sum(np.abs(similar_matrix), axis=0)
                
                similar_small_index = similar_sum.argsort()[:num_filters_to_prune]

                binary_map = torch.Tensor([1] * param.size(0)).to(param.device)
       
                for i in similar_small_index:
                    binary_map[i] = 0

                a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()

                mask = a.view(*param.shape)

        # #############################################################################################################
        elif binary_map is None and HRank:
            print("HRank",param_name)
            small_index = conv_index[param_name].argsort()[:num_filters_to_prune]
            binary_map = torch.Tensor([1] * param.size(0)).to(param.device)       
            for i in small_index:
                binary_map[i] = 0
            binary_all = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
            mask = binary_all.view(*param.shape)

	    # #############################################################################################################
        if zeros_mask_dict is not None:
            zeros_mask_dict[param_name].mask = mask
        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
                       threshold_type, param_name,
                       distiller.sparsity(mask),
                       fraction_to_prune)
        # param.data = torch.randn_like(param)
        return binary_map