Exemplo n.º 1
0
def global_add_by_abs_grad(sparse_modules, num_add):
    """
    Adds weights globally (among all given sparse modules) by ranking and selecting the
    top absolute gradients. Weights are added by adjusting the `zero_masks` of their
    modules; switching on the weight occurs by changing 1 to 0. Be sure to call
    `rezero_weights` following this function to initialize the new weights to zero.

    :param sparse_modules: list of modules of type SparseWeightsBase
    :param num_add: number of weights to add
    """

    flattened_grads = parameters_to_vector(
        [m.weight.grad for m in sparse_modules])
    flattened_mask = parameters_to_vector(
        [m.zero_mask for m in sparse_modules]).bool()

    # Find a mask of the top grads for weights that are currently off.
    topk_mask = get_topk_submask(k=num_add,
                                 values=flattened_grads.abs(),
                                 mask=flattened_mask,
                                 largest=True)

    # Pointer for slicing the mask to match the shape of each parameter.
    pointer = 0
    for module in sparse_modules:

        num_params = module.weight.numel()
        add_mask = topk_mask[pointer:pointer + num_params].view_as(
            module.weight)
        module.zero_mask[add_mask] = 0
        pointer += num_params
Exemplo n.º 2
0
    def test_topk_submask_structure(self):

        values = torch.rand(5, 6, 7)
        k = 5
        mask = torch.rand_like(values) < 0.5
        submask = get_topk_submask(k, values, mask, largest=True)

        # Validate that the submask's on-positions are a subset of the original.
        is_subset = (submask <= mask).all()
        self.assertTrue(is_subset)

        # Validate only k positions have been chosen.
        self.assertEqual(submask.sum(), k)
Exemplo n.º 3
0
def global_prune_by_abs_weight(sparse_modules,
                               prune_fraction=None,
                               num_remove=None):
    """
    Globally prune 'num_remove' weights from a list of sparse modules by ranking and
    selecting the top absolute weights. If prune_fraction is given, then num_removed is
    calculated. Modules are pruned adjusting their `zero_masks`; switching off the
    weight occurs by changing 0 to 1. Be sure to call `rezero_weights` following this
    function to zero out the pruned weights.

    :param sparse_modules: list of modules of type SparseWeightsBase
    :param prune_fraction: fraction of weights to prune; between 0 and 1; can't be
                           specified if num_remove is not None
    :param num_remove: how many parameters to remove; can't be specified if
                       prune_fraction is not None
    """

    # Only one of these arguments may be given.
    assert not (prune_fraction is not None and num_remove is not None)

    # Flatten parameters to compare them all at once for global pruning.
    flattened_params = parameters_to_vector([m.weight for m in sparse_modules])
    flattened_off_mask = parameters_to_vector(
        [m.zero_mask for m in sparse_modules])
    flattened_on_mask = ~flattened_off_mask.bool()

    # Calculate the number of parameters to keep.
    total_on = flattened_on_mask.sum().item()
    if prune_fraction is not None:
        assert 0 <= prune_fraction <= 1
        num_remove = int(round(total_on * prune_fraction))
    num_keep = total_on - num_remove

    # Prune by only keeping the top weights ranked by their absolute values.
    topk_mask = get_topk_submask(k=num_keep,
                                 values=flattened_params.abs(),
                                 mask=flattened_on_mask,
                                 largest=True)

    # Pointer for slicing the mask to match the shape of each parameter
    pointer = 0
    for module in sparse_modules:

        num_params = module.weight.numel()
        keep_mask = topk_mask[pointer:pointer + num_params].view_as(
            module.weight)
        module.zero_mask[:] = ~keep_mask
        pointer += num_params

    return num_remove
Exemplo n.º 4
0
    def test_get_topk_submask(self):
        k = 3
        values = torch.tensor([[0.5086, 0.5467, 0.2095],
                               [0.9721, 0.2540, 0.2837],
                               [0.4696, 0.9867, 0.6543]])
        mask = torch.tensor([[True, True, False], [True, True, False],
                             [True, False, True]])
        submask = get_topk_submask(k, values, mask, largest=True)
        expected_submask = torch.tensor([[False, True, False],
                                         [True, False, False],
                                         [False, False, True]])

        # Validate that the submask's on-positions are a subset of the original.
        is_subset = (submask <= mask).all()
        self.assertTrue(is_subset)

        # Validate calculated submask.
        all_equal = (submask == expected_submask).all()
        self.assertTrue(all_equal)
Exemplo n.º 5
0
    def test_get_bottomk_submask(self):
        k = 3
        values = torch.tensor([[0.5184, 0.1562, 0.3428],
                               [0.7742, 0.8507, 0.0986],
                               [0.3525, 0.8384, 0.4315]])

        mask = torch.tensor([[True, False, True], [True, True, False],
                             [True, False, False]])
        submask = get_topk_submask(k, values, mask, largest=False)

        expected_submask = torch.tensor([[True, False, True],
                                         [False, False, False],
                                         [True, False, False]])

        # Validate that the submask's on-positions are a subset of the original.
        is_subset = (submask <= mask).all()
        self.assertTrue(is_subset)

        # Validate calculated submask.
        all_equal = (submask == expected_submask).all()
        self.assertTrue(all_equal)
Exemplo n.º 6
0
def local_add_by_abs_grad(sparse_modules, num_add):
    """
    Adds `num_add` weights distributed across the modules by ranking and
    selecting the top absolute gradients for each `sparse_module`. Weights are
    added by adjusting the sparse module `zero_masks`. Be sure to call
    `rezero_weights` following this function to initialize the new weights to
    zero.

    :param sparse_modules: list of modules of type SparseWeightsBase
    :param num_add: list with number of weights to add per sparse_module,
                    usually the output of :meth:`local_prune_by_abs_weight`
    """
    for i, module in enumerate(sparse_modules):
        grads = module.weight.grad.flatten()
        mask = module.zero_mask.detach().flatten().bool()

        # Find a mask of the top grads for weights that are currently off.
        topk_mask = get_topk_submask(k=num_add[i],
                                     values=grads.abs(),
                                     mask=mask,
                                     largest=True)
        module.zero_mask[topk_mask.view_as(module.weight)] = 0.
Exemplo n.º 7
0
def local_prune_by_abs_weight(sparse_modules, prune_fraction):
    """
    Prune `prune_fraction` of the weigths of each module in `sparse_modules`.

    Modules are pruned by ranking and selecting the top absolute weights of each
    module and adjusting the sparse module `zero_mask`s. Be sure to call
    `rezero_weights` following this function to zero out the pruned weights.

    :param sparse_modules: list of modules of type SparseWeightsBase
    :param prune_fraction: fraction of weights to prune; between 0 and 1

    :return: list containing the number of weights removed for each `sparse_module`
    """
    assert 0 <= prune_fraction <= 1

    total_removed = []
    for module in sparse_modules:
        params = module.weight.detach().flatten()
        off_mask = module.zero_mask.detach().flatten().bool()

        # Compute new top K value
        on_mask = ~off_mask
        total_on = on_mask.sum()
        num_remove = (total_on * prune_fraction).round()
        k = int(total_on - num_remove)

        # Prune by only keeping the top weights ranked by their absolute values.
        topk_mask = get_topk_submask(k=k,
                                     values=params.abs(),
                                     mask=on_mask,
                                     largest=True)

        # Update module with new mask
        module.zero_mask[:] = (~topk_mask).view_as(module.weight)
        total_removed.append(int(num_remove))

    return total_removed