def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): with self.cached_session() as session: tf.global_variables_initializer().run() expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) kronecker_product = pruning_utils.kronecker_product( tensor, tf.ones(block_dim)) expanded_tensor_val, kronecker_product_val = session.run( [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
def _maybe_update_block_mask(self, weights): """Performs block-granular masking of the weights. If sparsity_m_by_n is selected, then we return the relevant pruning mask, that nullify two out of four elements in the block. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step. In case of sparsity m_by_n, the returned threshold is an arbitrary number. new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ if self._sparsity_m_by_n: mask = self._update_mask_sparsity_m_by_n(weights, self._sparsity_m_by_n) # We need to return some numbers for threshold. return 999.0, mask if self._block_size == [1, 1]: return self._update_mask(weights) # TODO(pulkitb): Check if squeeze operations should now be removed since # we are only accepting 2-D weights. squeezed_weights = tf.squeeze(weights) abs_weights = tf.math.abs(squeezed_weights) pooled_weights = pruning_utils.factorized_pool( abs_weights, window_shape=self._block_size, pooling_type=self._block_pooling_type, strides=self._block_size, padding='SAME') if pooled_weights.get_shape().ndims != 2: pooled_weights = tf.squeeze(pooled_weights) new_threshold, new_mask = self._update_mask(pooled_weights) updated_mask = pruning_utils.expand_tensor(new_mask, self._block_size) sliced_mask = tf.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1]]) return new_threshold, tf.reshape(sliced_mask, tf.shape(weights))