Example #1
0
    def _maybe_update_block_mask(self, weights):
        """Performs block-granular masking of the weights.

    If sparsity_m_by_n is selected, then we return the relevant pruning mask,
    that nullify two out of four elements in the block.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step. In case of sparsity m_by_n,
        the returned threshold is an arbitrary number.
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
        if self._sparsity_m_by_n:
            mask = self._update_mask_sparsity_m_by_n(weights,
                                                     self._sparsity_m_by_n)
            # We need to return some numbers for threshold.
            return 999.0, mask

        if self._block_size == [1, 1]:
            return self._update_mask(weights)

        # TODO(pulkitb): Check if squeeze operations should now be removed since
        # we are only accepting 2-D weights.

        squeezed_weights = tf.squeeze(weights)
        abs_weights = tf.math.abs(squeezed_weights)
        pooled_weights = pruning_utils.factorized_pool(
            abs_weights,
            window_shape=self._block_size,
            pooling_type=self._block_pooling_type,
            strides=self._block_size,
            padding='SAME')

        if pooled_weights.get_shape().ndims != 2:
            pooled_weights = tf.squeeze(pooled_weights)

        new_threshold, new_mask = self._update_mask(pooled_weights)

        updated_mask = pruning_utils.expand_tensor(new_mask, self._block_size)
        sliced_mask = tf.slice(
            updated_mask, [0, 0],
            [squeezed_weights.get_shape()[0],
             squeezed_weights.get_shape()[1]])
        return new_threshold, tf.reshape(sliced_mask, tf.shape(weights))
 def _compare_pooling_methods(self, weights, pooling_kwargs):
   with self.cached_session():
     variables.global_variables_initializer().run()
     pooled_weights_tf = array_ops.squeeze(
         nn_ops.pool(
             array_ops.reshape(
                 weights,
                 [1, weights.get_shape()[0],
                  weights.get_shape()[1], 1]), **pooling_kwargs))
     pooled_weights_factorized_pool = pruning_utils.factorized_pool(
         weights, **pooling_kwargs)
     self.assertAllClose(pooled_weights_tf.eval(),
                         pooled_weights_factorized_pool.eval())
Example #3
0
 def _compare_pooling_methods(self, weights, pooling_kwargs):
   with self.cached_session():
     compat.initialize_variables(self)
     pooled_weights_tf = tf.squeeze(
         tf.nn.pool(
             tf.reshape(
                 weights,
                 [1, weights.get_shape()[0],
                  weights.get_shape()[1], 1]), **pooling_kwargs))
     pooled_weights_factorized_pool = pruning_utils.factorized_pool(
         weights, **pooling_kwargs)
     self.assertAllClose(self.evaluate(pooled_weights_tf),
                         self.evaluate(pooled_weights_factorized_pool))