def set_param_mask(self, param, param_name, zeros_mask_dict, meta): if param_name not in self.params_names: return starting_epoch = meta['starting_epoch'] current_epoch = meta['current_epoch'] ending_epoch = meta['ending_epoch'] freq = meta['frequency'] ramp_epoch = self.ramp_epoch_offset + starting_epoch # Calculate start slope if self.start_slope is None: # We want to calculate these values only once, and then cache them. self.start_slope = (2 * self.q * freq) / (2 * (ramp_epoch - starting_epoch) + 3 * (ending_epoch - ramp_epoch)) self.ramp_slope = self.start_slope * self.ramp_slope_mult if current_epoch < ramp_epoch: eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq else: eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) + self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq # After computing the threshold, we can create the mask zeros_mask_dict[param_name].mask = distiller.threshold_mask( param.data, eps)
def create_mask(param, sensitivity): if not hasattr(param, 'stddev'): param.stddev = torch.std(param).item() with torch.no_grad(): threshold = param.stddev * sensitivity mask = distiller.threshold_mask(param.data, threshold) return mask
def prune_level(param, param_name, zeros_mask_dict, desired_sparsity): bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) threshold = bottomk.data[ -1] # This is the largest element from the group of elements that we prune away zeros_mask_dict[param_name].mask = distiller.threshold_mask( param.data, threshold)
def threshold_model(model, threshold): """Threshold an entire model using the provided threshold This function prunes weights only (biases are left untouched). """ for name, p in model.named_parameters(): if "weight" in name: mask = distiller.threshold_mask(p.data, threshold) p.data = p.data.mul_(mask)
def threshold(self, param, param_name, zeros_mask_dict): """Soft threshold for L1-norm regularizer""" if self.threshold_criteria is None or param_name not in self.reg_regims: return strength = self.reg_regims[param_name] zeros_mask_dict[param_name].mask = distiller.threshold_mask( param.data, threshold=strength) zeros_mask_dict[param_name].is_regularization_mask = True
def test_threshold_mask(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask mask = distiller.threshold_mask(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
def set_param_mask(self, param, param_name, zeros_mask_dict, meta): if not hasattr(param, 'stddev'): param.stddev = torch.std(param).item() if param_name not in self.sensitivities: if '*' not in self.sensitivities: return else: sensitivity = self.sensitivities['*'] else: sensitivity = self.sensitivities[param_name] threshold = param.stddev * sensitivity # After computing the threshold, we can create the mask zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
def set_param_mask(self, param, param_name, zeros_mask_dict, meta): threshold = self.thresholds.get(param_name, self.thresholds["*"]) zeros_mask_dict[param_name].mask = distiller.threshold_mask( param.data, threshold)
def create_mask(param, desired_sparsity): with torch.no_grad(): bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away mask = distiller.threshold_mask(param.data, threshold) return mask
def create_mask(param, threshold): with torch.no_grad(): mask = distiller.threshold_mask(param.data, threshold) return mask
def set_param_mask(self, param, param_name, zeros_mask_dict, meta): percent_to_prune = self.threshold_ratio.get(param_name) data = param.data.view(param.data.numel()) threshold = self.read_boundary_value_with_ratio(data, percent_to_prune) zeros_mask_dict[param_name].mask = distiller.threshold_mask( param.data, threshold)