def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, MASKED_WEIGHT_NAME) # absolute value of gradients for gradient based pruning gradient = pruning_utils.weight_gradient_variable(x, scope) old_weight = pruning_utils.old_weight_variable(x, scope) old_old_weight = pruning_utils.old_old_weight_variable(x, scope) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in tf.get_collection_ref(MASK_COLLECTION): tf.add_to_collection(THRESHOLD_COLLECTION, threshold) tf.add_to_collection(MASK_COLLECTION, mask) tf.add_to_collection(MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(WEIGHT_COLLECTION, x) tf.add_to_collection(WEIGHT_GRADIENT_COLLECTION, gradient) tf.add_to_collection(OLD_WEIGHT_COLLECTION, old_weight) tf.add_to_collection(OLD_OLD_WEIGHT_COLLECTION, old_old_weight) return masked_weights
def apply_mask_and_return(x, scope='', prune_option='weight'): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". prune_option: pruning option. Defaults to 'weight'. option = 'first_order_gradient' means using |weight| * |first order gradient| for pruning. option = 'second_order_gradient' means using |weight| * |second order gradient| for pruning. Returns: masked_weights: a TensorFlow tensor representing masked weights. mask: a TensorFlow tensor representing the pruning mask. """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, MASKED_WEIGHT_NAME) if prune_option in ('first_order_gradient', 'second_order_gradient'): # absolute value of gradients for gradient based pruning gradient = pruning_utils.weight_gradient_variable(x, scope) old_weight = pruning_utils.old_weight_variable(x, scope) old_old_weight = pruning_utils.old_old_weight_variable(x, scope) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in tf.get_collection_ref(MASK_COLLECTION): tf.add_to_collection(THRESHOLD_COLLECTION, threshold) tf.add_to_collection(MASK_COLLECTION, mask) tf.add_to_collection(MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(WEIGHT_COLLECTION, x) if prune_option in ('first_order_gradient', 'second_order_gradient'): tf.add_to_collection(WEIGHT_GRADIENT_COLLECTION, gradient) tf.add_to_collection(OLD_WEIGHT_COLLECTION, old_weight) tf.add_to_collection(OLD_OLD_WEIGHT_COLLECTION, old_old_weight) return [masked_weights, mask]