def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
     with self.cached_session() as session:
         tf.global_variables_initializer().run()
         expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
         kronecker_product = pruning_utils.kronecker_product(
             tensor, tf.ones(block_dim))
         expanded_tensor_val, kronecker_product_val = session.run(
             [expanded_tensor, kronecker_product])
         self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
예제 #2
0
    def _maybe_update_block_mask(self, weights):
        """Performs block-granular masking of the weights.

    If sparsity_m_by_n is selected, then we return the relevant pruning mask,
    that nullify two out of four elements in the block.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.
    Args:
      weights: The weight tensor that needs to be masked.

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step. In case of sparsity m_by_n,
        the returned threshold is an arbitrary number.
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
        if self._sparsity_m_by_n:
            mask = self._update_mask_sparsity_m_by_n(weights,
                                                     self._sparsity_m_by_n)
            # We need to return some numbers for threshold.
            return 999.0, mask

        if self._block_size == [1, 1]:
            return self._update_mask(weights)

        # TODO(pulkitb): Check if squeeze operations should now be removed since
        # we are only accepting 2-D weights.

        squeezed_weights = tf.squeeze(weights)
        abs_weights = tf.math.abs(squeezed_weights)
        pooled_weights = pruning_utils.factorized_pool(
            abs_weights,
            window_shape=self._block_size,
            pooling_type=self._block_pooling_type,
            strides=self._block_size,
            padding='SAME')

        if pooled_weights.get_shape().ndims != 2:
            pooled_weights = tf.squeeze(pooled_weights)

        new_threshold, new_mask = self._update_mask(pooled_weights)

        updated_mask = pruning_utils.expand_tensor(new_mask, self._block_size)
        sliced_mask = tf.slice(
            updated_mask, [0, 0],
            [squeezed_weights.get_shape()[0],
             squeezed_weights.get_shape()[1]])
        return new_threshold, tf.reshape(sliced_mask, tf.shape(weights))