def test_global_sparsity_mask_creator(tensors, mask_creator, sparsity_val):
    masks = mask_creator.create_sparsity_masks(
        tensors, sparsity_val, global_sparsity=True
    )
    mask_sparsities = [tensor_sparsity(mask) for mask in masks]
    global_sparsity = tensor_sparsity(torch.cat([mask.reshape(-1) for mask in masks]))
    assert abs(global_sparsity - sparsity_val) < 1e-2

    if sparsity_val not in [0.0, 1.0]:
        # check that individual sparsity masks are reasonably dissimilar
        assert len(set(mask_sparsities)) > 1

    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_masks(masks, mask_creator)
Exemplo n.º 2
0
 def param_sparsity_dim(self,
                        dim: Union[None, int, Tuple[int,
                                                    ...]] = None) -> Tensor:
     """
     :param dim: a dimension(s) to calculate the sparsity over, ex over channels
     :return: the sparsity of the contained parameter structured according
         to the dim passed in
     """
     return tensor_sparsity(self._param.data, dim)
Exemplo n.º 3
0
def _test_set_param_mask_from_sparsity(
    layer, param_name, param, sparsity, mask_creator
):
    mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator)
    mask.set_param_data(param, 0)
    mask.set_param_masks_from_sparsity(sparsity)
    measured = tensor_sparsity(mask.param_masks[0])
    assert (measured - sparsity).abs() < 0.01
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_sparsity_mask_creator(tensor_shapes, mask_creator, sparsity_val, device):
    tensors = [torch.randn(tensor_shape).to(device) for tensor_shape in tensor_shapes]
    initial_masks = mask_creator.create_sparsity_masks_from_tensor(tensors)
    update_masks = mask_creator.create_sparsity_masks(tensors, sparsity_val)

    for update_mask in update_masks:
        assert abs(tensor_sparsity(update_mask) - sparsity_val) < 1e-2

    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_masks(initial_masks + update_masks, mask_creator)
Exemplo n.º 5
0
def _test_set_param_mask_from_sparsity(layer, param_name, param, sparsity,
                                       mask_creator):
    mask = ModuleParamPruningMask(
        [layer],
        param_names=[param_name],
        mask_creator=mask_creator,
        scorer=MagnitudePruningParamsScorer([getattr(layer, param_name)]),
    )
    mask.set_param_data(param, 0)
    mask.update_param_masks(sparsity)
    measured = tensor_sparsity(mask.param_masks[0])
    assert (measured - sparsity).abs() < 0.01
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
Exemplo n.º 6
0
def _test_set_param_mask_from_abs_threshold(
    layer,
    param_name,
    param,
    threshold,
    expected_sparsity,
    mask_creator,
):
    mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator)
    mask.set_param_data(param, 0)
    mask.set_param_masks_from_abs_threshold(threshold)
    sparsity = tensor_sparsity(mask.param_masks[0])
    assert (sparsity - expected_sparsity).abs() < 0.01
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
Exemplo n.º 7
0
def _test_sparsity_mask_creator(tensor_shapes, mask_creator, sparsity_val,
                                device):
    tensors = [
        torch.randn(tensor_shape).to(device) for tensor_shape in tensor_shapes
    ]
    update_masks = mask_creator.create_sparsity_masks(tensors, sparsity_val)

    if isinstance(sparsity_val, float):
        sparsity_val = [sparsity_val] * len(update_masks)

    for update_mask, target_sparsity in zip(update_masks, sparsity_val):
        assert abs(tensor_sparsity(update_mask) - target_sparsity) < 1e-2

    if isinstance(mask_creator, GroupedPruningMaskCreator):
        _test_grouped_masks(update_masks, mask_creator)
def _test_sparsity_mask_creator(tensor_shape, mask_creator, sparsity_val,
                                device):
    tensor = torch.randn(tensor_shape)
    tensor.to(device)
    initial_mask = mask_creator.create_sparsity_mask_from_tensor(tensor)
    update_mask = mask_creator.create_sparsity_mask(tensor, sparsity_val)

    update_mask_sparsity = tensor_sparsity(update_mask)
    assert abs(update_mask_sparsity - sparsity_val) < 1e-2
    if isinstance(mask_creator, GroupedPruningMaskCreator):
        # Check that every value in the mask_creator grouping
        # is the same within the mask.  Assumes grouping applies
        # an absolte mean to each grouping
        for mask in [initial_mask, update_mask]:
            grouped_mask = mask_creator.group_tensor(mask)
            mask_vals_are_grouped = torch.all((grouped_mask == 0.0)
                                              | (grouped_mask == 1.0))
            assert mask_vals_are_grouped
Exemplo n.º 9
0
    def update_param_masks(self, target: Union[float,
                                               List[float]]) -> List[Tensor]:
        """
        Convenience function to set the parameter masks such that each masks have an
        amount of masked values such that the percentage equals the sparsity amount
        given. Masks the absolute smallest values up until sparsity is reached.

        :param target: the desired sparsity (decimal fraction of zeros) to reach
            within the mask or other float target value to base sparsity masks on.
            Can also be a list where each element is a
            target for a tensor in the same position in the tensor list. If global
            sparsity is enabled, all values of the target list must be the same
        """
        if self._scorer:
            param_scores = self._scorer.score_parameters()
        else:
            # if scorer is not set, use param data
            param_scores = self.params_data

        if not isinstance(target, Iterable):
            target = [target] * len(self._params)
        if self.adjust_target_sparsity_for_thinning:
            for idx, sparsity_val in enumerate(target):
                applied_thinning = self._params_applied_thinning[idx]
                if applied_thinning > 0.0:
                    # adjust sparsity for thinned (compressed) layer param
                    # derived from:
                    # remaining_num_els * (1 - adjusted_sparsity) = \
                    #   orig_num_els * (1 - sparsity)
                    # with applied_thinning = 1 - (remaining_num_els / orig_num_els)
                    target[idx] = (sparsity_val -
                                   applied_thinning) / (1 - applied_thinning)

        masks = self._mask_creator.create_sparsity_masks(
            param_scores, target=target, global_sparsity=self._global_sparsity)

        if self._scorer:
            self._scorer.update_last_applied_sparsity(
                tensor_sparsity(self.params_data[0]))

        return self.set_param_masks(masks)
Exemplo n.º 10
0
    def from_sparse_model(model: Module) -> List[ScheduledModifier]:
        """
        Create constant ks modifiers for all prunable params in the given model
        (conv, linear) that have been artificially sparsified (sparsity > 40%).
        Useful for transfer learning from a pruned model.

        :param model: the model to create constant ks modifiers for
        :return: the list of created constant ks modifiers
        """
        prunable = get_prunable_layers(model)
        modifiers = []

        for name, layer in prunable:
            weight = getattr(layer, "weight")
            sparsity = tensor_sparsity(weight)

            if sparsity > 0.4:
                modifiers.append(
                    ConstantPruningModifier(
                        params=["{}.{}".format(name, "weight")]))

        return modifiers
    def test_lifecycle(
            self,
            modifier_lambda,
            model_lambda,
            optim_lambda,
            test_steps_per_epoch,  # noqa: F811
    ):
        modifier = modifier_lambda()
        model = model_lambda()
        optimizer = optim_lambda(model)
        self.initialize_helper(modifier, model)
        if modifier.start_epoch > 0:
            assert modifier.applied_sparsity is None
        assert modifier._mask_creator == modifier._module_masks._mask_creator

        # check sparsity is not set before
        for epoch in range(int(modifier.start_epoch)):
            assert not modifier.update_ready(epoch, test_steps_per_epoch)
            assert modifier.applied_sparsity is None

        epoch = int(modifier.start_epoch)
        assert modifier.update_ready(epoch, test_steps_per_epoch)
        modifier.scheduled_update(model, optimizer, epoch,
                                  test_steps_per_epoch)

        applied_sparsities = modifier.applied_sparsity
        if not isinstance(applied_sparsities, list):
            applied_sparsities = [applied_sparsities]

        if not isinstance(modifier.init_sparsity, str):
            assert all(applied_sparsity == modifier.init_sparsity
                       for applied_sparsity in applied_sparsities)
        else:
            assert len(modifier._init_sparsity) == len(
                modifier.module_masks.layers)
            for idx, param in enumerate(modifier.module_masks.params_data):
                assert modifier._init_sparsity[idx] == tensor_sparsity(
                    param).item()

        last_sparsities = applied_sparsities

        # check forward pass
        input_shape = model_lambda.layer_descs()[0].input_size
        test_batch = torch.randn(10, *input_shape)
        _ = model(test_batch)

        while epoch < modifier.end_epoch - modifier.update_frequency:
            epoch += modifier.update_frequency
            assert modifier.update_ready(epoch, test_steps_per_epoch)
            modifier.scheduled_update(model, optimizer, epoch,
                                      test_steps_per_epoch)

            applied_sparsities = modifier.applied_sparsity
            if not isinstance(applied_sparsities, list):
                applied_sparsities = [applied_sparsities]

            assert all(applied_sparsity > last_sparsity
                       for applied_sparsity, last_sparsity in zip(
                           applied_sparsities, last_sparsities))

            last_sparsities = applied_sparsities

        _ = model(test_batch)  # check forward pass
        epoch = int(modifier.end_epoch)
        assert modifier.update_ready(epoch, test_steps_per_epoch)
        modifier.scheduled_update(model, optimizer, epoch,
                                  test_steps_per_epoch)

        def _test_final_sparsity_applied():
            final_sparsities = ([modifier.final_sparsity] if isinstance(
                modifier.final_sparsity, float) else modifier.final_sparsity)
            assert all(sparsity in final_sparsities
                       for sparsity in modifier.applied_sparsity)

        _test_final_sparsity_applied()

        for epoch in range(
                int(modifier.end_epoch) + 1,
                int(modifier.end_epoch) + 6):
            assert not modifier.update_ready(epoch, test_steps_per_epoch)
            _test_final_sparsity_applied()