def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
   with self.cached_session() as session:
     variables.global_variables_initializer().run()
     expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
     kronecker_product = pruning_utils.kronecker_product(
         tensor, array_ops.ones(block_dim))
     expanded_tensor_val, kronecker_product_val = session.run(
         [expanded_tensor, kronecker_product])
     self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
 def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
     with self.test_session() as session:
         variables.global_variables_initializer().run()
         expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
         kronecker_product = pruning_utils.kronecker_product(
             tensor, array_ops.ones(block_dim))
         expanded_tensor_val, kronecker_product_val = session.run(
             [expanded_tensor, kronecker_product])
         self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
Exemple #3
0
  def _maybe_update_block_mask(self, weights, threshold):
    """Performs block-granular masking of the weights.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
    squeezed_weights = array_ops.squeeze(weights)
    if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
      return self._update_mask(weights, threshold)

    if self._block_pooling_function not in ['AVG', 'MAX']:
      raise ValueError('Unknown pooling function for block sparsity: %s' %
                       self._block_pooling_function)

    with ops.name_scope(weights.op.name + '_pruning_ops'):
      abs_weights = math_ops.abs(squeezed_weights)

      pool_window = [self._block_dim[0], self._block_dim[1]]
      pool_fn = pruning_utils.factorized_pool

      if not self._spec.use_tpu:
        pool_fn = nn_ops.pool
        abs_weights = array_ops.reshape(
            abs_weights,
            [1, abs_weights.get_shape()[0],
             abs_weights.get_shape()[1], 1])

      pooled_weights = pool_fn(
          abs_weights,
          window_shape=pool_window,
          pooling_type=self._block_pooling_function,
          strides=pool_window,
          padding='SAME',
          name=weights.op.name + '_pooled')

      if pooled_weights.get_shape().ndims != 2:
        pooled_weights = array_ops.squeeze(pooled_weights)

      smoothed_threshold, new_mask = self._update_mask(pooled_weights,
                                                       threshold)
      updated_mask = pruning_utils.kronecker_product(
          new_mask, array_ops.ones(self._block_dim))
      sliced_mask = array_ops.slice(
          updated_mask, [0, 0],
          [squeezed_weights.get_shape()[0],
           squeezed_weights.get_shape()[1]])

    return smoothed_threshold, array_ops.reshape(sliced_mask,
                                                 array_ops.shape(weights))
Exemple #4
0
  def _maybe_update_block_mask(self, weights, threshold):
    """Performs block-granular masking of the weights.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
    squeezed_weights = array_ops.squeeze(weights)
    if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
      return self._update_mask(weights, threshold)

    if self._block_pooling_function not in ['AVG', 'MAX']:
      raise ValueError('Unknown pooling function for block sparsity: %s' %
                       self._block_pooling_function)

    with ops.name_scope(weights.op.name + '_pruning_ops'):
      abs_weights = math_ops.abs(squeezed_weights)

      pool_window = [self._block_dim[0], self._block_dim[1]]
      pool_fn = pruning_utils.factorized_pool

      if not self._spec.use_tpu:
        pool_fn = nn_ops.pool
        abs_weights = array_ops.reshape(
            abs_weights,
            [1, abs_weights.get_shape()[0],
             abs_weights.get_shape()[1], 1])

      pooled_weights = pool_fn(
          abs_weights,
          window_shape=pool_window,
          pooling_type=self._block_pooling_function,
          strides=pool_window,
          padding='SAME',
          name=weights.op.name + '_pooled')

      if pooled_weights.get_shape().ndims != 2:
        pooled_weights = array_ops.squeeze(pooled_weights)

      smoothed_threshold, new_mask = self._update_mask(pooled_weights,
                                                       threshold)
      updated_mask = pruning_utils.kronecker_product(
          new_mask, array_ops.ones(self._block_dim))
      sliced_mask = array_ops.slice(
          updated_mask, [0, 0],
          [squeezed_weights.get_shape()[0],
           squeezed_weights.get_shape()[1]])

    return smoothed_threshold, array_ops.reshape(sliced_mask,
                                                 array_ops.shape(weights))