def _get_assign_gradient_ops(self): # Make sure the assignment ops have not already been added to the list if self._assign_gradient_ops: raise ValueError( 'Assign op list not empty. _get_mask_assign_ops() called twice?') weights = get_weights() old_weights = get_old_weights() old_old_weights = get_old_old_weights() gradients = get_gradients() if len(weights) != len(old_weights): raise ValueError( 'Number of weights %s and number of old_weights %s mismatch' % (len(weights), len(old_weights))) if len(weights) != len(gradients): raise ValueError( 'Number of weights %s and number of gradients %s mismatch' % (len(weights), len(gradients))) for index, _ in enumerate(weights): weight = weights[index] old_weight = old_weights[index] old_old_weight = old_old_weights[index] gradient = gradients[index] if weight.shape.as_list() != old_weight.shape.as_list(): raise ValueError('weight tensor has different shape from old_weight') if weight.shape.as_list() != gradient.shape.as_list(): raise ValueError('weight tensor has different shape from gradient') if weight.shape.as_list() != old_old_weight.shape.as_list(): raise ValueError('weight tensor has different shape from old_weight') is_partitioned = isinstance(weight, variables.PartitionedVariable) if is_partitioned: weight = weight.as_tensor() old_weight = old_weight.as_tensor() old_old_weight = old_old_weight.as_tensor() decay = self._spec.gradient_decay_rate if self._spec.prune_option == 'first_order_gradient': tf.logging.info('Applying first order gradient pruning') normalized_weight_delta = tf.nn.l2_normalize( tf.abs(weight - old_weight)) elif self._spec.prune_option == 'second_order_gradient': tf.logging.info('Applying second order gradient pruning') normalized_weight_delta = tf.nn.l2_normalize( tf.abs(weight + old_old_weight - 2 * old_weight)) else: raise ValueError('Unknown prune option. Should not execute this code.') new_gradient = decay * gradient + (1 - decay) * normalized_weight_delta self._assign_gradient_ops.append( pruning_utils.variable_assign(gradient, new_gradient))
def _get_mask_assign_ops(self): # Make sure the assignment ops have not already been added to the list if self._assign_ops: raise ValueError( 'Assign op list not empty. _get_mask_assign_ops() called twice?') masks = get_masks() weights = get_weights() thresholds = get_thresholds() gradients = get_gradients() if len(masks) != len(thresholds): raise ValueError( 'Number of masks %s and number of thresholds %s mismatch' % (len(masks), len(thresholds))) for index, mask in enumerate(masks): threshold = thresholds[index] weight = weights[index] if self._spec.prune_option in ('first_order_gradient', 'second_order_gradient'): gradient = gradients[index] else: gradient = None is_partitioned = isinstance(weight, variables.PartitionedVariable) if is_partitioned: weight = weight.as_tensor() new_threshold, new_mask = self._maybe_update_block_mask( weight, threshold, gradient) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) self._assign_ops.append( pruning_utils.partitioned_variable_assign(mask, new_mask) if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
def _get_assign_old_old_weight_ops(self): if self._assign_old_old_weight_ops: raise ValueError( 'Assign op list not empty. _get_old_old_weight_assign_ops() called twice?' ) old_old_weights = get_old_old_weights() old_weights = get_old_weights() if len(old_old_weights) != len(old_weights): raise ValueError( 'Number of old_old_weights %s and number of old_weights %s mismatch' % (len(old_old_weights), len(old_weights))) for index, old_old_weight in enumerate(old_old_weights): old_weight = old_weights[index] self._assign_old_old_weight_ops.append( pruning_utils.variable_assign(old_old_weight, old_weight))