def _test_constructor(layer, param_name, mask_creator): mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator) assert mask.layers[0] == layer assert mask.param_names[0] == param_name assert not mask.store_init assert not mask.store_unmasked assert mask.track_grad_mom == -1.0 assert not mask.global_sparsity assert not mask.enabled assert mask_creator == mask.mask_creator
def _test_set_param_mask_from_sparsity( layer, param_name, param, sparsity, mask_creator ): mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator) mask.set_param_data(param, 0) mask.set_param_masks_from_sparsity(sparsity) measured = tensor_sparsity(mask.param_masks[0]) assert (measured - sparsity).abs() < 0.01 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_set_param_mask_from_abs_threshold( layer, param_name, param, threshold, expected_sparsity, mask_creator, ): mask = ModuleParamPruningMask([layer], [param_name], mask_creator=mask_creator) mask.set_param_data(param, 0) mask.set_param_masks_from_abs_threshold(threshold) sparsity = tensor_sparsity(mask.param_masks[0]) assert (sparsity - expected_sparsity).abs() < 0.01 if isinstance(mask_creator, GroupedPruningMaskCreator): _test_grouped_sparsity_mask_output(mask_creator, mask.param_masks[0])
def _test_set_param_mask(layer, param_name, param_mask): mask = ModuleParamPruningMask([layer], [param_name]) result = mask.set_param_masks([param_mask])[0] res_unmasked = (result == 1.0).type(torch.float32) res_masked = (result == -1.0).type(torch.float32) res_no_change = (result == 0.0).type(torch.float32) mask_ones = (param_mask == 1.0).type(torch.float32) mask_zeros = (param_mask == 0.0).type(torch.float32) assert torch.sum(res_unmasked.abs()) < sys.float_info.epsilon assert torch.sum((res_masked - mask_zeros).abs()) < sys.float_info.epsilon assert torch.sum((res_no_change - mask_ones).abs()) < sys.float_info.epsilon mask.enabled = True mask.apply() param_data_zeros = (mask.params_data[0] == 0.0).type("float32") assert torch.sum((param_data_zeros - mask_zeros).abs()) < sys.float_info.epsilon
def _test_set_param_data(layer, param_name, data): mask = ModuleParamPruningMask([layer], param_name) mask.set_param_data(data, 0) assert torch.sum((mask.params_data[0] - data).abs()) < sys.float_info.epsilon