def _policy(self, value, quantized, previous_mask, interval): previous_pruned = util.sum(previous_mask) if self.count_zero: th_arg = util.cast(util.count(value) * interval, int) else: tmp = util.count(value[value != 0]) flat_value_arg = util.where(value.flatten() != 0) th_arg = util.cast(tmp * interval, int) if th_arg < 0: raise ValueError('mask has {} elements, interval is {}'.format( previous_pruned, interval)) off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)), float) metric = value - quantized flat_value = (metric * off_mask).flatten() if interval >= 1.0: th = flat_value.max() + 1.0 else: if self.count_zero: th = util.top_k(util.abs(flat_value), th_arg) else: th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg) th = util.cast(th, float) new_mask = util.logical_not(util.greater_equal(util.abs(metric), th)) return util.logical_or(new_mask, previous_mask)
def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.zeros_initializer(tf.bool), 'shape': value.shape, } } quantized_value = self._quantize(value) off_mask = util.cast(util.logical_not(self.mask), float) mask = util.cast(self.mask, float) # on mask indicates the quantized values return value * off_mask + quantized_value * mask
def _new_mask(self, mask, value, quantized_value, interval): loss = util.abs(value - quantized_value) # check the ones that are not quantized mask = mask.reshape((1, 1, 1, mask.shape[0])) unquantized_mask = util.logical_not(mask) # TODO: mask shape is incorrect loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2)) # sort num_active = util.ceil(len(loss_vec) * interval) threshold = sorted(loss_vec)[num_active] if interval >= 1.0: return util.cast(unquantized_mask, float) new_mask = (unquantized_mask * loss) > threshold return util.cast(util.logical_or(new_mask, mask), float)
def _update(self): # update positives mask and mean values value = self.session.run(self.before) # divide them into two groups # mean = util.mean(value) mean = 0.0 # find two central points positives = value > mean self.positives = positives self.positives_mean = util.mean(value[util.where(positives)]) negatives = util.logical_and(util.logical_not(positives), value != 0) self.negatives_mean = util.mean(value[util.where(negatives)]) if self.positives_mean.eval() == 0 or self.negatives_mean.eval() == 0: log.warn( 'means are skewed, pos mean is {} and neg mean is {}'.format( self.positives_mean.eval(), self.negatives_mean.eval())) # update internal quantizer self.quantizer.update() for quantizer in self.parameter_quantizers.values(): quantizer.update()
def negatives(self): return util.logical_not(self.positives)