def test_global_sparsity_mask_creator(tensors, mask_creator, sparsity_val): masks = mask_creator.create_sparsity_masks( tensors, sparsity_val, global_sparsity=True ) mask_sparsities = [tensor_sparsity(mask) for mask in masks] global_sparsity = tensor_sparsity(torch.cat([mask.reshape(-1) for mask in masks])) assert abs(global_sparsity - sparsity_val) < 1e-2 if sparsity_val not in [0.0, 1.0]: # check that individual sparsity masks are reasonably dissimilar assert len(set(mask_sparsities)) > 1 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_masks(masks, mask_creator)
def param_sparsity_dim(self, dim: Union[None, int, Tuple[int, ...]] = None) -> Tensor: """ :param dim: a dimension(s) to calculate the sparsity over, ex over channels :return: the sparsity of the contained parameter structured according to the dim passed in """ return tensor_sparsity(self._param.data, dim)
def _test_set_param_mask_from_sparsity( layer, param_name, param, sparsity, mask_creator ): mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator) mask.set_param_data(param, 0) mask.set_param_masks_from_sparsity(sparsity) measured = tensor_sparsity(mask.param_masks[0]) assert (measured - sparsity).abs() < 0.01 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_sparsity_mask_creator(tensor_shapes, mask_creator, sparsity_val, device): tensors = [torch.randn(tensor_shape).to(device) for tensor_shape in tensor_shapes] initial_masks = mask_creator.create_sparsity_masks_from_tensor(tensors) update_masks = mask_creator.create_sparsity_masks(tensors, sparsity_val) for update_mask in update_masks: assert abs(tensor_sparsity(update_mask) - sparsity_val) < 1e-2 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_masks(initial_masks + update_masks, mask_creator)
def _test_set_param_mask_from_sparsity(layer, param_name, param, sparsity, mask_creator): mask = ModuleParamPruningMask( [layer], param_names=[param_name], mask_creator=mask_creator, scorer=MagnitudePruningParamsScorer([getattr(layer, param_name)]), ) mask.set_param_data(param, 0) mask.update_param_masks(sparsity) measured = tensor_sparsity(mask.param_masks[0]) assert (measured - sparsity).abs() < 0.01 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_set_param_mask_from_abs_threshold( layer, param_name, param, threshold, expected_sparsity, mask_creator, ): mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator) mask.set_param_data(param, 0) mask.set_param_masks_from_abs_threshold(threshold) sparsity = tensor_sparsity(mask.param_masks[0]) assert (sparsity - expected_sparsity).abs() < 0.01 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_sparsity_mask_creator(tensor_shapes, mask_creator, sparsity_val, device): tensors = [ torch.randn(tensor_shape).to(device) for tensor_shape in tensor_shapes ] update_masks = mask_creator.create_sparsity_masks(tensors, sparsity_val) if isinstance(sparsity_val, float): sparsity_val = [sparsity_val] * len(update_masks) for update_mask, target_sparsity in zip(update_masks, sparsity_val): assert abs(tensor_sparsity(update_mask) - target_sparsity) < 1e-2 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_masks(update_masks, mask_creator)
def _test_sparsity_mask_creator(tensor_shape, mask_creator, sparsity_val, device): tensor = torch.randn(tensor_shape) tensor.to(device) initial_mask = mask_creator.create_sparsity_mask_from_tensor(tensor) update_mask = mask_creator.create_sparsity_mask(tensor, sparsity_val) update_mask_sparsity = tensor_sparsity(update_mask) assert abs(update_mask_sparsity - sparsity_val) < 1e-2 if isinstance(mask_creator, GroupedPruningMaskCreator): # Check that every value in the mask_creator grouping # is the same within the mask. Assumes grouping applies # an absolte mean to each grouping for mask in [initial_mask, update_mask]: grouped_mask = mask_creator.group_tensor(mask) mask_vals_are_grouped = torch.all((grouped_mask == 0.0) | (grouped_mask == 1.0)) assert mask_vals_are_grouped
def update_param_masks(self, target: Union[float, List[float]]) -> List[Tensor]: """ Convenience function to set the parameter masks such that each masks have an amount of masked values such that the percentage equals the sparsity amount given. Masks the absolute smallest values up until sparsity is reached. :param target: the desired sparsity (decimal fraction of zeros) to reach within the mask or other float target value to base sparsity masks on. Can also be a list where each element is a target for a tensor in the same position in the tensor list. If global sparsity is enabled, all values of the target list must be the same """ if self._scorer: param_scores = self._scorer.score_parameters() else: # if scorer is not set, use param data param_scores = self.params_data if not isinstance(target, Iterable): target = [target] * len(self._params) if self.adjust_target_sparsity_for_thinning: for idx, sparsity_val in enumerate(target): applied_thinning = self._params_applied_thinning[idx] if applied_thinning > 0.0: # adjust sparsity for thinned (compressed) layer param # derived from: # remaining_num_els * (1 - adjusted_sparsity) = \ # orig_num_els * (1 - sparsity) # with applied_thinning = 1 - (remaining_num_els / orig_num_els) target[idx] = (sparsity_val - applied_thinning) / (1 - applied_thinning) masks = self._mask_creator.create_sparsity_masks( param_scores, target=target, global_sparsity=self._global_sparsity) if self._scorer: self._scorer.update_last_applied_sparsity( tensor_sparsity(self.params_data[0])) return self.set_param_masks(masks)
def from_sparse_model(model: Module) -> List[ScheduledModifier]: """ Create constant ks modifiers for all prunable params in the given model (conv, linear) that have been artificially sparsified (sparsity > 40%). Useful for transfer learning from a pruned model. :param model: the model to create constant ks modifiers for :return: the list of created constant ks modifiers """ prunable = get_prunable_layers(model) modifiers = [] for name, layer in prunable: weight = getattr(layer, "weight") sparsity = tensor_sparsity(weight) if sparsity > 0.4: modifiers.append( ConstantPruningModifier( params=["{}.{}".format(name, "weight")])) return modifiers
def test_lifecycle( self, modifier_lambda, model_lambda, optim_lambda, test_steps_per_epoch, # noqa: F811 ): modifier = modifier_lambda() model = model_lambda() optimizer = optim_lambda(model) self.initialize_helper(modifier, model) if modifier.start_epoch > 0: assert modifier.applied_sparsity is None assert modifier._mask_creator == modifier._module_masks._mask_creator # check sparsity is not set before for epoch in range(int(modifier.start_epoch)): assert not modifier.update_ready(epoch, test_steps_per_epoch) assert modifier.applied_sparsity is None epoch = int(modifier.start_epoch) assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) applied_sparsities = modifier.applied_sparsity if not isinstance(applied_sparsities, list): applied_sparsities = [applied_sparsities] if not isinstance(modifier.init_sparsity, str): assert all(applied_sparsity == modifier.init_sparsity for applied_sparsity in applied_sparsities) else: assert len(modifier._init_sparsity) == len( modifier.module_masks.layers) for idx, param in enumerate(modifier.module_masks.params_data): assert modifier._init_sparsity[idx] == tensor_sparsity( param).item() last_sparsities = applied_sparsities # check forward pass input_shape = model_lambda.layer_descs()[0].input_size test_batch = torch.randn(10, *input_shape) _ = model(test_batch) while epoch < modifier.end_epoch - modifier.update_frequency: epoch += modifier.update_frequency assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) applied_sparsities = modifier.applied_sparsity if not isinstance(applied_sparsities, list): applied_sparsities = [applied_sparsities] assert all(applied_sparsity > last_sparsity for applied_sparsity, last_sparsity in zip( applied_sparsities, last_sparsities)) last_sparsities = applied_sparsities _ = model(test_batch) # check forward pass epoch = int(modifier.end_epoch) assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) def _test_final_sparsity_applied(): final_sparsities = ([modifier.final_sparsity] if isinstance( modifier.final_sparsity, float) else modifier.final_sparsity) assert all(sparsity in final_sparsities for sparsity in modifier.applied_sparsity) _test_final_sparsity_applied() for epoch in range( int(modifier.end_epoch) + 1, int(modifier.end_epoch) + 6): assert not modifier.update_ready(epoch, test_steps_per_epoch) _test_final_sparsity_applied()