def _policy(self, value, quantized, previous_mask, interval): previous_pruned = util.sum(previous_mask) if self.count_zero: th_arg = util.cast(util.count(value) * interval, int) else: tmp = util.count(value[value != 0]) flat_value_arg = util.where(value.flatten() != 0) th_arg = util.cast(tmp * interval, int) if th_arg < 0: raise ValueError('mask has {} elements, interval is {}'.format( previous_pruned, interval)) off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)), float) metric = value - quantized flat_value = (metric * off_mask).flatten() if interval >= 1.0: th = flat_value.max() + 1.0 else: if self.count_zero: th = util.top_k(util.abs(flat_value), th_arg) else: th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg) th = util.cast(th, float) new_mask = util.logical_not(util.greater_equal(util.abs(metric), th)) return util.logical_or(new_mask, previous_mask)
def _overflow_rate(mask): """ Compute overflow_rate from a given overflow mask. Here `mask` is a boolean tensor where True and False represent the presence and absence of overflow repsectively. """ return util.sum(util.cast(mask, int)) / util.count(mask)
def _info(self): # FIXME it doesn't make sense to run `gate` once as its density # varies from run to run. gate = util.cast(self.session.run(self.gate), int) density = Percent(util.sum(gate) / util.count(gate)) return self._info_tuple(gate=self.gate.name, density=density, count_=gate.size)
def _info(self): mask = util.cast(self.session.run(self.mask), int) density = Percent(util.sum(mask) / util.count(mask)) return self._info_tuple( mask=self.mask.name, density=density, count_=mask.size)