def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): with self.cached_session() as session: variables.global_variables_initializer().run() expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) kronecker_product = pruning_utils.kronecker_product( tensor, array_ops.ones(block_dim)) expanded_tensor_val, kronecker_product_val = session.run( [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): with self.test_session() as session: variables.global_variables_initializer().run() expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) kronecker_product = pruning_utils.kronecker_product( tensor, array_ops.ones(block_dim)) expanded_tensor_val, kronecker_product_val = session.run( [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
def _maybe_update_block_mask(self, weights, threshold): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ squeezed_weights = array_ops.squeeze(weights) if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [ 1, 1 ]: return self._update_mask(weights, threshold) if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError( 'Unknown pooling function for block sparsity: %s' % self._block_pooling_function) with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(squeezed_weights) pool_window = [self._block_dim[0], self._block_dim[1]] pool_fn = pruning_utils.factorized_pool squeeze_axis = None if not self._spec.use_tpu: pool_fn = nn_ops.pool abs_weights = array_ops.reshape(abs_weights, [ 1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1 ]) squeeze_axis = [0, 3] pooled_weights = pool_fn(abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask( pooled_weights, threshold) updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim) sliced_mask = array_ops.slice(updated_mask, [0, 0], [ squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1] ]) return smoothed_threshold, array_ops.reshape(sliced_mask, array_ops.shape(weights))
def _maybe_update_block_mask(self, weights, threshold): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ squeezed_weights = array_ops.squeeze(weights) if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]: return self._update_mask(weights, threshold) if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError('Unknown pooling function for block sparsity: %s' % self._block_pooling_function) with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(squeezed_weights) pool_window = [self._block_dim[0], self._block_dim[1]] pool_fn = pruning_utils.factorized_pool squeeze_axis = None if not self._spec.use_tpu: pool_fn = nn_ops.pool abs_weights = array_ops.reshape( abs_weights, [1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1]) squeeze_axis = [0, 3] pooled_weights = pool_fn( abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim) sliced_mask = array_ops.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1]]) return smoothed_threshold, array_ops.reshape(sliced_mask, array_ops.shape(weights))