예제 #1
0
  def _get_assign_gradient_ops(self):
    # Make sure the assignment ops have not already been added to the list
    if self._assign_gradient_ops:
      raise ValueError(
          'Assign op list not empty. _get_mask_assign_ops() called twice?')

    weights = get_weights()
    old_weights = get_old_weights()
    old_old_weights = get_old_old_weights()
    gradients = get_gradients()

    if len(weights) != len(old_weights):
      raise ValueError(
          'Number of weights %s and number of old_weights %s mismatch' %
          (len(weights), len(old_weights)))

    if len(weights) != len(gradients):
      raise ValueError(
          'Number of weights %s and number of gradients %s mismatch' %
          (len(weights), len(gradients)))

    for index, _ in enumerate(weights):
      weight = weights[index]
      old_weight = old_weights[index]
      old_old_weight = old_old_weights[index]
      gradient = gradients[index]

      if weight.shape.as_list() != old_weight.shape.as_list():
        raise ValueError('weight tensor has different shape from old_weight')

      if weight.shape.as_list() != gradient.shape.as_list():
        raise ValueError('weight tensor has different shape from gradient')

      if weight.shape.as_list() != old_old_weight.shape.as_list():
        raise ValueError('weight tensor has different shape from old_weight')

      is_partitioned = isinstance(weight, variables.PartitionedVariable)
      if is_partitioned:
        weight = weight.as_tensor()
        old_weight = old_weight.as_tensor()
        old_old_weight = old_old_weight.as_tensor()

      decay = self._spec.gradient_decay_rate
      if self._spec.prune_option == 'first_order_gradient':
        tf.logging.info('Applying first order gradient pruning')
        normalized_weight_delta = tf.nn.l2_normalize(
            tf.abs(weight - old_weight))
      elif self._spec.prune_option == 'second_order_gradient':
        tf.logging.info('Applying second order gradient pruning')
        normalized_weight_delta = tf.nn.l2_normalize(
            tf.abs(weight + old_old_weight - 2 * old_weight))
      else:
        raise ValueError('Unknown prune option. Should not execute this code.')
      new_gradient = decay * gradient + (1 - decay) * normalized_weight_delta

      self._assign_gradient_ops.append(
          pruning_utils.variable_assign(gradient, new_gradient))
예제 #2
0
  def _get_mask_assign_ops(self):
    # Make sure the assignment ops have not already been added to the list
    if self._assign_ops:
      raise ValueError(
          'Assign op list not empty. _get_mask_assign_ops() called twice?')

    masks = get_masks()
    weights = get_weights()
    thresholds = get_thresholds()
    gradients = get_gradients()

    if len(masks) != len(thresholds):
      raise ValueError(
          'Number of masks %s and number of thresholds %s mismatch' %
          (len(masks), len(thresholds)))

    for index, mask in enumerate(masks):
      threshold = thresholds[index]
      weight = weights[index]
      if self._spec.prune_option in ('first_order_gradient',
                                     'second_order_gradient'):
        gradient = gradients[index]
      else:
        gradient = None

      is_partitioned = isinstance(weight, variables.PartitionedVariable)
      if is_partitioned:
        weight = weight.as_tensor()

      new_threshold, new_mask = self._maybe_update_block_mask(
          weight, threshold, gradient)
      self._assign_ops.append(
          pruning_utils.variable_assign(threshold, new_threshold))

      self._assign_ops.append(
          pruning_utils.partitioned_variable_assign(mask, new_mask)
          if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
예제 #3
0
    def _get_assign_old_old_weight_ops(self):
        if self._assign_old_old_weight_ops:
            raise ValueError(
                'Assign op list not empty. _get_old_old_weight_assign_ops() called twice?'
            )

        old_old_weights = get_old_old_weights()
        old_weights = get_old_weights()

        if len(old_old_weights) != len(old_weights):
            raise ValueError(
                'Number of old_old_weights %s and number of old_weights %s mismatch'
                % (len(old_old_weights), len(old_weights)))

        for index, old_old_weight in enumerate(old_old_weights):
            old_weight = old_weights[index]

            self._assign_old_old_weight_ops.append(
                pruning_utils.variable_assign(old_old_weight, old_weight))