Esempio n. 1
0
 def _policy(self, value, quantized, previous_mask, interval):
     previous_pruned = util.sum(previous_mask)
     if self.count_zero:
         th_arg = util.cast(util.count(value) * interval, int)
     else:
         tmp = util.count(value[value != 0])
         flat_value_arg = util.where(value.flatten() != 0)
         th_arg = util.cast(tmp * interval, int)
     if th_arg < 0:
         raise ValueError('mask has {} elements, interval is {}'.format(
             previous_pruned, interval))
     off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)),
                          float)
     metric = value - quantized
     flat_value = (metric * off_mask).flatten()
     if interval >= 1.0:
         th = flat_value.max() + 1.0
     else:
         if self.count_zero:
             th = util.top_k(util.abs(flat_value), th_arg)
         else:
             th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg)
     th = util.cast(th, float)
     new_mask = util.logical_not(util.greater_equal(util.abs(metric), th))
     return util.logical_or(new_mask, previous_mask)
Esempio n. 2
0
 def _apply(self, value):
     self._parameter_config = {
         'mask': {
             'initial': tf.zeros_initializer(tf.bool),
             'shape': value.shape,
         }
     }
     quantized_value = self._quantize(value)
     off_mask = util.cast(util.logical_not(self.mask), float)
     mask = util.cast(self.mask, float)
     # on mask indicates the quantized values
     return value * off_mask + quantized_value * mask
Esempio n. 3
0
File: mixed.py Progetto: zaf05/mayo
 def _new_mask(self, mask, value, quantized_value, interval):
     loss = util.abs(value - quantized_value)
     # check the ones that are not quantized
     mask = mask.reshape((1, 1, 1, mask.shape[0]))
     unquantized_mask = util.logical_not(mask)
     # TODO: mask shape is incorrect
     loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2))
     # sort
     num_active = util.ceil(len(loss_vec) * interval)
     threshold = sorted(loss_vec)[num_active]
     if interval >= 1.0:
         return util.cast(unquantized_mask, float)
     new_mask = (unquantized_mask * loss) > threshold
     return util.cast(util.logical_or(new_mask, mask), float)
Esempio n. 4
0
 def _update(self):
     # update positives mask and mean values
     value = self.session.run(self.before)
     # divide them into two groups
     # mean = util.mean(value)
     mean = 0.0
     # find two central points
     positives = value > mean
     self.positives = positives
     self.positives_mean = util.mean(value[util.where(positives)])
     negatives = util.logical_and(util.logical_not(positives), value != 0)
     self.negatives_mean = util.mean(value[util.where(negatives)])
     if self.positives_mean.eval() == 0 or self.negatives_mean.eval() == 0:
         log.warn(
             'means are skewed, pos mean is {} and neg mean is {}'.format(
                 self.positives_mean.eval(), self.negatives_mean.eval()))
     # update internal quantizer
     self.quantizer.update()
     for quantizer in self.parameter_quantizers.values():
         quantizer.update()
Esempio n. 5
0
 def negatives(self):
     return util.logical_not(self.positives)