Example #1
0
    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
        if param_name not in self.params_names:
            return

        starting_epoch = meta['starting_epoch']
        current_epoch = meta['current_epoch']
        ending_epoch = meta['ending_epoch']
        freq = meta['frequency']

        ramp_epoch = self.ramp_epoch_offset + starting_epoch

        # Calculate start slope
        if self.start_slope is None:
            # We want to calculate these values only once, and then cache them.
            self.start_slope = (2 * self.q *
                                freq) / (2 *
                                         (ramp_epoch - starting_epoch) + 3 *
                                         (ending_epoch - ramp_epoch))
            self.ramp_slope = self.start_slope * self.ramp_slope_mult

        if current_epoch < ramp_epoch:
            eps = self.start_slope * (current_epoch - starting_epoch +
                                      1) / freq
        else:
            eps = (self.start_slope *
                   (ramp_epoch - starting_epoch + 1) + self.ramp_slope *
                   (current_epoch - ramp_epoch + 1)) / freq

        # After computing the threshold, we can create the mask
        zeros_mask_dict[param_name].mask = distiller.threshold_mask(
            param.data, eps)
 def create_mask(param, sensitivity):
     if not hasattr(param, 'stddev'):
         param.stddev = torch.std(param).item()
     with torch.no_grad():
         threshold = param.stddev * sensitivity
         mask = distiller.threshold_mask(param.data, threshold)
         return mask
Example #3
0
 def prune_level(param, param_name, zeros_mask_dict, desired_sparsity):
     bottomk, _ = torch.topk(param.abs().view(-1),
                             int(desired_sparsity * param.numel()),
                             largest=False,
                             sorted=True)
     threshold = bottomk.data[
         -1]  # This is the largest element from the group of elements that we prune away
     zeros_mask_dict[param_name].mask = distiller.threshold_mask(
         param.data, threshold)
Example #4
0
def threshold_model(model, threshold):
    """Threshold an entire model using the provided threshold

    This function prunes weights only (biases are left untouched).
    """
    for name, p in model.named_parameters():
        if "weight" in name:
            mask = distiller.threshold_mask(p.data, threshold)
            p.data = p.data.mul_(mask)
Example #5
0
    def threshold(self, param, param_name, zeros_mask_dict):
        """Soft threshold for L1-norm regularizer"""
        if self.threshold_criteria is None or param_name not in self.reg_regims:
            return

        strength = self.reg_regims[param_name]
        zeros_mask_dict[param_name].mask = distiller.threshold_mask(
            param.data, threshold=strength)
        zeros_mask_dict[param_name].is_regularization_mask = True
Example #6
0
def test_threshold_mask():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create and apply a mask
    mask = distiller.threshold_mask(a, threshold=0.3)
    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
    assert mask[1, 4, 17, 31] == 0
    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
    def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
        if not hasattr(param, 'stddev'):
            param.stddev = torch.std(param).item()

        if param_name not in self.sensitivities:
            if '*' not in self.sensitivities:
                return
            else:
                sensitivity = self.sensitivities['*']
        else:
            sensitivity = self.sensitivities[param_name]

        threshold = param.stddev * sensitivity

        # After computing the threshold, we can create the mask
        zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
 def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
     threshold = self.thresholds.get(param_name, self.thresholds["*"])
     zeros_mask_dict[param_name].mask = distiller.threshold_mask(
         param.data, threshold)
Example #9
0
 def create_mask(param, desired_sparsity):
     with torch.no_grad():
         bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
         threshold = bottomk.data[-1]  # This is the largest element from the group of elements that we prune away
         mask = distiller.threshold_mask(param.data, threshold)
         return mask
Example #10
0
 def create_mask(param, threshold):
     with torch.no_grad():
         mask = distiller.threshold_mask(param.data, threshold)
         return mask
Example #11
0
 def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
     percent_to_prune = self.threshold_ratio.get(param_name)
     data = param.data.view(param.data.numel())
     threshold = self.read_boundary_value_with_ratio(data, percent_to_prune)
     zeros_mask_dict[param_name].mask = distiller.threshold_mask(
         param.data, threshold)