def prune_bn2d(module: BatchNorm2d, keep_idxes): module.num_features = len(keep_idxes) module.weight.data = module.weight.data[keep_idxes] module.weight.grad = None module.bias.data = module.bias.data[keep_idxes] module.bias.grad = None module.running_mean = module.running_mean[keep_idxes] module.running_var = module.running_var[keep_idxes]
def prune_conv_layer( conv_layer: Union[nn.Conv2d, nn.Linear], bn_layer: nn.BatchNorm2d, sparse_layer: Union[nn.BatchNorm2d, SparseGate], in_channel_mask: np.ndarray, prune_output_mode: str, pruner: Callable[[np.ndarray], float], prune_mode: str, sparse_layer_in: typing.Optional[SparseGate] = None, prune_on="factor") -> typing.Tuple[np.ndarray, np.ndarray]: """ Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm scaling factor. The value of the gate will be set to all ones. :param prune_output_mode: how to handle the output channel (case insensitive) "keep": keep the output channel intact "same": keep the output channel as same as input channel "prune": prune the output according to bn_layer :param pruner: the method to determinate the pruning threshold. :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1) :param bn_layer: the BatchNorm layer after the convolution layer :param conv_layer: the convolution layer to be pruned :param sparse_layer: the layer to determine the sparsity. Support BatchNorm2d and SparseGate. :param prune_mode: pruning mode (`str`), case-insensitive: - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer` only available when `sparse_layer` is `SparseGate` - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer` :param prune_on: 'factor' or 'weight'. :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate. When the `sparse_layer_in` is None, there is no input sparse layers, the input channel will be determined by the `in_channel_mask` Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`! :return out_channel_mask """ assert isinstance(conv_layer, nn.Conv2d) or isinstance( conv_layer, nn.Linear), f"conv_layer got {conv_layer}" assert isinstance(sparse_layer, nn.BatchNorm2d) or \ isinstance(sparse_layer, nn.BatchNorm1d) or isinstance(sparse_layer, SparseGate), \ f"sparse_layer got {sparse_layer}" if in_channel_mask is not None and sparse_layer_in is not None: raise ValueError( "Conflict option: in_channel_mask and sparse_layer_in") prune_mode = prune_mode.lower() #lower() 方法转换字符串中所有大写字符为小写 prune_output_mode = str.lower( prune_output_mode) #lower() 方法转换字符串中所有大写字符为小写 if prune_mode == 'multiply': if bn_layer is None: raise ValueError("Could not use multiply mode when bn is None") if not isinstance(sparse_layer, SparseGate): raise ValueError( f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer}" ) with torch.no_grad(): conv_weight: torch.Tensor = conv_layer.weight.data.clone() # prune the input channel of the conv layer # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned if sparse_layer_in is not None: if in_channel_mask is not None: raise ValueError("") sparse_weight_in: np.ndarray = sparse_layer_in.weight.view( -1).data.cpu().numpy() # the in_channel_mask will be overwrote input_threshold = pruner(sparse_weight_in) in_channel_mask: np.ndarray = sparse_weight_in > input_threshold #获得掩码 # convert mask to channel indexes idx_in = np.squeeze(np.argwhere(np.asarray( in_channel_mask))) #np.argwhere(a):返回非0的数组元组的索引,其中a是要索引数组的条件。 if len(idx_in.shape) == 0: # expand the single scalar to array idx_in = np.expand_dims(idx_in, 0) # prune the input of the conv layer if isinstance(conv_layer, nn.Conv2d): if conv_layer.groups == 1: a = conv_weight conv_weight = conv_weight[:, idx_in.tolist( ), :, :] #tolist():将数组或者矩阵转换成列表 print(a - conv_weight) else: assert conv_weight.shape[ 1] == 1, "only works for groups == num_channels" elif isinstance(conv_layer, nn.Linear): conv_weight = conv_weight[:, idx_in.tolist()] else: raise ValueError(f"unsupported conv layer type: {conv_layer}") # prune the output channel of the conv layer if prune_output_mode == "prune": if prune_on == 'factor': # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d sparse_weight: np.ndarray = sparse_layer.weight.view( -1).data.cpu().numpy() if prune_mode == 'multiply': bn_weight = bn_layer.weight.data.cpu().numpy() sparse_weight = sparse_weight * bn_weight # element-wise multiplication pass elif prune_mode != 'default': raise ValueError(f"Do not support prune_mode {prune_mode}") # prune according the bn layer output_threshold = pruner(sparse_weight) out_channel_mask: np.ndarray = sparse_weight > output_threshold #获取掩码 else: sparse_weight: np.ndarray = sparse_layer.weight.view( -1).data.cpu().numpy() # in this case, the sparse weight should be the conv or linear weight out_channel_mask: np.ndarray = pruner( conv_weight.data.cpu().numpy()) elif prune_output_mode == "keep": # do not prune the output out_channel_mask = np.ones(conv_layer.out_channels) elif prune_output_mode == "same": # prune the output channel with the input mask # keep the conv layer in_channel == out_channel out_channel_mask = in_channel_mask else: raise ValueError(f"invalid prune_output_mode: {prune_output_mode}") if not np.any(out_channel_mask): #NumPy any() 判断矩阵中 是否 有一个元素 为True # there is no channel left return out_channel_mask, in_channel_mask idx_out: np.ndarray = np.squeeze( np.argwhere(np.asarray(out_channel_mask))) if len(idx_out.shape) == 0: # 0-d scalar idx_out = np.expand_dims(idx_out, 0) if isinstance(conv_layer, nn.Conv2d): conv_weight = conv_weight[idx_out.tolist(), :, :, :] #获得输出权值 elif isinstance(conv_layer, nn.Linear): conv_weight = conv_weight[idx_out.tolist(), :] linear_bias = conv_layer.bias.clone() linear_bias = linear_bias[idx_out.tolist()] else: raise ValueError(f"unsupported conv layer type: {conv_layer}") # change the property of the conv layer if isinstance(conv_layer, nn.Conv2d): conv_layer.in_channels = len(idx_in) conv_layer.out_channels = len(idx_out) elif isinstance(conv_layer, nn.Linear): conv_layer.in_features = len(idx_in) conv_layer.out_features = len(idx_out) conv_layer.weight.data = conv_weight if isinstance(conv_layer, nn.Linear): conv_layer.bias.data = linear_bias if isinstance(conv_layer, nn.Conv2d) and conv_layer.groups != 1: # set the new groups for dw layer (for MobileNet) conv_layer.groups = conv_layer.in_channels pass # prune the bn layer if bn_layer is not None: bn_layer.weight.data = bn_layer.weight.data[ idx_out.tolist()].clone() bn_layer.bias.data = bn_layer.bias.data[ idx_out.tolist()].clone() ########## bn_layer.running_mean = bn_layer.running_mean[ idx_out.tolist()].clone() bn_layer.running_var = bn_layer.running_var[ idx_out.tolist()].clone() # set bn properties bn_layer.num_features = len(idx_out) # prune the gate if isinstance(sparse_layer, SparseGate): sparse_layer.prune(idx_out) # multiply the bn weight and SparseGate weight sparse_weight: torch.Tensor = sparse_layer.weight.view(-1) if bn_layer is not None: bn_layer.weight.data = (bn_layer.weight.data * sparse_weight).clone() bn_layer.bias.data = (bn_layer.bias.data * sparse_weight).clone() # the function of the SparseGate is now replaced by bn layers # the SparseGate should be disabled sparse_layer.set_ones() return out_channel_mask, in_channel_mask
def prune_conv_layer( conv_layer: nn.Conv2d, bn_layer: nn.BatchNorm2d, sparse_layer_out: typing.Optional[Union[nn.BatchNorm2d, SparseGate]], in_channel_mask: typing.Optional[np.ndarray], prune_output_mode: str, pruner: Pruner, prune_mode: str, sparse_layer_in: typing.Optional[Union[nn.BatchNorm2d, SparseGate]], ) -> typing.Tuple[np.ndarray, np.ndarray]: """ Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm scaling factor. The value of the gate will be set to all ones. :param prune_output_mode: how to handle the output channel (case insensitive) "keep": keep the output channel intact "same": keep the output channel as same as input channel "prune": prune the output according to bn_layer :param pruner: the method to determinate the pruning mask. :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1) :param bn_layer: the BatchNorm layer after the convolution layer :param conv_layer: the convolution layer to be pruned :param sparse_layer_out: the layer to determine the output sparsity. Support BatchNorm2d and SparseGate. :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate. When the `sparse_layer_in` is None, there is no input sparse layers, the input channel will be determined by the `in_channel_mask` Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`! :param prune_mode: pruning mode (`str`): - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer` only available when `sparse_layer` is `SparseGate` - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer` :return out_channel_mask """ assert isinstance(conv_layer, nn.Conv2d), f"conv_layer got {conv_layer}" assert isinstance(sparse_layer_out, nn.BatchNorm2d) or isinstance( sparse_layer_out, SparseGate), f"sparse_layer got {sparse_layer_out}" if in_channel_mask is not None and sparse_layer_in is not None: raise ValueError( "Conflict option: in_channel_mask and sparse_layer_in") prune_mode = prune_mode.lower() prune_output_mode = str.lower(prune_output_mode) if prune_mode == 'multiply': if not isinstance(sparse_layer_out, SparseGate): raise ValueError( f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer_out}" ) with torch.no_grad(): conv_weight: torch.Tensor = conv_layer.weight.data.clone() # prune the input channel of the conv layer # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned if sparse_layer_in is not None: if in_channel_mask is not None: raise ValueError("") sparse_weight_in: np.ndarray = sparse_layer_in.weight.view( -1).data.cpu().numpy() # the in_channel_mask will be overwrote in_channel_mask = pruner(sparse_weight_in) if in_channel_mask is not None: # prune the input channel according to the in_channel_mask # convert mask to channel indexes idx_in = np.squeeze(np.argwhere(np.asarray(in_channel_mask))) if len(idx_in.shape) == 0: # expand the single scalar to array idx_in = np.expand_dims(idx_in, 0) elif len(idx_in.shape) == 1 and idx_in.shape[0] == 0: # nothing left, prune the whole block out_channel_mask = np.full(conv_layer.out_channels, False) return in_channel_mask, out_channel_mask # prune the input of the conv layer if conv_layer.groups == 1: conv_weight = conv_weight[:, idx_in.tolist(), :, :] else: assert conv_weight.shape[ 1] == 1, "only works for groups == num_channels" # prune the output channel of the conv layer if prune_output_mode == "prune": # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d sparse_weight_out: np.ndarray = sparse_layer_out.weight.view( -1).data.cpu().numpy() if prune_mode == 'multiply': bn_weight = bn_layer.weight.data.cpu().numpy() sparse_weight_out = sparse_weight_out * bn_weight # element-wise multiplication elif prune_mode != 'default': raise ValueError(f"Do not support prune_mode {prune_mode}") # prune and get the pruned mask out_channel_mask = pruner(sparse_weight_out) elif prune_output_mode == "keep": # do not prune the output out_channel_mask = np.ones(conv_layer.out_channels) elif prune_output_mode == "same": # prune the output channel with the input mask # keep the conv layer in_channel == out_channel out_channel_mask = in_channel_mask else: raise ValueError(f"invalid prune_output_mode: {prune_output_mode}") idx_out: np.ndarray = np.squeeze( np.argwhere(np.asarray(out_channel_mask))) if len(idx_out.shape) == 0: # expand the single scalar to array idx_out = np.expand_dims(idx_out, 0) elif len(idx_out.shape) == 1 and idx_out.shape[0] == 0: # no channel left # return mask directly # the block is supposed to be set as a identity mapping return out_channel_mask, in_channel_mask conv_weight = conv_weight[idx_out.tolist(), :, :, :] # change the property of the conv layer conv_layer.in_channels = len(idx_in) conv_layer.out_channels = len(idx_out) conv_layer.weight.data = conv_weight if conv_layer.groups != 1: # set the new groups for dw layer conv_layer.groups = conv_layer.in_channels pass # prune the bn layer bn_layer.weight.data = bn_layer.weight.data[idx_out.tolist()].clone() bn_layer.bias.data = bn_layer.bias.data[idx_out.tolist()].clone() bn_layer.running_mean = bn_layer.running_mean[idx_out.tolist()].clone() bn_layer.running_var = bn_layer.running_var[idx_out.tolist()].clone() # set bn properties bn_layer.num_features = len(idx_out) # prune the gate if isinstance(sparse_layer_out, SparseGate): sparse_layer_out.prune(idx_out) # multiply the bn weight and SparseGate weight sparse_weight_out: torch.Tensor = sparse_layer_out.weight.view(-1) bn_layer.weight.data = (bn_layer.weight.data * sparse_weight_out).clone() bn_layer.bias.data = (bn_layer.bias.data * sparse_weight_out).clone() # the function of the SparseGate is now replaced by bn layers # the SparseGate should be disabled sparse_layer_out.set_ones() return out_channel_mask, in_channel_mask