def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in ops.get_collection_ref(_MASK_COLLECTION): ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) ops.add_to_collection(_MASK_COLLECTION, mask) ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) ops.add_to_collection(_WEIGHT_COLLECTION, x) return masked_weights