Exemplo n.º 1
0
    def _get_mask_assign_ops(self):
        # Make sure the assignment ops have not already been added to the list
        if self._assign_ops:
            raise ValueError(
                'Assign op list not empty. _get_mask_assign_ops() called twice?'
            )

        masks = get_masks()
        weights = get_weights()
        thresholds = get_thresholds()

        if len(masks) != len(thresholds):
            raise ValueError(
                'Number of masks %s and number of thresholds %s mismatch' %
                (len(masks), len(thresholds)))

        for index, mask in enumerate(masks):
            threshold = thresholds[index]
            weight = weights[index]
            is_partitioned = isinstance(weight, variables.PartitionedVariable)
            if is_partitioned:
                weight = weight.as_tensor()

            new_threshold, new_mask = self._maybe_update_block_mask(
                weight, threshold)
            self._assign_ops.append(
                pruning_utils.variable_assign(threshold, new_threshold))

            self._assign_ops.append(
                pruning_utils.partitioned_variable_assign(mask, new_mask)
                if is_partitioned else pruning_utils.
                variable_assign(mask, new_mask))
Exemplo n.º 2
0
  def _get_mask_assign_ops(self):
    # Make sure the assignment ops have not already been added to the list
    if self._assign_ops:
      raise ValueError(
          'Assign op list not empty. _get_mask_assign_ops() called twice?')

    masks = get_masks()
    weights = get_weights()
    thresholds = get_thresholds()

    if len(masks) != len(thresholds):
      raise ValueError(
          'Number of masks %s and number of thresholds %s mismatch' %
          (len(masks), len(thresholds)))

    for index, mask in enumerate(masks):
      threshold = thresholds[index]
      weight = weights[index]
      is_partitioned = isinstance(weight, variables.PartitionedVariable)
      if is_partitioned:
        weight = weight.as_tensor()

      new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
      self._assign_ops.append(
          pruning_utils.variable_assign(threshold, new_threshold))

      self._assign_ops.append(
          pruning_utils.partitioned_variable_assign(mask, new_mask)
          if is_partitioned else pruning_utils.variable_assign(mask, new_mask))