예제 #1
0
파일: netslim.py 프로젝트: zaf05/mayo
 def _global_threshold(self):
     estimator = self.session.estimator
     gamma_name = 'NetworkSlimmer.gamma'
     threshold_name = 'NetworkSlimmer.threshold'
     if estimator.max_len(gamma_name) == 0:
         try:
             return estimator.get_value(threshold_name)
         except KeyError:
             raise RuntimeError(
                 'Train for a while before running update to collect '
                 'gamma values.')
     # extract all gammas globally
     gammas = []
     for overrider, gamma in estimator.get_values(gamma_name).items():
         if not overrider.should_update:
             continue
         if self.incremental:
             mask = self.session.run(overrider.mask)
             gamma = gamma[util.nonzero(mask)]
         gammas += gamma.tolist()
     threshold = self._threshold(gammas)
     log.debug(
         'Extracted a global threshold for all gammas: {}.'
         .format(threshold))
     estimator.flush_all(gamma_name)
     estimator.add(threshold, threshold_name)
     return threshold
예제 #2
0
파일: fixed.py 프로젝트: randysuen/mayo
 def _quantize(self, value, point, width, compute_overflow_rate=False):
     # decompose
     sign = util.cast(value > 0, float) - util.cast(value < 0, float)
     value = util.log(util.abs(value), 2.0)
     # quantize
     value = self.quantizer.apply(
         value, compute_overflow_rate=compute_overflow_rate)
     if compute_overflow_rate:
         return value
     # represent
     return util.where(util.nonzero(sign), sign * (2**value), 0)
예제 #3
0
파일: netslim.py 프로젝트: zaf05/mayo
 def _updated_mask(self, var, mask):
     mask, gamma = self.session.run([mask, self.gamma])
     if self.global_threshold:
         threshold = self._global_threshold()
     else:
         if self.incremental:
             gammas = gamma[util.nonzero(self.mask)]
         threshold = self._threshold(gammas)
     new_mask = gamma > threshold
     if self.incremental:
         return util.logical_and(mask, new_mask)
     return new_mask