class ChannelTernaryQuantizer(TernaryQuantizer): """Same tenary quantization, but channel-wise scaling factors. """ scale = Parameter('scale', None, None, 'float', trainable=True) def _quantize(self, value, base=None): # @Aaron @Xitong FIXME possible redundancy: # this code is idenitical to super()._quantize() scale = self.scale base = util.cast(self.base if base is None else base, int) shift = util.cast(2**base, float) positives = util.cast(value > 0, float) negatives = util.cast(value < 0, float) # FIXME verify this multiplication is broadcasting correctly return positives * shift * scale - negatives * shift * scale def _apply(self, value): self._parameter_config = { 'scale': { 'initial': tf.ones_initializer(), # a vector that has a length matching # the number of output channels 'shape': value.shape[-1], } } return self._quantize(value)
class MeanStdPruner(PrunerBase): alpha = Parameter('alpha', -2, [], 'float') def __init__(self, session, alpha=None, should_update=True): super().__init__(session, should_update) self.alpha = alpha def _threshold(self, tensor, alpha=None): # axes = list(range(len(tensor.get_shape()) - 1)) tensor_shape = util.get_shape(tensor) axes = list(range(len(tensor_shape))) mean, var = util.moments(util.abs(tensor), axes) if alpha is None: return mean + self.alpha * util.sqrt(var) return mean + alpha * util.sqrt(var) def _updated_mask(self, var, mask): return util.abs(var) > self._threshold(var) def _info(self): _, mask, density, count = super()._info() alpha = self.session.run(self.alpha) return self._info_tuple(mask=mask, alpha=alpha, density=density, count_=count)
class FilterPruner(PrunerBase): density = Parameter('density', 0.0, [], 'float') mask = Parameter('mask', None, None, 'bool') def __init__(self, session, density=None, should_update=True): super().__init__(session, should_update) self.density = density def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.ones_initializer(dtype=tf.bool), 'shape': tf.TensorShape([value.shape[-2], value.shape[-1]]), } } return value * util.cast(self.mask, float) def _l1_norm(self, value): # compute l1 norm for each filter axes = len(value.shape) assert axes == 4 # mean, var = tf.nn.moments(util.abs(tensor), axes=[0, 1]) # mean = np.mean(value, axis=(0, 1)) # var = np.var(value, axis=(0, 1)) # return mean + util.sqrt(var) return util.sum(util.abs(value), axis=(0, 1)) def _threshold(self, value, density): value = value.flatten() index = int(value.size * density) return sorted(value)[index] def _updated_mask(self, tensor, mask): value, mask, density = self.session.run([tensor, mask, self.density]) l1_norm = self._l1_norm(value) # mean, var = tf.nn.moments(util.abs(tensor), axes=[0, 1]) return l1_norm > self._threshold(l1_norm, density) def _info(self): _, mask, density, count = super()._info() density = self.session.run(self.density) return self._info_tuple(mask=mask, density=density, count_=count) @classmethod def finalize_info(cls, table): footer = super().finalize_info(table) table.set_footer([None] + footer)
class TernaryQuantizer(QuantizerBase): """ Ternary quantization, quantizes all values into the range: {- 2^base * scale, 0, 2^base * scale}. Args: base: The universal coarse-grain scaling factor applied to tenary weights. References: - Extremely Low Bit Neural Network: Squeeze the Last Bit Out with ADMM - Trained Ternary Quantization """ base = Parameter('base', 1, [], 'int', trainable=False) scale = Parameter('scale', 1.0, [], 'float', trainable=True) def __init__(self, session, base=None, stochastic=None, should_update=True, enable=True): super().__init__(session, should_update, enable) if base is not None: if base < 0: raise ValueError('Base of ternary quantization must be ' 'greater or equal than 0.') self.base = base if stochastic is not None: raise NotImplementedError( 'Ternary quantization does not implement stochastic mode.') def _quantize(self, value, base=None): scale = self.scale base = util.cast(self.base if base is None else base, int) shift = util.cast(2**base, float) positives = util.cast(value > 0, float) negatives = util.cast(value < 0, float) return positives * shift * scale - negatives * shift * scale def _apply(self, value): return self._quantize(value) def _info(self): base = int(self.eval(self.base)) return self._info_tuple(width=2, base=base)
def setUp(self): self.parameter = Parameter('test', 1, (), tf.int32, True) class Overrider(OverriderBase): param = self.parameter def _apply(self, value): return self.param self.overrider = Overrider() var = VariableMock('input', 1, (), tf.int32, True) self.overrider.apply('scope', self._get_variable, var)
class DGTrainableQuantizer(DGQuantizer): """ Backpropagatable precision. Trainable width, but no gradients can be felt by it at the moment. """ width = Parameter('width', 16, [], 'float', trainable=True) def __init__(self, session, overflow_rate, should_update=True): super().__init__(session, None, None, should_update=should_update) def _apply(self, value): return self._quantize(value, self.width, self.point)
class ThresholdBinarizer(QuantizerBase): threshold = Parameter('threshold', 0, [], 'float') def __init__(self, session, threshold=None, should_update=True, enable=True): super().__init__(session, should_update, enable) if threshold is not None: self.threshold = threshold def _apply(self, value): return util.cast(value > self.threshold, float)
class ChannelPrunerBase(OverriderBase): mask = Parameter('mask', None, None, 'bool') def __init__(self, session, should_update=True): super().__init__(session, should_update) def _apply(self, value): # check shape if not len(value.shape) >= 3: raise ValueError( 'Incorrect dimension {} for channel pruner' .format(value.shape)) self.num_channels = value.shape[-1] self._parameter_config = { 'mask': { 'initial': tf.ones_initializer(dtype=tf.bool), 'shape': self.num_channels, }, } mask = tf.reshape(self.mask, (1, 1, 1, -1)) return value * util.cast(mask, float) def _updated_mask(self, var, mask): raise NotImplementedError( 'Method to compute an updated mask is not implemented.') def _update(self): mask = self._updated_mask(self.before, self.mask) self.session.assign(self.mask, mask) def _info(self): mask = util.cast(self.session.run(self.mask), int) density = Percent(util.sum(mask) / util.count(mask)) return self._info_tuple( mask=self.mask.name, density=density, count_=mask.size) @classmethod def finalize_info(cls, table): densities = table.get_column('density') count = table.get_column('count_') avg_density = sum(d * c for d, c in zip(densities, count)) / sum(count) footer = [None, ' overall: ', Percent(avg_density), None] table.set_footer(footer) return footer
class ChannelGater(GaterBase): threshold = Parameter('threshold', 1, [], 'float') def __init__(self, session, threshold=None, policy=None, should_update=True): super().__init__(session, should_update) self.threshold = threshold self.policy = policy def _apply(self, value): policy = self.policy value_pool = tf.nn.relu(value) n, h, w, c = (int(d) for d in value.shape) pool_params = { 'padding': 'VALID', 'ksize': [1, h, w, 1], 'strides': [1, 1, 1, 1] } if policy == 'avg' or policy is None: pooled = tf.nn.avg_pool(value_pool, **pool_params) if policy == 'max': pooled = tf.nn.max_pool(tf.abs(value_pool), **pool_params) if policy == 'mix': maxed = tf.nn.max_pool(tf.abs(value_pool), **pool_params) avged = tf.nn.avg_pool(value_pool, **pool_params) pooled = maxed - avged # mean, variance = tf.nn.moments(value, axes=[1, 2]) # variance = tf.reshape(variance, shape=[n, 1, 1, c]) # mean = tf.reshape(mean, shape=[n, 1, 1, c]) # threshold # omap = {'Sign': 'Identity'} # with tf.get_default_graph().gradient_override_map(omap): # self.gate = tf.sign(mean - self.threshold) # self.gate = tf.clip_by_value(self.gate, 0, 1) # gates out feature maps with low vairance and replace the whole # feature map with its mean self.gate = util.cast(tf.abs(pooled) >= self.threshold, float) self.pooled = pooled tf.add_to_collection('mayo.overrider.gates', self.gate) # return mean * (1 - self.gate) + self.gate * var return self.gate * value
class PrunerBase(OverriderBase): mask = Parameter('mask', None, None, 'bool') def __init__(self, session, should_update=True): super().__init__(session, should_update) def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.ones_initializer(dtype=tf.bool), 'shape': value.shape, } } return value * util.cast(self.mask, float) def _updated_mask(self, var, mask): raise NotImplementedError( 'Method to compute an updated mask is not implemented.') def _update(self): mask = self._updated_mask(self.before, self.mask) self.session.assign(self.mask, mask) def _info(self): mask = util.cast(self.session.run(self.mask), int) density = Percent(util.sum(mask) / util.count(mask)) return self._info_tuple( mask=self.mask.name, density=density, count_=mask.size) @classmethod def finalize_info(cls, table): densities = table.get_column('density') count = table.get_column('count_') avg_density = sum(d * c for d, c in zip(densities, count)) / sum(count) footer = ['overall: ', None, None, Percent(avg_density), count] table.add_row(footer) return footer
class Recentralizer(OverriderBase): """ Recentralizes the distribution of pruned weights. """ class QuantizedParameter(Parameter): def _quantize(self, instance, value): scope = '{}/{}/{}'.format(instance._scope, instance.__class__.__name__, self.name) quantizer = instance.parameter_quantizers.get(self.name) if not quantizer: return value return quantizer.apply(instance.node, scope, instance._original_getter, value) def __get__(self, instance, owner): if instance is None: return self name = '_quantized_{}'.format(self.name) try: return instance._parameter_variables[name] except KeyError: pass var = super().__get__(instance, owner) var = self._quantize(instance, var) instance._parameter_variables[name] = var return var positives = Parameter('positives', None, None, 'bool') positives_mean = QuantizedParameter('positives_mean', 1, [], 'float') negatives_mean = QuantizedParameter('negatives_mean', -1, [], 'float') def __init__(self, session, quantizer, mean_quantizer=None, should_update=True, reg=False): super().__init__(session, should_update) cls, params = object_from_params(quantizer) self.quantizer = cls(session, **params) self.reg = reg if mean_quantizer: cls, params = object_from_params(mean_quantizer) self.parameter_quantizers = { 'positives_mean': cls(session, **params), 'negatives_mean': cls(session, **params), } else: self.parameter_quantizers = {} @memoize_property def negatives(self): return util.logical_not(self.positives) def assign_parameters(self): super().assign_parameters() self.quantizer.assign_parameters() for quantizer in self.parameter_quantizers.values(): quantizer.assign_parameters() def _quantize(self, value): quantizer = self.quantizer scope = '{}/{}'.format(self._scope, self.__class__.__name__) return quantizer.apply(self.node, scope, self._original_getter, value) def _apply(self, value): # dynamic parameter configuration self._parameter_config = { 'positives': { 'initial': tf.ones_initializer(tf.bool), 'shape': value.shape, }, } positives = util.cast(self.positives, float) negatives = util.cast(self.negatives, float) non_zeros = util.cast(tf.not_equal(self.before, 0), float) positives_centralized = positives * (value - self.positives_mean) negatives_centralized = negatives * (value - self.negatives_mean) # keep a track of quantized value, without means self.quantized = self._quantize( non_zeros * (positives_centralized + negatives_centralized)) quantized_value = non_zeros * positives * \ (self.quantized + self.positives_mean) quantized_value += non_zeros * negatives * \ (self.quantized + self.negatives_mean) self._quantization_loss_regularizer(value, quantized_value) return quantized_value def _quantization_loss_regularizer(self, value, quantized_value): if self.reg == 0.0: return loss = tf.reduce_sum(tf.abs(value - quantized_value)) loss *= self.reg loss_name = tf.GraphKeys.REGULARIZATION_LOSSES tf.add_to_collection(loss_name, loss) def _update(self): # update positives mask and mean values value = self.session.run(self.before) # divide them into two groups # mean = util.mean(value) mean = 0.0 # find two central points positives = value > mean self.positives = positives self.positives_mean = util.mean(value[util.where(positives)]) negatives = util.logical_and(util.logical_not(positives), value != 0) self.negatives_mean = util.mean(value[util.where(negatives)]) if self.positives_mean.eval() == 0 or self.negatives_mean.eval() == 0: log.warn( 'means are skewed, pos mean is {} and neg mean is {}'.format( self.positives_mean.eval(), self.negatives_mean.eval())) # update internal quantizer self.quantizer.update() for quantizer in self.parameter_quantizers.values(): quantizer.update() def _info(self): info = self.quantizer.info()._asdict() for name, quantizer in self.parameter_quantizers.items(): param_info = quantizer.info() param_info = { '{}_{}'.format(name, key): value for key, value in param_info._asdict().items() } info.update(param_info) info.pop('name') return self._info_tuple(**info)
class IncrementalQuantizer(OverriderBase): """ https://arxiv.org/pdf/1702.03044.pdf """ mask = Parameter('mask', None, None, 'bool') def __init__(self, session, quantizer, interval, count_zero=True, should_update=True, enable=True): super().__init__(session, should_update, enable) cls, params = object_from_params(quantizer) self.quantizer = cls(session, **params) self.count_zero = count_zero self.interval = interval def _quantize(self, value, mean_quantizer=False): quantizer = self.quantizer scope = '{}/{}'.format(self._scope, self.__class__.__name__) return quantizer.apply(self.node, scope, self._original_getter, value) def _apply(self, value): self._parameter_config = { 'mask': { 'initial': tf.zeros_initializer(tf.bool), 'shape': value.shape, } } quantized_value = self._quantize(value) off_mask = util.cast(util.logical_not(self.mask), float) mask = util.cast(self.mask, float) # on mask indicates the quantized values return value * off_mask + quantized_value * mask def _policy(self, value, quantized, previous_mask, interval): previous_pruned = util.sum(previous_mask) if self.count_zero: th_arg = util.cast(util.count(value) * interval, int) else: tmp = util.count(value[value != 0]) flat_value_arg = util.where(value.flatten() != 0) th_arg = util.cast(tmp * interval, int) if th_arg < 0: raise ValueError('mask has {} elements, interval is {}'.format( previous_pruned, interval)) off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)), float) metric = value - quantized flat_value = (metric * off_mask).flatten() if interval >= 1.0: th = flat_value.max() + 1.0 else: if self.count_zero: th = util.top_k(util.abs(flat_value), th_arg) else: th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg) th = util.cast(th, float) new_mask = util.logical_not(util.greater_equal(util.abs(metric), th)) return util.logical_or(new_mask, previous_mask) # override assign_parameters to assign quantizer as well def assign_parameters(self): super().assign_parameters() self.quantizer.assign_parameters() def _update(self): # reset index self.quantizer.update() # if chosen quantized, change it to zeros value, quantized, mask = self.session.run( [self.before, self.quantizer.after, self.mask]) new_mask = self._policy(value, quantized, mask, self.interval) self.session.assign(self.mask, new_mask) def dump(self): return self.quantizer.dump() def _info(self): return self.quantizer._info()
class FloatingPointQuantizer(QuantizerBase): """ Minifloat quantization. When exponent_width is 0, the floating-point value is a degenerate case where exponent is always a constant bias, equivalent to fixed-point with a sign-magnitude representation. When mantissa_width is 0, the floating-point value is a degenerate case where mantissa is always 1, equivalent to shifts with only 2^n representations. When both exponent_width and mantissa_width are 0, the quantized value can only represent $2^{-bias}$ or 0, which is not very useful. """ width = Parameter('width', 32, [], 'float') exponent_bias = Parameter('exponent_bias', -127, [], 'float') mantissa_width = Parameter('mantissa_width', 23, [], 'float') def __init__(self, session, width, exponent_bias, mantissa_width, overflow_rate=0.0, should_update=True, stochastic=None): super().__init__(session, should_update) self.width = width self.exponent_bias = exponent_bias self.mantissa_width = mantissa_width self.overflow_rate = overflow_rate self.stochastic = stochastic exponent_width = width - mantissa_width is_valid = exponent_width >= 0 and mantissa_width >= 0 is_valid = is_valid and (not (exponent_width == 0 and mantissa_width == 0)) if not is_valid: raise ValueError( 'We expect exponent_width >= 0 and mantissa_width >= 0 ' 'where equalities must be exclusive.') def _decompose(self, value, exponent_bias=None): """ Decompose a single-precision floating-point into sign, exponent and mantissa components. """ if exponent_bias is None: exponent_bias = self.exponent_bias # smallest non-zero floating point descriminator = (2**(-exponent_bias)) / 2 sign = util.cast(value > descriminator, int) sign -= util.cast(value < -descriminator, int) value = util.abs(value) exponent = util.floor(util.log(value, 2)) mantissa = value / (2**exponent) return sign, exponent, mantissa def _transform(self, sign, exponent, mantissa, exponent_width=None, mantissa_width=None, exponent_bias=None): if exponent_bias is None: exponent_bias = self.exponent_bias if exponent_width is None: exponent_width = self.width - self.mantissa_width if mantissa_width is None: mantissa_width = self.mantissa_width # clip exponent and quantize mantissa exponent_min = -exponent_bias exponent_max = 2**exponent_width - 1 - exponent_bias exponent = util.clip_by_value(exponent, exponent_min, exponent_max) shift = util.cast(2**mantissa_width, float) # quantize if self.stochastic: mantissa = util.stochastic_round(mantissa * shift, self.stochastic) mantissa /= shift else: mantissa = util.round(mantissa * shift) / shift # if the mantissa value gets rounded to >= 2 then we need to divide it # by 2 and increment exponent by 1 is_out_of_range = util.greater_equal(mantissa, 2) mantissa = util.where(is_out_of_range, mantissa / 2, mantissa) exponent = util.where(is_out_of_range, exponent + 1, exponent) return sign, exponent, mantissa def _represent(self, sign, exponent, mantissa): """ Represent the value in floating-point using sign, exponent and mantissa. """ value = util.cast(sign, float) * (2.0**exponent) * mantissa if util.is_constant(sign, exponent, mantissa): return value if util.is_numpy(sign, exponent, mantissa): zeros = np.zeros(sign.shape, dtype=np.int32) else: zeros = tf.zeros(sign.shape, dtype=tf.int32) is_zero = util.equal(sign, zeros) return util.where(is_zero, util.cast(zeros, float), value) def _quantize(self, value, exponent_width=None, mantissa_width=None, exponent_bias=None): sign, exponent, mantissa = self._decompose(value, exponent_bias) sign, exponent, mantissa = self._transform(sign, exponent, mantissa, exponent_width, mantissa_width, exponent_bias) return self._represent(sign, exponent, mantissa) def _apply(self, value): quantized = self._quantize(value) nan = tf.reduce_sum(tf.cast(tf.is_nan(quantized), tf.int32)) assertion = tf.Assert(tf.equal(nan, 0), [nan]) with tf.control_dependencies([assertion]): return value + tf.stop_gradient(quantized - value) def _bias(self, value, exponent_width): max_exponent = int(2**exponent_width) for exponent in range(min(-max_exponent, -4), max(max_exponent, 4)): max_value = 2**(exponent + 1) overflows = util.logical_or(value < -max_value, value > max_value) if self._overflow_rate(overflows) <= self.overflow_rate: break return 2**exponent_width - 1 - exponent def compute_quantization_loss(self, value, exponent_width, mantissa_width, overflow_rate): exponent_bias = self._bias(value, exponent_width) quantized = self._quantize(value, exponent_width, mantissa_width, exponent_bias) # mean squared loss loss = ((value - quantized)**2).mean() return (loss, exponent_bias) def _info(self): width = int(self.eval(self.width)) mantissa_width = int(self.eval(self.mantissa_width)) exponent_bias = int(self.eval(self.exponent_bias)) return self._info_tuple(width=width, mantissa_width=mantissa_width, exponent_bias=exponent_bias) def _update(self): value = self.eval(self.before) exponent_width = self.eval(self.width) - self.eval(self.mantissa_width) self.exponent_bias = self._bias(value, exponent_width)
class LowRankApproximation(OverriderBase): singular = Parameter('singular', None, None, 'float') left = Parameter('left', None, None, 'float') right = Parameter('right', None, None, 'float') def __init__(self, session, should_update=True, ranks=0): super().__init__(session, should_update) # ranks to prune away self.ranks = ranks def _parameter_initial(self, value): dimensions = value.shape left_dimension = dimensions[0] * dimensions[2] right_dimension = dimensions[1] * dimensions[3] left_shape = (left_dimension, left_dimension) right_shape = (right_dimension, right_dimension) rows = int(left_dimension) columns = int(right_dimension) singular_shape = left_dimension if rows < columns else right_dimension self._parameter_config = { 'singular': { 'initial': tf.ones_initializer(dtype=tf.float32), 'shape': singular_shape, }, 'left': { 'initial': tf.ones_initializer(dtype=tf.float32), 'shape': left_shape, }, 'right': { 'initial': tf.ones_initializer(dtype=tf.float32), 'shape': right_shape, } } return (rows, columns) def _apply(self, value): rows, columns = self._parameter_initial(value) if rows < columns: singular = tf.expand_dims(self.singular, 1) * tf.eye(rows, columns) else: singular = tf.expand_dims(self.singular, 0) * tf.eye(rows, columns) svd_construct = tf.matmul(tf.matmul(self.left, singular), self.right) return tf.reshape(svd_construct, value.shape) def _update(self): value = self.session.run(self.before) dimensions = value.shape if len(dimensions) == 4: meshed = np.reshape( value, [dimensions[0] * dimensions[2], dimensions[1] * dimensions[3]]) elif len(dimensions) == 2: meshed = value else: raise ValueError('uh') left, singular, right = np.linalg.svd(meshed, full_matrices=True) singular[-self.ranks:] = 0.0 self.session.assign(self.left, left) self.session.assign(self.singular, singular) self.session.assign(self.right, right)
class FixedPointQuantizer(QuantizerBase): """ Quantize inputs into 2's compliment n-bit fixed-point values with d-bit dynamic range. Args: - width: The number of bits to use in number representation. If not specified, we do not limit the range of values. - point: The position of the binary point, counting from the LSB. References: [1] https://arxiv.org/pdf/1604.03168 """ width = Parameter('width', 32, [], 'int') point = Parameter('point', 2, [], 'int') def __init__(self, session, point=None, width=None, should_update=True, stochastic=None): super().__init__(session, should_update) if point is not None: self.point = point if width is not None: if width < 1: raise ValueError( 'Width of quantized value must be greater than 0.') self.width = width if stochastic is None: self.stochastic = False self.stochastic = stochastic def _quantize(self, value, point=None, width=None, compute_overflow_rate=False): point = util.cast(self.point if point is None else point, float) width = util.cast(self.width if width is None else width, float) # x << (width - point) shift = 2.0**(util.round(width) - util.round(point)) value = value * shift # quantize if self.stochastic: value = util.stochastic_round(value, self.stochastic) else: value = util.round(value) # ensure number is representable without overflow if width is not None: max_value = util.cast(2**(width - 1), float) if compute_overflow_rate: overflow_value = value[value != 0] overflows = util.logical_or(overflow_value < -max_value, overflow_value > max_value - 1) return self._overflow_rate(overflows) value = util.clip_by_value(value, -max_value, max_value - 1) # revert bit-shift earlier return value / shift def _apply(self, value): return self._quantize(value) def _info(self): width = int(self.eval(self.width)) point = int(self.eval(self.point)) return self._info_tuple(width=width, point=point)
class MixedQuantizer(QuantizerBase): """ Mixed Precision should be implemented as the following: mask1 * precision1 + mask2 * precision2 ... The masks are mutually exclusive Currently supporting: 1. making a loss to the reg term 2. quantizer_maps contains parallel quantizers that each can have a different quantizer 3. channel wise granuarity based on output channels TODO: provide _update() """ interval = Parameter('interval', 0.1, [], 'float') channel_mask = Parameter('channel_mask', None, None, 'int') def __init__(self, session, quantizers, index=0, should_update=True, reg_factor=0.0, interval=0.1): super().__init__(session, should_update) self.quantizer_maps = {} for key, item in dict(quantizers).items(): cls, params = object_from_params(item) quantizer = cls(session, **params) self.quantizer_maps[key] = quantizer self.reg_factor = reg_factor # the quantizer that makes a loss for training self.quantizers = quantizers self.picked_quantizer = list(quantizers.keys())[index] # keep record of an index for update self.index = index def _apply(self, value): # making an quantization loss to reg loss self._parameter_config = { 'channel_mask': { 'initial': tf.zeros_initializer(tf.bool), 'shape': value.shape[-1], } } quantized_values = self._quantize(value) self._quantization_loss(value, quantized_values[self.picked_quantizer]) # on mask indicates the quantized values return self._combine_masks(value, quantized_values) def _quantize(self, value, mean_quantizer=False): quantized_values = {} for key, quantizer in dict(self.quantizer_maps).items(): scope = '{}/{}'.format(self._scope, self.__class__.__name__ + key) quantized_values[key] = quantizer.apply(self.node, scope, self._original_getter, value) return quantized_values def _combine_masks(self, value, quantized_values): """ Args: quantized_value: the current mask is working on this current quantized value, this value is not included in quantizer_maps """ if self.quantizer_maps: index = 0 for key, quantizer in self.quantizer_maps.items(): mask_label = index + 1 channel_mask = util.cast( util.equal(self.channel_mask, mask_label), float) if index == 0: result = quantized_values[key] * channel_mask else: result += quantized_values[key] * channel_mask # now handel off_mask off_mask = util.cast(util.equal(self.channel_mask, 0), float) return value * off_mask + result def _quantization_loss(self, value, quantized_value): loss = tf.reduce_sum(tf.abs(value - quantized_value)) loss *= self.reg_factor loss_name = tf.GraphKeys.REGULARIZATION_LOSSES tf.add_to_collection(loss_name, loss) def _new_mask(self, mask, value, quantized_value, interval): loss = util.abs(value - quantized_value) # check the ones that are not quantized mask = mask.reshape((1, 1, 1, mask.shape[0])) unquantized_mask = util.logical_not(mask) # TODO: mask shape is incorrect loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2)) # sort num_active = util.ceil(len(loss_vec) * interval) threshold = sorted(loss_vec)[num_active] if interval >= 1.0: return util.cast(unquantized_mask, float) new_mask = (unquantized_mask * loss) > threshold return util.cast(util.logical_or(new_mask, mask), float) def _update(self): # update only the selected index quantizer = self.quantizer_maps[self.picked_quantizer] mask, value, quantized_value, interval = self.session.run( [self.channel_mask, self.before, quantizer.after, self.interval]) mask = mask == (self.index + 1) new_mask = self._new_mask(mask, value, quantized_value, interval) self.index += 1 self.picked_quantizer = list(self.quantizers.keys())[self.index]