def prune_conv2d_layer_with_craig(layer: nn.Conv2d, prune_percent_per_layer: float, similarity_metric: Union[Text, Dict] = "", prune_type: Text = "craig", **kwargs) -> Tuple[List[int], List[float]]: # Get CRAIG subset. subset_nodes: List subset_weights: List subset_nodes, subset_weights = get_layer_craig_subset( layer=layer, original_num_nodes=layer.out_channels, prune_percent_per_layer=prune_percent_per_layer, similarity_metric=similarity_metric, prune_type=prune_type, **kwargs) # Remove nodes+weights+biases, and adjust weights. num_nodes: int = len(subset_nodes) # Prune current layer. # Multiply weights (and biases?) by subset_weights. subset_weights_tensor = torch.tensor(subset_weights) layer.weight = nn.Parameter(layer.weight[subset_nodes] * subset_weights_tensor.reshape( (num_nodes, 1, 1, 1))) if layer.bias is not None: layer.bias = nn.Parameter(layer.bias[subset_nodes] * subset_weights_tensor) layer.out_channels = num_nodes return subset_nodes, subset_weights
def update_attributes(self, conv: nn.Conv2d): if conv.groups > 1: # Depthwise Convolution in_channels = int(conv.weight.shape[0]) conv.groups = in_channels out_channels = conv.groups * int(conv.weight.shape[1]) conv.in_channels = in_channels conv.out_channels = out_channels else: in_channels = int(conv.weight.shape[1]) conv.in_channels = in_channels
def active_rebuild(self, conv: nn.Conv2d): assert conv.groups == 1, 'Group Convolution is not supported.' # mask_model is not None => conv-bnのbnを用いてconvをpruningする例外 # ocに関する全kernelの総和が0のチャネルは除外し、それ以外を残す mask = conv.weight.data.sum(dim=(1, 2, 3)) != 0 self.logger.debug(log_shape(conv.weight.data, mask)) conv.weight.data = conv.weight.data[mask].clone() if conv.bias is not None: self.logger.debug(log_shape(conv.bias.data, mask)) conv.bias.data = conv.bias.data[mask].clone() out_channels = int(conv.weight.shape[0]) conv.out_channels = out_channels return mask
def prune_conv2d(module: Conv2d, in_keep_idxes=None, out_keep_idxes=None): if in_keep_idxes is None: in_keep_idxes = list(range(module.weight.shape[1])) if out_keep_idxes is None: out_keep_idxes = list(range(module.weight.shape[0])) assert len(in_keep_idxes) <= module.weight.shape[1] assert len(out_keep_idxes) <= module.weight.shape[0] module.out_channels = len(out_keep_idxes) module.in_channels = len(in_keep_idxes) module.weight.data = module.weight.data[out_keep_idxes, :, :, :] module.weight.data = module.weight.data[:, in_keep_idxes, :, :] module.weight.grad = None if module.bias is not None: module.bias.data = module.bias.data[out_keep_idxes] module.bias.grad = None return in_keep_idxes, out_keep_idxes
def prune_conv2d(module: Conv2d, in_keep_idxes=None, out_keep_idxes=None): if in_keep_idxes is None: in_keep_idxes = list(range(module.weight.shape[1])) if out_keep_idxes is None: out_keep_idxes = list(range(module.weight.shape[0])) is_depthwise = is_depthwise_conv2d(module) if is_depthwise: module.groups = len(in_keep_idxes) assert len(in_keep_idxes) == len(out_keep_idxes) else: assert ( len(in_keep_idxes) <= module.weight.shape[1] ), f"len(in_keep_idxes): {len(in_keep_idxes)}, module.weight.shape[1]: {module.weight.shape[1]}" assert ( len(out_keep_idxes) <= module.weight.shape[0] ), f"len(out_keep_idxes): {len(out_keep_idxes)}, module.weight.shape[0]: {module.weight.shape[0]}" module.out_channels = len(out_keep_idxes) module.in_channels = len(in_keep_idxes) module.weight = torch.nn.Parameter(module.weight.data[out_keep_idxes, :, :, :]) if not is_depthwise: module.weight = torch.nn.Parameter(module.weight.data[:, in_keep_idxes, :, :]) module.weight.grad = None if module.bias is not None: module.bias = torch.nn.Parameter(module.bias.data[out_keep_idxes]) module.bias.grad = None return in_keep_idxes, out_keep_idxes
def prune_conv_layer( conv_layer: nn.Conv2d, bn_layer: nn.BatchNorm2d, sparse_layer_out: typing.Optional[Union[nn.BatchNorm2d, SparseGate]], in_channel_mask: typing.Optional[np.ndarray], prune_output_mode: str, pruner: Pruner, prune_mode: str, sparse_layer_in: typing.Optional[Union[nn.BatchNorm2d, SparseGate]], ) -> typing.Tuple[np.ndarray, np.ndarray]: """ Note: if the sparse_layer is SparseGate, the gate will be replaced by BatchNorm scaling factor. The value of the gate will be set to all ones. :param prune_output_mode: how to handle the output channel (case insensitive) "keep": keep the output channel intact "same": keep the output channel as same as input channel "prune": prune the output according to bn_layer :param pruner: the method to determinate the pruning mask. :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1) :param bn_layer: the BatchNorm layer after the convolution layer :param conv_layer: the convolution layer to be pruned :param sparse_layer_out: the layer to determine the output sparsity. Support BatchNorm2d and SparseGate. :param sparse_layer_in: the layer to determine the input sparsity. Support BatchNorm2d and SparseGate. When the `sparse_layer_in` is None, there is no input sparse layers, the input channel will be determined by the `in_channel_mask` Note: `in_channel_mask` is CONFLICT with `sparse_layer_in`! :param prune_mode: pruning mode (`str`): - `"multiply"`: pruning threshold is determined by the multiplication of `sparse_layer` and `bn_layer` only available when `sparse_layer` is `SparseGate` - `None` or `"default"`: default behaviour. The pruning threshold is determined by `sparse_layer` :return out_channel_mask """ assert isinstance(conv_layer, nn.Conv2d), f"conv_layer got {conv_layer}" assert isinstance(sparse_layer_out, nn.BatchNorm2d) or isinstance( sparse_layer_out, SparseGate), f"sparse_layer got {sparse_layer_out}" if in_channel_mask is not None and sparse_layer_in is not None: raise ValueError( "Conflict option: in_channel_mask and sparse_layer_in") prune_mode = prune_mode.lower() prune_output_mode = str.lower(prune_output_mode) if prune_mode == 'multiply': if not isinstance(sparse_layer_out, SparseGate): raise ValueError( f"Do not support prune_mode {prune_mode} when the sparse_layer is {sparse_layer_out}" ) with torch.no_grad(): conv_weight: torch.Tensor = conv_layer.weight.data.clone() # prune the input channel of the conv layer # if sparse_layer_in and in_channel_mask are both None, the input dim will NOT be pruned if sparse_layer_in is not None: if in_channel_mask is not None: raise ValueError("") sparse_weight_in: np.ndarray = sparse_layer_in.weight.view( -1).data.cpu().numpy() # the in_channel_mask will be overwrote in_channel_mask = pruner(sparse_weight_in) if in_channel_mask is not None: # prune the input channel according to the in_channel_mask # convert mask to channel indexes idx_in = np.squeeze(np.argwhere(np.asarray(in_channel_mask))) if len(idx_in.shape) == 0: # expand the single scalar to array idx_in = np.expand_dims(idx_in, 0) elif len(idx_in.shape) == 1 and idx_in.shape[0] == 0: # nothing left, prune the whole block out_channel_mask = np.full(conv_layer.out_channels, False) return in_channel_mask, out_channel_mask # prune the input of the conv layer if conv_layer.groups == 1: conv_weight = conv_weight[:, idx_in.tolist(), :, :] else: assert conv_weight.shape[ 1] == 1, "only works for groups == num_channels" # prune the output channel of the conv layer if prune_output_mode == "prune": # the sparse_layer.weight need to be flatten, because the weight of SparseGate is not 1d sparse_weight_out: np.ndarray = sparse_layer_out.weight.view( -1).data.cpu().numpy() if prune_mode == 'multiply': bn_weight = bn_layer.weight.data.cpu().numpy() sparse_weight_out = sparse_weight_out * bn_weight # element-wise multiplication elif prune_mode != 'default': raise ValueError(f"Do not support prune_mode {prune_mode}") # prune and get the pruned mask out_channel_mask = pruner(sparse_weight_out) elif prune_output_mode == "keep": # do not prune the output out_channel_mask = np.ones(conv_layer.out_channels) elif prune_output_mode == "same": # prune the output channel with the input mask # keep the conv layer in_channel == out_channel out_channel_mask = in_channel_mask else: raise ValueError(f"invalid prune_output_mode: {prune_output_mode}") idx_out: np.ndarray = np.squeeze( np.argwhere(np.asarray(out_channel_mask))) if len(idx_out.shape) == 0: # expand the single scalar to array idx_out = np.expand_dims(idx_out, 0) elif len(idx_out.shape) == 1 and idx_out.shape[0] == 0: # no channel left # return mask directly # the block is supposed to be set as a identity mapping return out_channel_mask, in_channel_mask conv_weight = conv_weight[idx_out.tolist(), :, :, :] # change the property of the conv layer conv_layer.in_channels = len(idx_in) conv_layer.out_channels = len(idx_out) conv_layer.weight.data = conv_weight if conv_layer.groups != 1: # set the new groups for dw layer conv_layer.groups = conv_layer.in_channels pass # prune the bn layer bn_layer.weight.data = bn_layer.weight.data[idx_out.tolist()].clone() bn_layer.bias.data = bn_layer.bias.data[idx_out.tolist()].clone() bn_layer.running_mean = bn_layer.running_mean[idx_out.tolist()].clone() bn_layer.running_var = bn_layer.running_var[idx_out.tolist()].clone() # set bn properties bn_layer.num_features = len(idx_out) # prune the gate if isinstance(sparse_layer_out, SparseGate): sparse_layer_out.prune(idx_out) # multiply the bn weight and SparseGate weight sparse_weight_out: torch.Tensor = sparse_layer_out.weight.view(-1) bn_layer.weight.data = (bn_layer.weight.data * sparse_weight_out).clone() bn_layer.bias.data = (bn_layer.bias.data * sparse_weight_out).clone() # the function of the SparseGate is now replaced by bn layers # the SparseGate should be disabled sparse_layer_out.set_ones() return out_channel_mask, in_channel_mask