Exemplo n.º 1
0
 def _policy(self, value, quantized, previous_mask, interval):
     previous_pruned = util.sum(previous_mask)
     if self.count_zero:
         th_arg = util.cast(util.count(value) * interval, int)
     else:
         tmp = util.count(value[value != 0])
         flat_value_arg = util.where(value.flatten() != 0)
         th_arg = util.cast(tmp * interval, int)
     if th_arg < 0:
         raise ValueError('mask has {} elements, interval is {}'.format(
             previous_pruned, interval))
     off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)),
                          float)
     metric = value - quantized
     flat_value = (metric * off_mask).flatten()
     if interval >= 1.0:
         th = flat_value.max() + 1.0
     else:
         if self.count_zero:
             th = util.top_k(util.abs(flat_value), th_arg)
         else:
             th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg)
     th = util.cast(th, float)
     new_mask = util.logical_not(util.greater_equal(util.abs(metric), th))
     return util.logical_or(new_mask, previous_mask)
Exemplo n.º 2
0
 def _quantize(self, value, base=None):
     scale = self.scale
     base = util.cast(self.base if base is None else base, int)
     shift = util.cast(2**base, float)
     positives = util.cast(value > 0, float)
     negatives = util.cast(value < 0, float)
     return positives * shift * scale - negatives * shift * scale
Exemplo n.º 3
0
 def _quantize(self,
               value,
               point=None,
               width=None,
               compute_overflow_rate=False):
     point = util.cast(self.point if point is None else point, float)
     width = util.cast(self.width if width is None else width, float)
     # x << (width - point)
     shift = 2.0**(util.round(width) - util.round(point))
     value = value * shift
     # quantize
     if self.stochastic:
         value = util.stochastic_round(value, self.stochastic)
     else:
         value = util.round(value)
     # ensure number is representable without overflow
     if width is not None:
         max_value = util.cast(2**(width - 1), float)
         if compute_overflow_rate:
             overflow_value = value[value != 0]
             overflows = util.logical_or(overflow_value < -max_value,
                                         overflow_value > max_value - 1)
             return self._overflow_rate(overflows)
         value = util.clip_by_value(value, -max_value, max_value - 1)
     # revert bit-shift earlier
     return value / shift
Exemplo n.º 4
0
 def _quantize(self, value, base=None):
     # @Aaron @Xitong FIXME possible redundancy:
     # this code is idenitical to super()._quantize()
     scale = self.scale
     base = util.cast(self.base if base is None else base, int)
     shift = util.cast(2**base, float)
     positives = util.cast(value > 0, float)
     negatives = util.cast(value < 0, float)
     # FIXME verify this multiplication is broadcasting correctly
     return positives * shift * scale - negatives * shift * scale
Exemplo n.º 5
0
 def _quantize(self, value, point, width, compute_overflow_rate=False):
     # decompose
     sign = util.cast(value > 0, float) - util.cast(value < 0, float)
     value = util.log(util.abs(value), 2.0)
     # quantize
     value = self.quantizer.apply(
         value, compute_overflow_rate=compute_overflow_rate)
     if compute_overflow_rate:
         return value
     # represent
     return util.where(util.nonzero(sign), sign * (2**value), 0)
Exemplo n.º 6
0
 def _apply(self, value):
     self._parameter_config = {
         'mask': {
             'initial': tf.zeros_initializer(tf.bool),
             'shape': value.shape,
         }
     }
     quantized_value = self._quantize(value)
     off_mask = util.cast(util.logical_not(self.mask), float)
     mask = util.cast(self.mask, float)
     # on mask indicates the quantized values
     return value * off_mask + quantized_value * mask
Exemplo n.º 7
0
Arquivo: mixed.py Projeto: zaf05/mayo
 def _new_mask(self, mask, value, quantized_value, interval):
     loss = util.abs(value - quantized_value)
     # check the ones that are not quantized
     mask = mask.reshape((1, 1, 1, mask.shape[0]))
     unquantized_mask = util.logical_not(mask)
     # TODO: mask shape is incorrect
     loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2))
     # sort
     num_active = util.ceil(len(loss_vec) * interval)
     threshold = sorted(loss_vec)[num_active]
     if interval >= 1.0:
         return util.cast(unquantized_mask, float)
     new_mask = (unquantized_mask * loss) > threshold
     return util.cast(util.logical_or(new_mask, mask), float)
Exemplo n.º 8
0
 def _represent(self, sign, exponent, mantissa):
     """
     Represent the value in floating-point using
     sign, exponent and mantissa.
     """
     value = util.cast(sign, float) * (2.0**exponent) * mantissa
     if util.is_constant(sign, exponent, mantissa):
         return value
     if util.is_numpy(sign, exponent, mantissa):
         zeros = np.zeros(sign.shape, dtype=np.int32)
     else:
         zeros = tf.zeros(sign.shape, dtype=tf.int32)
     is_zero = util.equal(sign, zeros)
     return util.where(is_zero, util.cast(zeros, float), value)
Exemplo n.º 9
0
 def _decompose(self, value, exponent_bias=None):
     """
     Decompose a single-precision floating-point into
     sign, exponent and mantissa components.
     """
     if exponent_bias is None:
         exponent_bias = self.exponent_bias
     # smallest non-zero floating point
     descriminator = (2**(-exponent_bias)) / 2
     sign = util.cast(value > descriminator, int)
     sign -= util.cast(value < -descriminator, int)
     value = util.abs(value)
     exponent = util.floor(util.log(value, 2))
     mantissa = value / (2**exponent)
     return sign, exponent, mantissa
Exemplo n.º 10
0
Arquivo: gate.py Projeto: zaf05/mayo
 def _apply(self, value):
     policy = self.policy
     value_pool = tf.nn.relu(value)
     n, h, w, c = (int(d) for d in value.shape)
     pool_params = {
         'padding': 'VALID',
         'ksize': [1, h, w, 1],
         'strides': [1, 1, 1, 1]
     }
     if policy == 'avg' or policy is None:
         pooled = tf.nn.avg_pool(value_pool, **pool_params)
     if policy == 'max':
         pooled = tf.nn.max_pool(tf.abs(value_pool), **pool_params)
     if policy == 'mix':
         maxed = tf.nn.max_pool(tf.abs(value_pool), **pool_params)
         avged = tf.nn.avg_pool(value_pool, **pool_params)
         pooled = maxed - avged
     #  mean, variance = tf.nn.moments(value, axes=[1, 2])
     #  variance = tf.reshape(variance, shape=[n, 1, 1, c])
     #  mean = tf.reshape(mean, shape=[n, 1, 1, c])
     # threshold
     # omap = {'Sign': 'Identity'}
     # with tf.get_default_graph().gradient_override_map(omap):
     #     self.gate = tf.sign(mean - self.threshold)
     #     self.gate = tf.clip_by_value(self.gate, 0, 1)
     # gates out feature maps with low vairance and replace the whole
     # feature map with its mean
     self.gate = util.cast(tf.abs(pooled) >= self.threshold, float)
     self.pooled = pooled
     tf.add_to_collection('mayo.overrider.gates', self.gate)
     # return mean * (1 - self.gate) + self.gate * var
     return self.gate * value
Exemplo n.º 11
0
Arquivo: base.py Projeto: zaf05/mayo
 def _overflow_rate(mask):
     """
     Compute overflow_rate from a given overflow mask.  Here `mask` is a
     boolean tensor where True and False represent the presence and absence
     of overflow repsectively.
     """
     return util.sum(util.cast(mask, int)) / util.count(mask)
Exemplo n.º 12
0
    def _transform(self,
                   sign,
                   exponent,
                   mantissa,
                   exponent_width=None,
                   mantissa_width=None,
                   exponent_bias=None):
        if exponent_bias is None:
            exponent_bias = self.exponent_bias
        if exponent_width is None:
            exponent_width = self.width - self.mantissa_width
        if mantissa_width is None:
            mantissa_width = self.mantissa_width
        # clip exponent and quantize mantissa
        exponent_min = -exponent_bias
        exponent_max = 2**exponent_width - 1 - exponent_bias
        exponent = util.clip_by_value(exponent, exponent_min, exponent_max)
        shift = util.cast(2**mantissa_width, float)
        # quantize
        if self.stochastic:
            mantissa = util.stochastic_round(mantissa * shift, self.stochastic)
            mantissa /= shift
        else:
            mantissa = util.round(mantissa * shift) / shift

        # if the mantissa value gets rounded to >= 2 then we need to divide it
        # by 2 and increment exponent by 1
        is_out_of_range = util.greater_equal(mantissa, 2)
        mantissa = util.where(is_out_of_range, mantissa / 2, mantissa)
        exponent = util.where(is_out_of_range, exponent + 1, exponent)
        return sign, exponent, mantissa
Exemplo n.º 13
0
Arquivo: gate.py Projeto: zaf05/mayo
 def _info(self):
     # FIXME it doesn't make sense to run `gate` once as its density
     # varies from run to run.
     gate = util.cast(self.session.run(self.gate), int)
     density = Percent(util.sum(gate) / util.count(gate))
     return self._info_tuple(gate=self.gate.name,
                             density=density,
                             count_=gate.size)
Exemplo n.º 14
0
 def _apply(self, value):
     self._parameter_config = {
         'mask': {
             'initial': tf.ones_initializer(dtype=tf.bool),
             'shape': value.shape,
         }
     }
     return value * util.cast(self.mask, float)
Exemplo n.º 15
0
Arquivo: mixed.py Projeto: zaf05/mayo
 def _combine_masks(self, value, quantized_values):
     """
     Args:
         quantized_value: the current mask is working on this current
             quantized value, this value is not included in
             quantizer_maps
     """
     if self.quantizer_maps:
         index = 0
         for key, quantizer in self.quantizer_maps.items():
             mask_label = index + 1
             channel_mask = util.cast(
                 util.equal(self.channel_mask, mask_label), float)
             if index == 0:
                 result = quantized_values[key] * channel_mask
             else:
                 result += quantized_values[key] * channel_mask
     # now handel off_mask
     off_mask = util.cast(util.equal(self.channel_mask, 0), float)
     return value * off_mask + result
Exemplo n.º 16
0
 def _apply(self, value):
     masked = super()._apply(value)
     gamma = self.gamma
     # register the latest gamma and mask to be used for later update
     # TODO this works as a way to collect global gammas, but the `gamma`
     # tensor is evaluated every time we use `session.run(batch=True)`,
     # will fix later if performance proves to be problematic.
     self.session.estimator.register(
         gamma, 'NetworkSlimmer.gamma', node=self, history=1)
     # add reg
     tf.losses.add_loss(
         self.weight * tf.reduce_sum(tf.abs(gamma)),
         loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES)
     return util.cast(masked, float)
Exemplo n.º 17
0
 def _apply(self, value):
     # check shape
     if not len(value.shape) >= 3:
         raise ValueError(
             'Incorrect dimension {} for channel pruner'
             .format(value.shape))
     self.num_channels = value.shape[-1]
     self._parameter_config = {
         'mask': {
             'initial': tf.ones_initializer(dtype=tf.bool),
             'shape': self.num_channels,
         },
     }
     mask = tf.reshape(self.mask, (1, 1, 1, -1))
     return value * util.cast(mask, float)
Exemplo n.º 18
0
    def _apply(self, value):
        # dynamic parameter configuration
        self._parameter_config = {
            'positives': {
                'initial': tf.ones_initializer(tf.bool),
                'shape': value.shape,
            },
        }
        positives = util.cast(self.positives, float)
        negatives = util.cast(self.negatives, float)
        non_zeros = util.cast(tf.not_equal(self.before, 0), float)

        positives_centralized = positives * (value - self.positives_mean)
        negatives_centralized = negatives * (value - self.negatives_mean)
        # keep a track of quantized value, without means
        self.quantized = self._quantize(
            non_zeros * (positives_centralized + negatives_centralized))
        quantized_value = non_zeros * positives * \
            (self.quantized + self.positives_mean)
        quantized_value += non_zeros * negatives * \
            (self.quantized + self.negatives_mean)

        self._quantization_loss_regularizer(value, quantized_value)
        return quantized_value
Exemplo n.º 19
0
Arquivo: fixed.py Projeto: zaf05/mayo
 def _apply(self, value):
     return util.cast(value > self.threshold, float)
Exemplo n.º 20
0
 def _info(self):
     mask = util.cast(self.session.run(self.mask), int)
     density = Percent(util.sum(mask) / util.count(mask))
     return self._info_tuple(
         mask=self.mask.name, density=density, count_=mask.size)