def _policy(self, value, quantized, previous_mask, interval): previous_pruned = util.sum(previous_mask) if self.count_zero: th_arg = util.cast(util.count(value) * interval, int) else: tmp = util.count(value[value != 0]) flat_value_arg = util.where(value.flatten() != 0) th_arg = util.cast(tmp * interval, int) if th_arg < 0: raise ValueError('mask has {} elements, interval is {}'.format( previous_pruned, interval)) off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)), float) metric = value - quantized flat_value = (metric * off_mask).flatten() if interval >= 1.0: th = flat_value.max() + 1.0 else: if self.count_zero: th = util.top_k(util.abs(flat_value), th_arg) else: th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg) th = util.cast(th, float) new_mask = util.logical_not(util.greater_equal(util.abs(metric), th)) return util.logical_or(new_mask, previous_mask)
def _quantize(self, value, base=None): scale = self.scale base = util.cast(self.base if base is None else base, int) shift = util.cast(2**base, float) positives = util.cast(value > 0, float) negatives = util.cast(value < 0, float) return positives * shift * scale - negatives * shift * scale
def _quantize(self, value, point=None, width=None, compute_overflow_rate=False): point = util.cast(self.point if point is None else point, float) width = util.cast(self.width if width is None else width, float) # x << (width - point) shift = 2.0**(util.round(width) - util.round(point)) value = value * shift # quantize if self.stochastic: value = util.stochastic_round(value, self.stochastic) else: value = util.round(value) # ensure number is representable without overflow if width is not None: max_value = util.cast(2**(width - 1), float) if compute_overflow_rate: overflow_value = value[value != 0] overflows = util.logical_or(overflow_value < -max_value, overflow_value > max_value - 1) return self._overflow_rate(overflows) value = util.clip_by_value(value, -max_value, max_value - 1) # revert bit-shift earlier return value / shift
def _quantize(self, value, base=None): # @Aaron @Xitong FIXME possible redundancy: # this code is idenitical to super()._quantize() scale = self.scale base = util.cast(self.base if base is None else base, int) shift = util.cast(2**base, float) positives = util.cast(value > 0, float) negatives = util.cast(value < 0, float) # FIXME verify this multiplication is broadcasting correctly return positives * shift * scale - negatives * shift * scale
def _quantize(self, value, point, width, compute_overflow_rate=False): # decompose sign = util.cast(value > 0, float) - util.cast(value < 0, float) value = util.log(util.abs(value), 2.0) # quantize value = self.quantizer.apply( value, compute_overflow_rate=compute_overflow_rate) if compute_overflow_rate: return value # represent return util.where(util.nonzero(sign), sign * (2**value), 0)
def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.zeros_initializer(tf.bool), 'shape': value.shape, } } quantized_value = self._quantize(value) off_mask = util.cast(util.logical_not(self.mask), float) mask = util.cast(self.mask, float) # on mask indicates the quantized values return value * off_mask + quantized_value * mask
def _new_mask(self, mask, value, quantized_value, interval): loss = util.abs(value - quantized_value) # check the ones that are not quantized mask = mask.reshape((1, 1, 1, mask.shape[0])) unquantized_mask = util.logical_not(mask) # TODO: mask shape is incorrect loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2)) # sort num_active = util.ceil(len(loss_vec) * interval) threshold = sorted(loss_vec)[num_active] if interval >= 1.0: return util.cast(unquantized_mask, float) new_mask = (unquantized_mask * loss) > threshold return util.cast(util.logical_or(new_mask, mask), float)
def _represent(self, sign, exponent, mantissa): """ Represent the value in floating-point using sign, exponent and mantissa. """ value = util.cast(sign, float) * (2.0**exponent) * mantissa if util.is_constant(sign, exponent, mantissa): return value if util.is_numpy(sign, exponent, mantissa): zeros = np.zeros(sign.shape, dtype=np.int32) else: zeros = tf.zeros(sign.shape, dtype=tf.int32) is_zero = util.equal(sign, zeros) return util.where(is_zero, util.cast(zeros, float), value)
def _decompose(self, value, exponent_bias=None): """ Decompose a single-precision floating-point into sign, exponent and mantissa components. """ if exponent_bias is None: exponent_bias = self.exponent_bias # smallest non-zero floating point descriminator = (2**(-exponent_bias)) / 2 sign = util.cast(value > descriminator, int) sign -= util.cast(value < -descriminator, int) value = util.abs(value) exponent = util.floor(util.log(value, 2)) mantissa = value / (2**exponent) return sign, exponent, mantissa
def _apply(self, value): policy = self.policy value_pool = tf.nn.relu(value) n, h, w, c = (int(d) for d in value.shape) pool_params = { 'padding': 'VALID', 'ksize': [1, h, w, 1], 'strides': [1, 1, 1, 1] } if policy == 'avg' or policy is None: pooled = tf.nn.avg_pool(value_pool, **pool_params) if policy == 'max': pooled = tf.nn.max_pool(tf.abs(value_pool), **pool_params) if policy == 'mix': maxed = tf.nn.max_pool(tf.abs(value_pool), **pool_params) avged = tf.nn.avg_pool(value_pool, **pool_params) pooled = maxed - avged # mean, variance = tf.nn.moments(value, axes=[1, 2]) # variance = tf.reshape(variance, shape=[n, 1, 1, c]) # mean = tf.reshape(mean, shape=[n, 1, 1, c]) # threshold # omap = {'Sign': 'Identity'} # with tf.get_default_graph().gradient_override_map(omap): # self.gate = tf.sign(mean - self.threshold) # self.gate = tf.clip_by_value(self.gate, 0, 1) # gates out feature maps with low vairance and replace the whole # feature map with its mean self.gate = util.cast(tf.abs(pooled) >= self.threshold, float) self.pooled = pooled tf.add_to_collection('mayo.overrider.gates', self.gate) # return mean * (1 - self.gate) + self.gate * var return self.gate * value
def _overflow_rate(mask): """ Compute overflow_rate from a given overflow mask. Here `mask` is a boolean tensor where True and False represent the presence and absence of overflow repsectively. """ return util.sum(util.cast(mask, int)) / util.count(mask)
def _transform(self, sign, exponent, mantissa, exponent_width=None, mantissa_width=None, exponent_bias=None): if exponent_bias is None: exponent_bias = self.exponent_bias if exponent_width is None: exponent_width = self.width - self.mantissa_width if mantissa_width is None: mantissa_width = self.mantissa_width # clip exponent and quantize mantissa exponent_min = -exponent_bias exponent_max = 2**exponent_width - 1 - exponent_bias exponent = util.clip_by_value(exponent, exponent_min, exponent_max) shift = util.cast(2**mantissa_width, float) # quantize if self.stochastic: mantissa = util.stochastic_round(mantissa * shift, self.stochastic) mantissa /= shift else: mantissa = util.round(mantissa * shift) / shift # if the mantissa value gets rounded to >= 2 then we need to divide it # by 2 and increment exponent by 1 is_out_of_range = util.greater_equal(mantissa, 2) mantissa = util.where(is_out_of_range, mantissa / 2, mantissa) exponent = util.where(is_out_of_range, exponent + 1, exponent) return sign, exponent, mantissa
def _info(self): # FIXME it doesn't make sense to run `gate` once as its density # varies from run to run. gate = util.cast(self.session.run(self.gate), int) density = Percent(util.sum(gate) / util.count(gate)) return self._info_tuple(gate=self.gate.name, density=density, count_=gate.size)
def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.ones_initializer(dtype=tf.bool), 'shape': value.shape, } } return value * util.cast(self.mask, float)
def _combine_masks(self, value, quantized_values): """ Args: quantized_value: the current mask is working on this current quantized value, this value is not included in quantizer_maps """ if self.quantizer_maps: index = 0 for key, quantizer in self.quantizer_maps.items(): mask_label = index + 1 channel_mask = util.cast( util.equal(self.channel_mask, mask_label), float) if index == 0: result = quantized_values[key] * channel_mask else: result += quantized_values[key] * channel_mask # now handel off_mask off_mask = util.cast(util.equal(self.channel_mask, 0), float) return value * off_mask + result
def _apply(self, value): masked = super()._apply(value) gamma = self.gamma # register the latest gamma and mask to be used for later update # TODO this works as a way to collect global gammas, but the `gamma` # tensor is evaluated every time we use `session.run(batch=True)`, # will fix later if performance proves to be problematic. self.session.estimator.register( gamma, 'NetworkSlimmer.gamma', node=self, history=1) # add reg tf.losses.add_loss( self.weight * tf.reduce_sum(tf.abs(gamma)), loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES) return util.cast(masked, float)
def _apply(self, value): # check shape if not len(value.shape) >= 3: raise ValueError( 'Incorrect dimension {} for channel pruner' .format(value.shape)) self.num_channels = value.shape[-1] self._parameter_config = { 'mask': { 'initial': tf.ones_initializer(dtype=tf.bool), 'shape': self.num_channels, }, } mask = tf.reshape(self.mask, (1, 1, 1, -1)) return value * util.cast(mask, float)
def _apply(self, value): # dynamic parameter configuration self._parameter_config = { 'positives': { 'initial': tf.ones_initializer(tf.bool), 'shape': value.shape, }, } positives = util.cast(self.positives, float) negatives = util.cast(self.negatives, float) non_zeros = util.cast(tf.not_equal(self.before, 0), float) positives_centralized = positives * (value - self.positives_mean) negatives_centralized = negatives * (value - self.negatives_mean) # keep a track of quantized value, without means self.quantized = self._quantize( non_zeros * (positives_centralized + negatives_centralized)) quantized_value = non_zeros * positives * \ (self.quantized + self.positives_mean) quantized_value += non_zeros * negatives * \ (self.quantized + self.negatives_mean) self._quantization_loss_regularizer(value, quantized_value) return quantized_value
def _apply(self, value): return util.cast(value > self.threshold, float)
def _info(self): mask = util.cast(self.session.run(self.mask), int) density = Percent(util.sum(mask) / util.count(mask)) return self._info_tuple( mask=self.mask.name, density=density, count_=mask.size)