예제 #1
0
 def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
   with self.cached_session() as session:
     variables.global_variables_initializer().run()
     expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
     kronecker_product = pruning_utils.kronecker_product(
         tensor, array_ops.ones(block_dim))
     expanded_tensor_val, kronecker_product_val = session.run(
         [expanded_tensor, kronecker_product])
     self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
예제 #2
0
 def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
     with self.test_session() as session:
         variables.global_variables_initializer().run()
         expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
         kronecker_product = pruning_utils.kronecker_product(
             tensor, array_ops.ones(block_dim))
         expanded_tensor_val, kronecker_product_val = session.run(
             [expanded_tensor, kronecker_product])
         self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
예제 #3
0
    def _maybe_update_block_mask(self, weights, threshold):
        """Performs block-granular masking of the weights.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
        squeezed_weights = array_ops.squeeze(weights)
        if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [
                1, 1
        ]:
            return self._update_mask(weights, threshold)

        if self._block_pooling_function not in ['AVG', 'MAX']:
            raise ValueError(
                'Unknown pooling function for block sparsity: %s' %
                self._block_pooling_function)

        with ops.name_scope(weights.op.name + '_pruning_ops'):
            abs_weights = math_ops.abs(squeezed_weights)

            pool_window = [self._block_dim[0], self._block_dim[1]]
            pool_fn = pruning_utils.factorized_pool
            squeeze_axis = None
            if not self._spec.use_tpu:
                pool_fn = nn_ops.pool
                abs_weights = array_ops.reshape(abs_weights, [
                    1,
                    abs_weights.get_shape()[0],
                    abs_weights.get_shape()[1], 1
                ])
                squeeze_axis = [0, 3]

            pooled_weights = pool_fn(abs_weights,
                                     window_shape=pool_window,
                                     pooling_type=self._block_pooling_function,
                                     strides=pool_window,
                                     padding='SAME',
                                     name=weights.op.name + '_pooled')

            if pooled_weights.get_shape().ndims != 2:
                pooled_weights = array_ops.squeeze(pooled_weights,
                                                   axis=squeeze_axis)

            smoothed_threshold, new_mask = self._update_mask(
                pooled_weights, threshold)

            updated_mask = pruning_utils.expand_tensor(new_mask,
                                                       self._block_dim)
            sliced_mask = array_ops.slice(updated_mask, [0, 0], [
                squeezed_weights.get_shape()[0],
                squeezed_weights.get_shape()[1]
            ])

        return smoothed_threshold, array_ops.reshape(sliced_mask,
                                                     array_ops.shape(weights))
예제 #4
0
  def _maybe_update_block_mask(self, weights, threshold):
    """Performs block-granular masking of the weights.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
    squeezed_weights = array_ops.squeeze(weights)
    if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
      return self._update_mask(weights, threshold)

    if self._block_pooling_function not in ['AVG', 'MAX']:
      raise ValueError('Unknown pooling function for block sparsity: %s' %
                       self._block_pooling_function)

    with ops.name_scope(weights.op.name + '_pruning_ops'):
      abs_weights = math_ops.abs(squeezed_weights)

      pool_window = [self._block_dim[0], self._block_dim[1]]
      pool_fn = pruning_utils.factorized_pool
      squeeze_axis = None
      if not self._spec.use_tpu:
        pool_fn = nn_ops.pool
        abs_weights = array_ops.reshape(
            abs_weights,
            [1, abs_weights.get_shape()[0],
             abs_weights.get_shape()[1], 1])
        squeeze_axis = [0, 3]

      pooled_weights = pool_fn(
          abs_weights,
          window_shape=pool_window,
          pooling_type=self._block_pooling_function,
          strides=pool_window,
          padding='SAME',
          name=weights.op.name + '_pooled')

      if pooled_weights.get_shape().ndims != 2:
        pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis)

      smoothed_threshold, new_mask = self._update_mask(pooled_weights,
                                                       threshold)

      updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
      sliced_mask = array_ops.slice(
          updated_mask, [0, 0],
          [squeezed_weights.get_shape()[0],
           squeezed_weights.get_shape()[1]])

    return smoothed_threshold, array_ops.reshape(sliced_mask,
                                                 array_ops.shape(weights))