예제 #1
0
    def __init__(self, layer, quantize_config, **kwargs):
        """Create a quantize emulate wrapper for a keras layer.

    Args:
      layer: The keras layer to be quantized.
      quantize_config: `QuantizeConfig` to quantize layer.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
    """
        if layer is None:
            raise ValueError('`layer` cannot be None.')

        # Check against keras.Model since it is an instance of keras.layers.Layer.
        if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
                layer, tf.keras.Model):
            raise ValueError(
                '`layer` can only be a `tf.keras.layers.Layer` instance. '
                'You passed an instance of type: {input}.'.format(
                    input=layer.__class__.__name__))

        if quantize_config is None:
            raise ValueError('quantize_config cannot be None. It is needed to '
                             'quantize a layer.')

        if 'name' not in kwargs:
            kwargs['name'] = self._make_layer_name(layer)

        super(QuantizeWrapper, self).__init__(layer, **kwargs)
        self.quantize_config = quantize_config

        self._track_trackable(layer, name='layer')
        metrics.MonitorBoolGauge('quantize_wrapper_usage').set(
            layer.__class__.__name__)
예제 #2
0
    def setUp(self):
        super(MetricsTest, self).setUp()
        self.test_label = tf.keras.layers.Conv2D(1, 1).__class__.__name__
        for label in [
                self.test_label, metrics.MonitorBoolGauge._SUCCESS_LABEL,
                metrics.MonitorBoolGauge._FAILURE_LABEL
        ]:
            MetricsTest.gauge.get_cell(label).set(False)

        with mock.patch.object(metrics.MonitorBoolGauge,
                               'get_usage_gauge',
                               return_value=MetricsTest.gauge):
            self.monitor = metrics.MonitorBoolGauge('testing')
예제 #3
0
    def __init__(self,
                 layer,
                 pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
                 block_size=(1, 1),
                 block_pooling_type='AVG',
                 **kwargs):
        """Create a pruning wrapper for a keras layer.

    #TODO(pulkitb): Consider if begin_step should be 0 by default.

    Args:
      layer: The keras layer to be pruned.
      pruning_schedule: A `PruningSchedule` object that controls pruning rate
        throughout training.
      block_size: (optional) The dimensions (height, weight) for the block
        sparse pattern in rank-2 weight tensors.
      block_pooling_type: (optional) The function to use to pool weights in the
        block. Must be 'AVG' or 'MAX'.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
    """
        self.pruning_schedule = pruning_schedule
        self.block_size = block_size
        self.block_pooling_type = block_pooling_type

        # An instance of the Pruning class. This class contains the logic to prune
        # the weights of this layer.
        self.pruning_obj = None

        # A list of all (weight,mask,threshold) tuples for this layer
        self.pruning_vars = []

        if block_pooling_type not in ['AVG', 'MAX']:
            raise ValueError(
                'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
                .format(block_pooling_type))

        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `Prune` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))

        # TODO(pulkitb): This should be pushed up to the wrappers.py
        # Name the layer using the wrapper and underlying layer name.
        # Prune(Dense) becomes prune_dense_1
        kwargs.update({
            'name':
            '{}_{}'.format(
                generic_utils.to_snake_case(self.__class__.__name__),
                layer.name)
        })

        if isinstance(layer, prunable_layer.PrunableLayer) or hasattr(
                layer, 'get_prunable_weights'):
            # Custom layer in client code which supports pruning.
            super(PruneLowMagnitude, self).__init__(layer, **kwargs)
        elif prune_registry.PruneRegistry.supports(layer):
            # Built-in keras layers which support pruning.
            super(PruneLowMagnitude, self).__init__(
                prune_registry.PruneRegistry.make_prunable(layer), **kwargs)
        else:
            raise ValueError(
                'Please initialize `Prune` with a supported layer. Layers should '
                'either be supported by the PruneRegistry (built-in keras layers) or '
                'should be a `PrunableLayer` instance, or should has a customer '
                'defined `get_prunable_weights` method. You passed: '
                '{input}'.format(input=layer.__class__))

        self._track_trackable(layer, name='layer')

        # TODO(yunluli): Work-around to handle the first layer of Sequential model
        # properly. Can remove this when it is implemented in the Wrapper base
        # class.
        #
        # Enables end-user to prune the first layer in Sequential models, while
        # passing the input shape to the original layer.
        #
        # tf.keras.Sequential(
        #   prune_low_magnitude(tf.keras.layers.Dense(2, input_shape=(3,)))
        # )
        #
        # as opposed to
        #
        # tf.keras.Sequential(
        #   prune_low_magnitude(tf.keras.layers.Dense(2), input_shape=(3,))
        # )
        #
        # Without this code, the pruning wrapper doesn't have an input
        # shape and being the first layer, this causes the model to not be
        # built. Being not built is confusing since the end-user has passed an
        # input shape.
        if not hasattr(self, '_batch_input_shape') and hasattr(
                layer, '_batch_input_shape'):
            self._batch_input_shape = self.layer._batch_input_shape
        metrics.MonitorBoolGauge('prune_low_magnitude_wrapper_usage').set(
            layer.__class__.__name__)