def _get_mask_assign_ops(self): # Make sure the assignment ops have not already been added to the list if self._assign_ops: raise ValueError( 'Assign op list not empty. _get_mask_assign_ops() called twice?') masks = get_masks() weights = get_weights() thresholds = get_thresholds() if len(masks) != len(thresholds): raise ValueError( 'Number of masks %s and number of thresholds %s mismatch' % (len(masks), len(thresholds))) for index, mask in enumerate(masks): threshold = thresholds[index] weight = weights[index] is_partitioned = isinstance(weight, variables.PartitionedVariable) if is_partitioned: weight = weight.as_tensor() new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) self._assign_ops.append( pruning_utils.partitioned_variable_assign(mask, new_mask) if is_partitioned else pruning_utils.variable_assign(mask, new_mask))