コード例 #1
0
  def _check_2x4_sparsity(self, model):

    def _is_pruned_2_by_4(weights):
      if weights.shape.rank == 2:
        prepared_weights = tf.transpose(weights)
      elif weights.shape.rank == 4:
        perm_weights = tf.transpose(weights, perm=[3, 0, 1, 2])
        prepared_weights = tf.reshape(perm_weights,
                                      [-1, perm_weights.shape[-1]])

      prepared_weights_np = prepared_weights.numpy()

      for row in range(0, prepared_weights_np.shape[0]):
        for col in range(0, prepared_weights_np.shape[1], 4):
          if np.count_nonzero(prepared_weights_np[row, col:col + 4]) > 2:
            return False
      return True

    prunable_layers = img_cls_task.collect_prunable_layers(model)
    for layer in prunable_layers:
      for weight, _, _ in layer.pruning_vars:
        if weight.shape[-2] % 4 == 0:
          self.assertTrue(_is_pruned_2_by_4(weight))
コード例 #2
0
  def _validate_model_pruned(self, model, config_name):

    pruning_weight_names = []
    prunable_layers = img_cls_task.collect_prunable_layers(model)
    for layer in prunable_layers:
      for weight, _, _ in layer.pruning_vars:
        pruning_weight_names.append(weight.name)
    if config_name == 'resnet_imagenet_pruning':
      # Conv2D : 1
      # BottleneckBlockGroup : 4+3+3 = 10
      # BottleneckBlockGroup1 : 4+3+3+3 = 13
      # BottleneckBlockGroup2 : 4+3+3+3+3+3 = 19
      # BottleneckBlockGroup3 : 4+3+3 = 10
      # FullyConnected : 1
      # Total : 54
      self.assertLen(pruning_weight_names, 54)
    elif config_name == 'mobilenet_imagenet_pruning':
      # Conv2DBN = 1
      # InvertedBottleneckBlockGroup = 2
      # InvertedBottleneckBlockGroup1~16 = 48
      # Conv2DBN = 1
      # FullyConnected : 1
      # Total : 53
      self.assertLen(pruning_weight_names, 53)