Пример #1
0
    def testPruneWrapperAllowsOnlyValidPoolingType(self):
        layer = layers.Dense(10)
        with self.assertRaises(ValueError):
            pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MIN')

        pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='AVG')
        pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
Пример #2
0
    def _add_pruning_wrapper(layer):
        if isinstance(layer, keras.Model):
            # Check whether the model is a subclass model.
            if (not layer._is_graph_network
                    and not isinstance(layer, keras.models.Sequential)):
                raise ValueError(
                    'Subclassed models are not supported currently.')

            return keras.models.clone_model(
                layer, input_tensors=None, clone_function=_add_pruning_wrapper)
        if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
            return layer
        return pruning_wrapper.PruneLowMagnitude(layer, **params)
Пример #3
0
    def _prune_list(layers, **params):
        wrapped_layers = []

        for layer in layers:
            # Allow layer that is already wrapped by the pruning wrapper
            # to be used as is.
            # No need to wrap the input layer either.
            if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
                wrapped_layers.append(layer)
            elif isinstance(layer, keras.layers.InputLayer):
                # TODO(yunluli): Replace with a clone function in keras.
                wrapped_layers.append(
                    layer.__class__.from_config(layer.get_config()))
            else:
                wrapped_layers.append(
                    pruning_wrapper.PruneLowMagnitude(layer, **params))

        return wrapped_layers
Пример #4
0
 def testCustomLayerPrunable(self):
     layer = CustomLayerPrunable(input_dim=16, output_dim=32)
     inputs = keras.layers.Input(shape=(16))
     _ = layer(inputs)
     pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
Пример #5
0
 def testCustomLayerNonPrunable(self):
     layer = CustomLayer(input_dim=16, output_dim=32)
     inputs = keras.layers.Input(shape=(16))
     _ = layer(inputs)
     with self.assertRaises(ValueError):
         pruning_wrapper.PruneLowMagnitude(layer, block_pooling_type='MAX')
Пример #6
0
def prune_low_magnitude(to_prune,
                        pruning_schedule=pruning_sched.ConstantSparsity(
                            0.5, 0),
                        block_size=(1, 1),
                        block_pooling_type='AVG',
                        pruning_policy=None,
                        **kwargs):
    """Modify a tf.keras layer or model to be pruned during training.

  This function wraps a tf.keras model or layer with pruning functionality which
  sparsifies the layer's weights during training. For example, using this with
  50% sparsity will ensure that 50% of the layer's weights are zero.

  The function accepts either a single keras layer
  (subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
  or Functional tf.keras model and handles them appropriately.

  If it encounters a layer it does not know how to handle, it will throw an
  error. While pruning an entire model, even a single unknown layer would lead
  to an error.

  Prune a model:

  ```python
  pruning_params = {
      'pruning_schedule': ConstantSparsity(0.5, 0),
      'block_size': (1, 1),
      'block_pooling_type': 'AVG'
  }

  model = prune_low_magnitude(
      keras.Sequential([
          layers.Dense(10, activation='relu', input_shape=(100,)),
          layers.Dense(2, activation='sigmoid')
      ]), **pruning_params)
  ```

  Prune a layer:

  ```python
  pruning_params = {
      'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
          final_sparsity=0.8, begin_step=1000, end_step=2000),
      'block_size': (2, 3),
      'block_pooling_type': 'MAX'
  }

  model = keras.Sequential([
      layers.Dense(10, activation='relu', input_shape=(100,)),
      prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
  ])
  ```

  Pretrained models: you must first load the weights and then apply the
  prune API:

  ```python
  model.load_weights(...)
  model = prune_low_magnitude(model)
  ```

  Optimizer: this function removes the optimizer. The user is expected to
  compile the model
  again. It's easiest to rely on the default (step starts at 0) and then
  use that to determine the desired begin_step for the pruning_schedules.

  Checkpointing: checkpointing should include the optimizer, not just the
  weights. Pruning supports
  checkpointing though
  upon inspection, the weights of checkpoints are not sparse
  (https://github.com/tensorflow/model-optimization/issues/206).

  Arguments:
      to_prune: A single keras layer, list of keras layers, or a
        `tf.keras.Model` instance.
      pruning_schedule: A `PruningSchedule` object that controls pruning rate
        throughout training.
      block_size: (optional) The dimensions (height, weight) for the block
        sparse pattern in rank-2 weight tensors.
      block_pooling_type: (optional) The function to use to pool weights in the
        block. Must be 'AVG' or 'MAX'.
      pruning_policy: (optional) The object that controls to which layers
        `PruneLowMagnitude` wrapper will be applied. This API is experimental
        and is subject to change.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
        Ignored when to_prune is not a keras layer.

  Returns:
    Layer or model modified with pruning wrappers. Optimizer is removed.

  Raises:
    ValueError: if the keras layer is unsupported, or the keras model contains
    an unsupported layer.
  """
    def _prune_list(layers, **params):
        wrapped_layers = []

        for layer in layers:
            # Allow layer that is already wrapped by the pruning wrapper
            # to be used as is.
            # No need to wrap the input layer either.
            if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
                wrapped_layers.append(layer)
            elif isinstance(layer, keras.layers.InputLayer):
                # TODO(yunluli): Replace with a clone function in keras.
                wrapped_layers.append(
                    layer.__class__.from_config(layer.get_config()))
            else:
                wrapped_layers.append(
                    pruning_wrapper.PruneLowMagnitude(layer, **params))

        return wrapped_layers

    def _add_pruning_wrapper(layer):
        if isinstance(layer, keras.Model):
            # Check whether the model is a subclass model.
            if (not layer._is_graph_network
                    and not isinstance(layer, keras.models.Sequential)):
                raise ValueError(
                    'Subclassed models are not supported currently.')

            return keras.models.clone_model(
                layer, input_tensors=None, clone_function=_add_pruning_wrapper)
        if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
            return layer
        if pruning_policy and not pruning_policy.allow_pruning(layer):
            return layer
        else:
            return pruning_wrapper.PruneLowMagnitude(layer, **params)

    params = {
        'pruning_schedule': pruning_schedule,
        'block_size': block_size,
        'block_pooling_type': block_pooling_type
    }

    is_sequential_or_functional = isinstance(
        to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential)
                                    or to_prune._is_graph_network)

    # A subclassed model is also a subclass of keras.layers.Layer.
    is_keras_layer = isinstance(
        to_prune,
        keras.layers.Layer) and not isinstance(to_prune, keras.Model)

    if isinstance(to_prune, list):
        return _prune_list(to_prune, **params)
    elif is_sequential_or_functional:
        if pruning_policy:
            pruning_policy.ensure_model_supports_pruning(to_prune)
        return _add_pruning_wrapper(to_prune)
    elif is_keras_layer:
        params.update(kwargs)
        return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
    else:
        raise ValueError(
            '`prune_low_magnitude` can only prune an object of the following '
            'types: tf.keras.models.Sequential, tf.keras functional model, '
            'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
            'an object of type: {input}.'.format(
                input=to_prune.__class__.__name__))
Пример #7
0
def prune_low_magnitude(to_prune,
                        pruning_schedule=pruning_sched.ConstantSparsity(
                            0.5, 0),
                        block_size=(1, 1),
                        block_pooling_type='AVG',
                        **kwargs):
    """Modify a keras layer or model to be pruned during training.

  This function wraps a tf.keras model or layer with pruning functionality which
  sparsifies the layer's weights during training. For example, using this with
  50% sparsity will ensure that 50% of the layer's weights are zero.

  The function accepts either a single keras layer
  (subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
  or Functional keras model and handles them appropriately.

  If it encounters a layer it does not know how to handle, it will throw an
  error. While pruning an entire model, even a single unknown layer would lead
  to an error.

  Prune a model:

  ```python
  pruning_params = {
      'pruning_schedule': ConstantSparsity(0.5, 0),
      'block_size': (1, 1),
      'block_pooling_type': 'AVG'
  }

  model = prune_low_magnitude(
      keras.Sequential([
          layers.Dense(10, activation='relu', input_shape=(100,)),
          layers.Dense(2, activation='sigmoid')
      ]), **pruning_params)
  ```

  Prune a layer:

  ```python
  pruning_params = {
      'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
          final_sparsity=0.8, begin_step=1000, end_step=2000),
      'block_size': (2, 3),
      'block_pooling_type': 'MAX'
  }

  model = keras.Sequential([
      layers.Dense(10, activation='relu', input_shape=(100,)),
      prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
  ])
  ```

  Arguments:
      to_prune: A single keras layer, list of keras layers, or a
        `tf.keras.Model` instance.
      pruning_schedule: A `PruningSchedule` object that controls pruning rate
        throughout training.
      block_size: (optional) The dimensions (height, weight) for the block
        sparse pattern in rank-2 weight tensors.
      block_pooling_type: (optional) The function to use to pool weights in the
        block. Must be 'AVG' or 'MAX'.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
        Ignored when to_prune is not a keras layer.

  Returns:
    Layer or model modified with pruning wrappers.

  Raises:
    ValueError: if the keras layer is unsupported, or the keras model contains
    an unsupported layer.
  """
    def _prune_list(layers, **params):
        wrapped_layers = []

        for layer in layers:
            # Allow layer that is already wrapped by the pruning wrapper
            # to be used as is.
            # No need to wrap the input layer either.
            if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
                wrapped_layers.append(layer)
            elif isinstance(layer, InputLayer):
                # TODO(yunluli): Replace with a clone function in keras.
                wrapped_layers.append(
                    layer.__class__.from_config(layer.get_config()))
            else:
                wrapped_layers.append(
                    pruning_wrapper.PruneLowMagnitude(layer, **params))

        return wrapped_layers

    def _add_pruning_wrapper(layer):
        if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
            return layer
        return pruning_wrapper.PruneLowMagnitude(layer, **params)

    params = {
        'pruning_schedule': pruning_schedule,
        'block_size': block_size,
        'block_pooling_type': block_pooling_type
    }
    is_sequential_or_functional = isinstance(
        to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential)
                                    or to_prune._is_graph_network)

    # A subclassed model is also a subclass of keras.layers.Layer.
    is_keras_layer = isinstance(
        to_prune,
        keras.layers.Layer) and not isinstance(to_prune, keras.Model)

    if isinstance(to_prune, list):
        return _prune_list(to_prune, **params)
    elif is_sequential_or_functional:
        return keras.models.clone_model(to_prune,
                                        input_tensors=None,
                                        clone_function=_add_pruning_wrapper)
    elif is_keras_layer:
        params.update(kwargs)
        return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
    else:
        raise ValueError(
            '`prune_low_magnitude` can only prune an object of the following '
            'types: tf.keras.models.Sequential, tf.keras functional model, '
            'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
            'an object of type: {input}.'.format(
                input=to_prune.__class__.__name__))
Пример #8
0
 def _add_pruning_wrapper(layer):
     if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
         return layer
     return pruning_wrapper.PruneLowMagnitude(layer, **params)