def _updated_mask(self, var, mask): var, mask, alpha = self.session.run([var, mask, self.alpha]) threshold = self._threshold(var, alpha) on_mask = util.abs(var) > self.on_factor * threshold mask = util.logical_or(mask, on_mask) off_mask = util.abs(var) > self.off_factor * threshold return util.logical_and(mask, off_mask)
def _updated_mask(self, var, mask): mask, gamma = self.session.run([mask, self.gamma]) if self.global_threshold: threshold = self._global_threshold() else: if self.incremental: gammas = gamma[util.nonzero(self.mask)] threshold = self._threshold(gammas) new_mask = gamma > threshold if self.incremental: return util.logical_and(mask, new_mask) return new_mask
def _update(self): # update positives mask and mean values value = self.session.run(self.before) # divide them into two groups # mean = util.mean(value) mean = 0.0 # find two central points positives = value > mean self.positives = positives self.positives_mean = util.mean(value[util.where(positives)]) negatives = util.logical_and(util.logical_not(positives), value != 0) self.negatives_mean = util.mean(value[util.where(negatives)]) if self.positives_mean.eval() == 0 or self.negatives_mean.eval() == 0: log.warn( 'means are skewed, pos mean is {} and neg mean is {}'.format( self.positives_mean.eval(), self.negatives_mean.eval())) # update internal quantizer self.quantizer.update() for quantizer in self.parameter_quantizers.values(): quantizer.update()