예제 #1
0
class ChannelTernaryQuantizer(TernaryQuantizer):
    """Same tenary quantization, but channel-wise scaling factors.  """
    scale = Parameter('scale', None, None, 'float', trainable=True)

    def _quantize(self, value, base=None):
        # @Aaron @Xitong FIXME possible redundancy:
        # this code is idenitical to super()._quantize()
        scale = self.scale
        base = util.cast(self.base if base is None else base, int)
        shift = util.cast(2**base, float)
        positives = util.cast(value > 0, float)
        negatives = util.cast(value < 0, float)
        # FIXME verify this multiplication is broadcasting correctly
        return positives * shift * scale - negatives * shift * scale

    def _apply(self, value):
        self._parameter_config = {
            'scale': {
                'initial': tf.ones_initializer(),
                # a vector that has a length matching
                # the number of output channels
                'shape': value.shape[-1],
            }
        }
        return self._quantize(value)
예제 #2
0
파일: dns.py 프로젝트: zaf05/mayo
class MeanStdPruner(PrunerBase):
    alpha = Parameter('alpha', -2, [], 'float')

    def __init__(self, session, alpha=None, should_update=True):
        super().__init__(session, should_update)
        self.alpha = alpha

    def _threshold(self, tensor, alpha=None):
        # axes = list(range(len(tensor.get_shape()) - 1))
        tensor_shape = util.get_shape(tensor)
        axes = list(range(len(tensor_shape)))
        mean, var = util.moments(util.abs(tensor), axes)
        if alpha is None:
            return mean + self.alpha * util.sqrt(var)
        return mean + alpha * util.sqrt(var)

    def _updated_mask(self, var, mask):
        return util.abs(var) > self._threshold(var)

    def _info(self):
        _, mask, density, count = super()._info()
        alpha = self.session.run(self.alpha)
        return self._info_tuple(mask=mask,
                                alpha=alpha,
                                density=density,
                                count_=count)
예제 #3
0
파일: filter.py 프로젝트: zaf05/mayo
class FilterPruner(PrunerBase):
    density = Parameter('density', 0.0, [], 'float')
    mask = Parameter('mask', None, None, 'bool')

    def __init__(self, session, density=None, should_update=True):
        super().__init__(session, should_update)
        self.density = density

    def _apply(self, value):
        self._parameter_config = {
            'mask': {
                'initial': tf.ones_initializer(dtype=tf.bool),
                'shape': tf.TensorShape([value.shape[-2], value.shape[-1]]),
            }
        }
        return value * util.cast(self.mask, float)

    def _l1_norm(self, value):
        # compute l1 norm for each filter
        axes = len(value.shape)
        assert axes == 4
        # mean, var = tf.nn.moments(util.abs(tensor), axes=[0, 1])
        # mean = np.mean(value, axis=(0, 1))
        # var = np.var(value, axis=(0, 1))
        # return mean + util.sqrt(var)
        return util.sum(util.abs(value), axis=(0, 1))

    def _threshold(self, value, density):
        value = value.flatten()
        index = int(value.size * density)
        return sorted(value)[index]

    def _updated_mask(self, tensor, mask):
        value, mask, density = self.session.run([tensor, mask, self.density])
        l1_norm = self._l1_norm(value)
        # mean, var = tf.nn.moments(util.abs(tensor), axes=[0, 1])
        return l1_norm > self._threshold(l1_norm, density)

    def _info(self):
        _, mask, density, count = super()._info()
        density = self.session.run(self.density)
        return self._info_tuple(mask=mask, density=density, count_=count)

    @classmethod
    def finalize_info(cls, table):
        footer = super().finalize_info(table)
        table.set_footer([None] + footer)
예제 #4
0
class TernaryQuantizer(QuantizerBase):
    """
    Ternary quantization, quantizes all values into the range:
        {- 2^base * scale, 0, 2^base * scale}.

    Args:
        base: The universal coarse-grain scaling factor
              applied to tenary weights.
    References:
        - Extremely Low Bit Neural Network: Squeeze the Last Bit Out with ADMM
        - Trained Ternary Quantization
    """
    base = Parameter('base', 1, [], 'int', trainable=False)
    scale = Parameter('scale', 1.0, [], 'float', trainable=True)

    def __init__(self,
                 session,
                 base=None,
                 stochastic=None,
                 should_update=True,
                 enable=True):
        super().__init__(session, should_update, enable)
        if base is not None:
            if base < 0:
                raise ValueError('Base of ternary quantization must be '
                                 'greater or equal than 0.')
            self.base = base
        if stochastic is not None:
            raise NotImplementedError(
                'Ternary quantization does not implement stochastic mode.')

    def _quantize(self, value, base=None):
        scale = self.scale
        base = util.cast(self.base if base is None else base, int)
        shift = util.cast(2**base, float)
        positives = util.cast(value > 0, float)
        negatives = util.cast(value < 0, float)
        return positives * shift * scale - negatives * shift * scale

    def _apply(self, value):
        return self._quantize(value)

    def _info(self):
        base = int(self.eval(self.base))
        return self._info_tuple(width=2, base=base)
예제 #5
0
파일: test_override.py 프로젝트: zaf05/mayo
    def setUp(self):
        self.parameter = Parameter('test', 1, (), tf.int32, True)

        class Overrider(OverriderBase):
            param = self.parameter

            def _apply(self, value):
                return self.param

        self.overrider = Overrider()
        var = VariableMock('input', 1, (), tf.int32, True)
        self.overrider.apply('scope', self._get_variable, var)
예제 #6
0
파일: fixed.py 프로젝트: randysuen/mayo
class DGTrainableQuantizer(DGQuantizer):
    """
    Backpropagatable precision.

    Trainable width, but no gradients can be felt by it at the moment.
    """
    width = Parameter('width', 16, [], 'float', trainable=True)

    def __init__(self, session, overflow_rate, should_update=True):
        super().__init__(session, None, None, should_update=should_update)

    def _apply(self, value):
        return self._quantize(value, self.width, self.point)
예제 #7
0
파일: fixed.py 프로젝트: zaf05/mayo
class ThresholdBinarizer(QuantizerBase):
    threshold = Parameter('threshold', 0, [], 'float')

    def __init__(self,
                 session,
                 threshold=None,
                 should_update=True,
                 enable=True):
        super().__init__(session, should_update, enable)
        if threshold is not None:
            self.threshold = threshold

    def _apply(self, value):
        return util.cast(value > self.threshold, float)
예제 #8
0
class ChannelPrunerBase(OverriderBase):
    mask = Parameter('mask', None, None, 'bool')

    def __init__(self, session, should_update=True):
        super().__init__(session, should_update)

    def _apply(self, value):
        # check shape
        if not len(value.shape) >= 3:
            raise ValueError(
                'Incorrect dimension {} for channel pruner'
                .format(value.shape))
        self.num_channels = value.shape[-1]
        self._parameter_config = {
            'mask': {
                'initial': tf.ones_initializer(dtype=tf.bool),
                'shape': self.num_channels,
            },
        }
        mask = tf.reshape(self.mask, (1, 1, 1, -1))
        return value * util.cast(mask, float)

    def _updated_mask(self, var, mask):
        raise NotImplementedError(
            'Method to compute an updated mask is not implemented.')

    def _update(self):
        mask = self._updated_mask(self.before, self.mask)
        self.session.assign(self.mask, mask)

    def _info(self):
        mask = util.cast(self.session.run(self.mask), int)
        density = Percent(util.sum(mask) / util.count(mask))
        return self._info_tuple(
            mask=self.mask.name, density=density, count_=mask.size)

    @classmethod
    def finalize_info(cls, table):
        densities = table.get_column('density')
        count = table.get_column('count_')
        avg_density = sum(d * c for d, c in zip(densities, count)) / sum(count)
        footer = [None, '    overall: ', Percent(avg_density), None]
        table.set_footer(footer)
        return footer
예제 #9
0
파일: gate.py 프로젝트: zaf05/mayo
class ChannelGater(GaterBase):
    threshold = Parameter('threshold', 1, [], 'float')

    def __init__(self,
                 session,
                 threshold=None,
                 policy=None,
                 should_update=True):
        super().__init__(session, should_update)
        self.threshold = threshold
        self.policy = policy

    def _apply(self, value):
        policy = self.policy
        value_pool = tf.nn.relu(value)
        n, h, w, c = (int(d) for d in value.shape)
        pool_params = {
            'padding': 'VALID',
            'ksize': [1, h, w, 1],
            'strides': [1, 1, 1, 1]
        }
        if policy == 'avg' or policy is None:
            pooled = tf.nn.avg_pool(value_pool, **pool_params)
        if policy == 'max':
            pooled = tf.nn.max_pool(tf.abs(value_pool), **pool_params)
        if policy == 'mix':
            maxed = tf.nn.max_pool(tf.abs(value_pool), **pool_params)
            avged = tf.nn.avg_pool(value_pool, **pool_params)
            pooled = maxed - avged
        #  mean, variance = tf.nn.moments(value, axes=[1, 2])
        #  variance = tf.reshape(variance, shape=[n, 1, 1, c])
        #  mean = tf.reshape(mean, shape=[n, 1, 1, c])
        # threshold
        # omap = {'Sign': 'Identity'}
        # with tf.get_default_graph().gradient_override_map(omap):
        #     self.gate = tf.sign(mean - self.threshold)
        #     self.gate = tf.clip_by_value(self.gate, 0, 1)
        # gates out feature maps with low vairance and replace the whole
        # feature map with its mean
        self.gate = util.cast(tf.abs(pooled) >= self.threshold, float)
        self.pooled = pooled
        tf.add_to_collection('mayo.overrider.gates', self.gate)
        # return mean * (1 - self.gate) + self.gate * var
        return self.gate * value
예제 #10
0
파일: base.py 프로젝트: zaf05/mayo
class PrunerBase(OverriderBase):
    mask = Parameter('mask', None, None, 'bool')

    def __init__(self, session, should_update=True):
        super().__init__(session, should_update)

    def _apply(self, value):
        self._parameter_config = {
            'mask': {
                'initial': tf.ones_initializer(dtype=tf.bool),
                'shape': value.shape,
            }
        }
        return value * util.cast(self.mask, float)

    def _updated_mask(self, var, mask):
        raise NotImplementedError(
            'Method to compute an updated mask is not implemented.')

    def _update(self):
        mask = self._updated_mask(self.before, self.mask)
        self.session.assign(self.mask, mask)

    def _info(self):
        mask = util.cast(self.session.run(self.mask), int)
        density = Percent(util.sum(mask) / util.count(mask))
        return self._info_tuple(
            mask=self.mask.name, density=density, count_=mask.size)

    @classmethod
    def finalize_info(cls, table):
        densities = table.get_column('density')
        count = table.get_column('count_')
        avg_density = sum(d * c for d, c in zip(densities, count)) / sum(count)
        footer = ['overall: ', None, None, Percent(avg_density), count]
        table.add_row(footer)
        return footer
예제 #11
0
class Recentralizer(OverriderBase):
    """ Recentralizes the distribution of pruned weights.  """
    class QuantizedParameter(Parameter):
        def _quantize(self, instance, value):
            scope = '{}/{}/{}'.format(instance._scope,
                                      instance.__class__.__name__, self.name)
            quantizer = instance.parameter_quantizers.get(self.name)
            if not quantizer:
                return value
            return quantizer.apply(instance.node, scope,
                                   instance._original_getter, value)

        def __get__(self, instance, owner):
            if instance is None:
                return self
            name = '_quantized_{}'.format(self.name)
            try:
                return instance._parameter_variables[name]
            except KeyError:
                pass
            var = super().__get__(instance, owner)
            var = self._quantize(instance, var)
            instance._parameter_variables[name] = var
            return var

    positives = Parameter('positives', None, None, 'bool')
    positives_mean = QuantizedParameter('positives_mean', 1, [], 'float')
    negatives_mean = QuantizedParameter('negatives_mean', -1, [], 'float')

    def __init__(self,
                 session,
                 quantizer,
                 mean_quantizer=None,
                 should_update=True,
                 reg=False):
        super().__init__(session, should_update)
        cls, params = object_from_params(quantizer)
        self.quantizer = cls(session, **params)
        self.reg = reg
        if mean_quantizer:
            cls, params = object_from_params(mean_quantizer)
            self.parameter_quantizers = {
                'positives_mean': cls(session, **params),
                'negatives_mean': cls(session, **params),
            }
        else:
            self.parameter_quantizers = {}

    @memoize_property
    def negatives(self):
        return util.logical_not(self.positives)

    def assign_parameters(self):
        super().assign_parameters()
        self.quantizer.assign_parameters()
        for quantizer in self.parameter_quantizers.values():
            quantizer.assign_parameters()

    def _quantize(self, value):
        quantizer = self.quantizer
        scope = '{}/{}'.format(self._scope, self.__class__.__name__)
        return quantizer.apply(self.node, scope, self._original_getter, value)

    def _apply(self, value):
        # dynamic parameter configuration
        self._parameter_config = {
            'positives': {
                'initial': tf.ones_initializer(tf.bool),
                'shape': value.shape,
            },
        }
        positives = util.cast(self.positives, float)
        negatives = util.cast(self.negatives, float)
        non_zeros = util.cast(tf.not_equal(self.before, 0), float)

        positives_centralized = positives * (value - self.positives_mean)
        negatives_centralized = negatives * (value - self.negatives_mean)
        # keep a track of quantized value, without means
        self.quantized = self._quantize(
            non_zeros * (positives_centralized + negatives_centralized))
        quantized_value = non_zeros * positives * \
            (self.quantized + self.positives_mean)
        quantized_value += non_zeros * negatives * \
            (self.quantized + self.negatives_mean)

        self._quantization_loss_regularizer(value, quantized_value)
        return quantized_value

    def _quantization_loss_regularizer(self, value, quantized_value):
        if self.reg == 0.0:
            return
        loss = tf.reduce_sum(tf.abs(value - quantized_value))
        loss *= self.reg
        loss_name = tf.GraphKeys.REGULARIZATION_LOSSES
        tf.add_to_collection(loss_name, loss)

    def _update(self):
        # update positives mask and mean values
        value = self.session.run(self.before)
        # divide them into two groups
        # mean = util.mean(value)
        mean = 0.0
        # find two central points
        positives = value > mean
        self.positives = positives
        self.positives_mean = util.mean(value[util.where(positives)])
        negatives = util.logical_and(util.logical_not(positives), value != 0)
        self.negatives_mean = util.mean(value[util.where(negatives)])
        if self.positives_mean.eval() == 0 or self.negatives_mean.eval() == 0:
            log.warn(
                'means are skewed, pos mean is {} and neg mean is {}'.format(
                    self.positives_mean.eval(), self.negatives_mean.eval()))
        # update internal quantizer
        self.quantizer.update()
        for quantizer in self.parameter_quantizers.values():
            quantizer.update()

    def _info(self):
        info = self.quantizer.info()._asdict()
        for name, quantizer in self.parameter_quantizers.items():
            param_info = quantizer.info()
            param_info = {
                '{}_{}'.format(name, key): value
                for key, value in param_info._asdict().items()
            }
            info.update(param_info)
        info.pop('name')
        return self._info_tuple(**info)
예제 #12
0
파일: incremental.py 프로젝트: zaf05/mayo
class IncrementalQuantizer(OverriderBase):
    """
    https://arxiv.org/pdf/1702.03044.pdf
    """
    mask = Parameter('mask', None, None, 'bool')

    def __init__(self,
                 session,
                 quantizer,
                 interval,
                 count_zero=True,
                 should_update=True,
                 enable=True):
        super().__init__(session, should_update, enable)
        cls, params = object_from_params(quantizer)
        self.quantizer = cls(session, **params)
        self.count_zero = count_zero
        self.interval = interval

    def _quantize(self, value, mean_quantizer=False):
        quantizer = self.quantizer
        scope = '{}/{}'.format(self._scope, self.__class__.__name__)
        return quantizer.apply(self.node, scope, self._original_getter, value)

    def _apply(self, value):
        self._parameter_config = {
            'mask': {
                'initial': tf.zeros_initializer(tf.bool),
                'shape': value.shape,
            }
        }
        quantized_value = self._quantize(value)
        off_mask = util.cast(util.logical_not(self.mask), float)
        mask = util.cast(self.mask, float)
        # on mask indicates the quantized values
        return value * off_mask + quantized_value * mask

    def _policy(self, value, quantized, previous_mask, interval):
        previous_pruned = util.sum(previous_mask)
        if self.count_zero:
            th_arg = util.cast(util.count(value) * interval, int)
        else:
            tmp = util.count(value[value != 0])
            flat_value_arg = util.where(value.flatten() != 0)
            th_arg = util.cast(tmp * interval, int)
        if th_arg < 0:
            raise ValueError('mask has {} elements, interval is {}'.format(
                previous_pruned, interval))
        off_mask = util.cast(util.logical_not(util.cast(previous_mask, bool)),
                             float)
        metric = value - quantized
        flat_value = (metric * off_mask).flatten()
        if interval >= 1.0:
            th = flat_value.max() + 1.0
        else:
            if self.count_zero:
                th = util.top_k(util.abs(flat_value), th_arg)
            else:
                th = util.top_k(util.abs(flat_value[flat_value_arg]), th_arg)
        th = util.cast(th, float)
        new_mask = util.logical_not(util.greater_equal(util.abs(metric), th))
        return util.logical_or(new_mask, previous_mask)

    # override assign_parameters to assign quantizer as well
    def assign_parameters(self):
        super().assign_parameters()
        self.quantizer.assign_parameters()

    def _update(self):
        # reset index
        self.quantizer.update()
        # if chosen quantized, change it to zeros
        value, quantized, mask = self.session.run(
            [self.before, self.quantizer.after, self.mask])
        new_mask = self._policy(value, quantized, mask, self.interval)
        self.session.assign(self.mask, new_mask)

    def dump(self):
        return self.quantizer.dump()

    def _info(self):
        return self.quantizer._info()
예제 #13
0
class FloatingPointQuantizer(QuantizerBase):
    """
    Minifloat quantization.

    When exponent_width is 0, the floating-point value is a degenerate case
    where exponent is always a constant bias, equivalent to fixed-point with a
    sign-magnitude representation.

    When mantissa_width is 0, the floating-point value is a degenerate
    case where mantissa is always 1, equivalent to shifts with only 2^n
    representations.

    When both exponent_width and mantissa_width are 0, the quantized value can
    only represent $2^{-bias}$ or 0, which is not very useful.
    """
    width = Parameter('width', 32, [], 'float')
    exponent_bias = Parameter('exponent_bias', -127, [], 'float')
    mantissa_width = Parameter('mantissa_width', 23, [], 'float')

    def __init__(self,
                 session,
                 width,
                 exponent_bias,
                 mantissa_width,
                 overflow_rate=0.0,
                 should_update=True,
                 stochastic=None):
        super().__init__(session, should_update)
        self.width = width
        self.exponent_bias = exponent_bias
        self.mantissa_width = mantissa_width
        self.overflow_rate = overflow_rate
        self.stochastic = stochastic
        exponent_width = width - mantissa_width
        is_valid = exponent_width >= 0 and mantissa_width >= 0
        is_valid = is_valid and (not (exponent_width == 0
                                      and mantissa_width == 0))
        if not is_valid:
            raise ValueError(
                'We expect exponent_width >= 0 and mantissa_width >= 0 '
                'where equalities must be exclusive.')

    def _decompose(self, value, exponent_bias=None):
        """
        Decompose a single-precision floating-point into
        sign, exponent and mantissa components.
        """
        if exponent_bias is None:
            exponent_bias = self.exponent_bias
        # smallest non-zero floating point
        descriminator = (2**(-exponent_bias)) / 2
        sign = util.cast(value > descriminator, int)
        sign -= util.cast(value < -descriminator, int)
        value = util.abs(value)
        exponent = util.floor(util.log(value, 2))
        mantissa = value / (2**exponent)
        return sign, exponent, mantissa

    def _transform(self,
                   sign,
                   exponent,
                   mantissa,
                   exponent_width=None,
                   mantissa_width=None,
                   exponent_bias=None):
        if exponent_bias is None:
            exponent_bias = self.exponent_bias
        if exponent_width is None:
            exponent_width = self.width - self.mantissa_width
        if mantissa_width is None:
            mantissa_width = self.mantissa_width
        # clip exponent and quantize mantissa
        exponent_min = -exponent_bias
        exponent_max = 2**exponent_width - 1 - exponent_bias
        exponent = util.clip_by_value(exponent, exponent_min, exponent_max)
        shift = util.cast(2**mantissa_width, float)
        # quantize
        if self.stochastic:
            mantissa = util.stochastic_round(mantissa * shift, self.stochastic)
            mantissa /= shift
        else:
            mantissa = util.round(mantissa * shift) / shift

        # if the mantissa value gets rounded to >= 2 then we need to divide it
        # by 2 and increment exponent by 1
        is_out_of_range = util.greater_equal(mantissa, 2)
        mantissa = util.where(is_out_of_range, mantissa / 2, mantissa)
        exponent = util.where(is_out_of_range, exponent + 1, exponent)
        return sign, exponent, mantissa

    def _represent(self, sign, exponent, mantissa):
        """
        Represent the value in floating-point using
        sign, exponent and mantissa.
        """
        value = util.cast(sign, float) * (2.0**exponent) * mantissa
        if util.is_constant(sign, exponent, mantissa):
            return value
        if util.is_numpy(sign, exponent, mantissa):
            zeros = np.zeros(sign.shape, dtype=np.int32)
        else:
            zeros = tf.zeros(sign.shape, dtype=tf.int32)
        is_zero = util.equal(sign, zeros)
        return util.where(is_zero, util.cast(zeros, float), value)

    def _quantize(self,
                  value,
                  exponent_width=None,
                  mantissa_width=None,
                  exponent_bias=None):
        sign, exponent, mantissa = self._decompose(value, exponent_bias)
        sign, exponent, mantissa = self._transform(sign, exponent, mantissa,
                                                   exponent_width,
                                                   mantissa_width,
                                                   exponent_bias)
        return self._represent(sign, exponent, mantissa)

    def _apply(self, value):
        quantized = self._quantize(value)
        nan = tf.reduce_sum(tf.cast(tf.is_nan(quantized), tf.int32))
        assertion = tf.Assert(tf.equal(nan, 0), [nan])
        with tf.control_dependencies([assertion]):
            return value + tf.stop_gradient(quantized - value)

    def _bias(self, value, exponent_width):
        max_exponent = int(2**exponent_width)
        for exponent in range(min(-max_exponent, -4), max(max_exponent, 4)):
            max_value = 2**(exponent + 1)
            overflows = util.logical_or(value < -max_value, value > max_value)
            if self._overflow_rate(overflows) <= self.overflow_rate:
                break
        return 2**exponent_width - 1 - exponent

    def compute_quantization_loss(self, value, exponent_width, mantissa_width,
                                  overflow_rate):
        exponent_bias = self._bias(value, exponent_width)
        quantized = self._quantize(value, exponent_width, mantissa_width,
                                   exponent_bias)
        # mean squared loss
        loss = ((value - quantized)**2).mean()
        return (loss, exponent_bias)

    def _info(self):
        width = int(self.eval(self.width))
        mantissa_width = int(self.eval(self.mantissa_width))
        exponent_bias = int(self.eval(self.exponent_bias))
        return self._info_tuple(width=width,
                                mantissa_width=mantissa_width,
                                exponent_bias=exponent_bias)

    def _update(self):
        value = self.eval(self.before)
        exponent_width = self.eval(self.width) - self.eval(self.mantissa_width)
        self.exponent_bias = self._bias(value, exponent_width)
예제 #14
0
class LowRankApproximation(OverriderBase):
    singular = Parameter('singular', None, None, 'float')
    left = Parameter('left', None, None, 'float')
    right = Parameter('right', None, None, 'float')

    def __init__(self, session, should_update=True, ranks=0):
        super().__init__(session, should_update)
        # ranks to prune away
        self.ranks = ranks

    def _parameter_initial(self, value):
        dimensions = value.shape
        left_dimension = dimensions[0] * dimensions[2]
        right_dimension = dimensions[1] * dimensions[3]
        left_shape = (left_dimension, left_dimension)
        right_shape = (right_dimension, right_dimension)
        rows = int(left_dimension)
        columns = int(right_dimension)

        singular_shape = left_dimension if rows < columns else right_dimension

        self._parameter_config = {
            'singular': {
                'initial': tf.ones_initializer(dtype=tf.float32),
                'shape': singular_shape,
            },
            'left': {
                'initial': tf.ones_initializer(dtype=tf.float32),
                'shape': left_shape,
            },
            'right': {
                'initial': tf.ones_initializer(dtype=tf.float32),
                'shape': right_shape,
            }
        }
        return (rows, columns)

    def _apply(self, value):
        rows, columns = self._parameter_initial(value)
        if rows < columns:
            singular = tf.expand_dims(self.singular, 1) * tf.eye(rows, columns)
        else:
            singular = tf.expand_dims(self.singular, 0) * tf.eye(rows, columns)

        svd_construct = tf.matmul(tf.matmul(self.left, singular), self.right)
        return tf.reshape(svd_construct, value.shape)

    def _update(self):
        value = self.session.run(self.before)
        dimensions = value.shape
        if len(dimensions) == 4:
            meshed = np.reshape(
                value,
                [dimensions[0] * dimensions[2], dimensions[1] * dimensions[3]])
        elif len(dimensions) == 2:
            meshed = value
        else:
            raise ValueError('uh')
        left, singular, right = np.linalg.svd(meshed, full_matrices=True)
        singular[-self.ranks:] = 0.0
        self.session.assign(self.left, left)
        self.session.assign(self.singular, singular)
        self.session.assign(self.right, right)
예제 #15
0
파일: fixed.py 프로젝트: randysuen/mayo
class FixedPointQuantizer(QuantizerBase):
    """
    Quantize inputs into 2's compliment n-bit fixed-point values with d-bit
    dynamic range.

    Args:
        - width:
            The number of bits to use in number representation.
            If not specified, we do not limit the range of values.
        - point:
            The position of the binary point, counting from the LSB.

    References:
        [1] https://arxiv.org/pdf/1604.03168
    """
    width = Parameter('width', 32, [], 'int')
    point = Parameter('point', 2, [], 'int')

    def __init__(self,
                 session,
                 point=None,
                 width=None,
                 should_update=True,
                 stochastic=None):
        super().__init__(session, should_update)
        if point is not None:
            self.point = point
        if width is not None:
            if width < 1:
                raise ValueError(
                    'Width of quantized value must be greater than 0.')
            self.width = width
        if stochastic is None:
            self.stochastic = False
        self.stochastic = stochastic

    def _quantize(self,
                  value,
                  point=None,
                  width=None,
                  compute_overflow_rate=False):
        point = util.cast(self.point if point is None else point, float)
        width = util.cast(self.width if width is None else width, float)
        # x << (width - point)
        shift = 2.0**(util.round(width) - util.round(point))
        value = value * shift
        # quantize
        if self.stochastic:
            value = util.stochastic_round(value, self.stochastic)
        else:
            value = util.round(value)
        # ensure number is representable without overflow
        if width is not None:
            max_value = util.cast(2**(width - 1), float)
            if compute_overflow_rate:
                overflow_value = value[value != 0]
                overflows = util.logical_or(overflow_value < -max_value,
                                            overflow_value > max_value - 1)
                return self._overflow_rate(overflows)
            value = util.clip_by_value(value, -max_value, max_value - 1)
        # revert bit-shift earlier
        return value / shift

    def _apply(self, value):
        return self._quantize(value)

    def _info(self):
        width = int(self.eval(self.width))
        point = int(self.eval(self.point))
        return self._info_tuple(width=width, point=point)
예제 #16
0
파일: mixed.py 프로젝트: zaf05/mayo
class MixedQuantizer(QuantizerBase):
    """
    Mixed Precision should be implemented as the following:
    mask1 * precision1 + mask2 * precision2 ...
    The masks are mutually exclusive
    Currently supporting:
        1. making a loss to the reg term
        2. quantizer_maps contains parallel quantizers that each can have
        a different quantizer
        3. channel wise granuarity based on output channels
    TODO:
    provide _update()
    """
    interval = Parameter('interval', 0.1, [], 'float')
    channel_mask = Parameter('channel_mask', None, None, 'int')

    def __init__(self,
                 session,
                 quantizers,
                 index=0,
                 should_update=True,
                 reg_factor=0.0,
                 interval=0.1):
        super().__init__(session, should_update)
        self.quantizer_maps = {}
        for key, item in dict(quantizers).items():
            cls, params = object_from_params(item)
            quantizer = cls(session, **params)
            self.quantizer_maps[key] = quantizer
        self.reg_factor = reg_factor
        # the quantizer that makes a loss for training
        self.quantizers = quantizers
        self.picked_quantizer = list(quantizers.keys())[index]
        # keep record of an index for update
        self.index = index

    def _apply(self, value):
        # making an quantization loss to reg loss
        self._parameter_config = {
            'channel_mask': {
                'initial': tf.zeros_initializer(tf.bool),
                'shape': value.shape[-1],
            }
        }
        quantized_values = self._quantize(value)
        self._quantization_loss(value, quantized_values[self.picked_quantizer])
        # on mask indicates the quantized values
        return self._combine_masks(value, quantized_values)

    def _quantize(self, value, mean_quantizer=False):
        quantized_values = {}
        for key, quantizer in dict(self.quantizer_maps).items():
            scope = '{}/{}'.format(self._scope, self.__class__.__name__ + key)
            quantized_values[key] = quantizer.apply(self.node, scope,
                                                    self._original_getter,
                                                    value)
        return quantized_values

    def _combine_masks(self, value, quantized_values):
        """
        Args:
            quantized_value: the current mask is working on this current
                quantized value, this value is not included in
                quantizer_maps
        """
        if self.quantizer_maps:
            index = 0
            for key, quantizer in self.quantizer_maps.items():
                mask_label = index + 1
                channel_mask = util.cast(
                    util.equal(self.channel_mask, mask_label), float)
                if index == 0:
                    result = quantized_values[key] * channel_mask
                else:
                    result += quantized_values[key] * channel_mask
        # now handel off_mask
        off_mask = util.cast(util.equal(self.channel_mask, 0), float)
        return value * off_mask + result

    def _quantization_loss(self, value, quantized_value):
        loss = tf.reduce_sum(tf.abs(value - quantized_value))
        loss *= self.reg_factor
        loss_name = tf.GraphKeys.REGULARIZATION_LOSSES
        tf.add_to_collection(loss_name, loss)

    def _new_mask(self, mask, value, quantized_value, interval):
        loss = util.abs(value - quantized_value)
        # check the ones that are not quantized
        mask = mask.reshape((1, 1, 1, mask.shape[0]))
        unquantized_mask = util.logical_not(mask)
        # TODO: mask shape is incorrect
        loss_vec = util.mean(loss * unquantized_mask, (0, 1, 2))
        # sort
        num_active = util.ceil(len(loss_vec) * interval)
        threshold = sorted(loss_vec)[num_active]
        if interval >= 1.0:
            return util.cast(unquantized_mask, float)
        new_mask = (unquantized_mask * loss) > threshold
        return util.cast(util.logical_or(new_mask, mask), float)

    def _update(self):
        # update only the selected index
        quantizer = self.quantizer_maps[self.picked_quantizer]
        mask, value, quantized_value, interval = self.session.run(
            [self.channel_mask, self.before, quantizer.after, self.interval])
        mask = mask == (self.index + 1)
        new_mask = self._new_mask(mask, value, quantized_value, interval)
        self.index += 1
        self.picked_quantizer = list(self.quantizers.keys())[self.index]