def __init__(self, inplanes: int, outplanes: int, cfg: List[int], gate: bool, stride=1, downsample=None, expand=False): """ :param inplanes: the input dimension of the block :param outplanes: the output dimension of the block :param cfg: the output dimension of each convolution layer config format: [conv1_out, conv2_out, conv3_out, conv1_in] :param gate: if use gate between conv layers :param stride: the stride of the first convolution layer :param downsample: if use the downsample convolution in the residual connection :param expand: if use ChannelExpand layer in the block """ super(Bottleneck, self).__init__() conv_in = cfg[3] self.use_gate = gate self._identity = False if 0 in cfg or conv_in == 0: # the whole block is pruned self.identity = True else: # the main body of the block # add a SparseGate before the first conv # to enable the pruning of the input dimension for further reducing computational complexity self.select = ChannelSelect(inplanes) # after select, the channel number of feature map is conv_in self.input_gate = SparseGate(conv_in) self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(cfg[0]) self.gate1 = SparseGate(cfg[0]) if gate else Identity() self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(cfg[1]) self.gate2 = SparseGate(cfg[1]) if gate else Identity() self.conv3 = nn.Conv2d(cfg[1], cfg[2], kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(cfg[2]) self.gate3 = SparseGate(cfg[2]) if gate else Identity() self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.expand = expand self.expand_layer = ChannelExpand(outplanes) if expand else None
def __init__(self, in_planes, outplanes, cfg, stride=1, option='A', gate=False, use_input_mask=False): super(BasicBlock, self).__init__() self.gate = gate self.use_input_mask = use_input_mask conv_in = cfg[2] if len(cfg) != 3: raise ValueError("cfg len should be 3, got {}".format(cfg)) self.is_empty_block = 0 in cfg if not self.is_empty_block: if self.use_input_mask: # input channel: in_planes, output_channel: conv_in self.input_channel_selector = ChannelSelect(in_planes) # input channel: conv_in, output_channel: conv_in self.input_mask = SparseGate(conv_in) else: self.input_channel_selector = None self.input_mask = None self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(cfg[0]) self.gate1 = SparseGate(cfg[0]) if self.gate else Identity() self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(cfg[1]) self.gate2 = SparseGate(cfg[1]) if self.gate else Identity() self.expand_layer = ChannelExpand(outplanes) else: self.conv1 = Identity() self.bn1 = Identity() self.conv2 = Identity() self.bn2 = Identity() self.expand_layer = Identity() self.shortcut = nn.Sequential() # do nothing if stride != 1 or in_planes != outplanes: if option == 'A': """For CIFAR10 ResNet paper uses option A.""" self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], ( 0, 0, 0, 0, (outplanes - in_planes) // 2, (outplanes - in_planes) // 2), "constant", 0)) elif option == 'B': raise NotImplementedError("Option B is not implemented")
def __init__(self, conv: nn.Conv2d, batch_norm: bool, output_channel: int, gate: bool): super().__init__() self.conv = conv self.gate: bool = gate if batch_norm: if isinstance(self.conv, nn.Conv2d): self.batch_norm = nn.BatchNorm2d(output_channel) elif isinstance(self.conv, nn.Linear): self.batch_norm = nn.BatchNorm1d(output_channel) else: self.batch_norm = Identity() #返回输入 if gate: self.sparse_gate = SparseGate(output_channel) self.relu = nn.ReLU(inplace=True)
def __init__(self, in_planes, out_planes, gate: bool, kernel_size=3, stride=1, groups=1): """ A sequence of modules - Conv - BN - Gate (optional, if gate is True) - ReLU """ padding = (kernel_size - 1) // 2 layers = [ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), nn.BatchNorm2d(out_planes), ] if gate: layers.append(SparseGate(out_planes)) layers.append(nn.ReLU6(inplace=True)) super(ConvBNReLU, self).__init__(*layers)
class BasicBlock(BuildingBlock): expansion = 1 def __init__(self, in_planes, outplanes, cfg, stride=1, option='A', gate=False, use_input_mask=False): super(BasicBlock, self).__init__() self.gate = gate self.use_input_mask = use_input_mask conv_in = cfg[2] if len(cfg) != 3: raise ValueError("cfg len should be 3, got {}".format(cfg)) self.is_empty_block = 0 in cfg if not self.is_empty_block: if self.use_input_mask: # input channel: in_planes, output_channel: conv_in self.input_channel_selector = ChannelSelect(in_planes) # input channel: conv_in, output_channel: conv_in self.input_mask = SparseGate(conv_in) else: self.input_channel_selector = None self.input_mask = None self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(cfg[0]) self.gate1 = SparseGate(cfg[0]) if self.gate else Identity() self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(cfg[1]) self.gate2 = SparseGate(cfg[1]) if self.gate else Identity() self.expand_layer = ChannelExpand(outplanes) else: self.conv1 = Identity() self.bn1 = Identity() self.conv2 = Identity() self.bn2 = Identity() self.expand_layer = Identity() self.shortcut = nn.Sequential() # do nothing if stride != 1 or in_planes != outplanes: if option == 'A': """For CIFAR10 ResNet paper uses option A.""" self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], ( 0, 0, 0, 0, (outplanes - in_planes) // 2, (outplanes - in_planes) // 2), "constant", 0)) elif option == 'B': raise NotImplementedError("Option B is not implemented") # self.shortcut = nn.Sequential( # nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), # nn.BatchNorm2d(self.expansion * planes) # ) def forward(self, x): if self.is_empty_block: out = self.shortcut(x) out = F.relu(out) return out else: if self.use_input_mask: out = self.input_channel_selector(x) out = self.input_mask(out) else: out = x # relu-gate and gate-relu is same out = F.relu(self.bn1(self.conv1(out))) out = self.gate1(out) out = self.bn2(self.conv2(out)) out = self.gate2(out) out = self.expand_layer(out) out += self.shortcut(x) out = F.relu(out) return out def do_pruning(self, pruner: Callable[[np.ndarray], float], prune_mode: str, in_channel_mask: np.ndarray = None, prune_on=None) -> None: """ Prune the block in place. Note: There is not ChannelExpand layer at the end of the block. After pruning, the output dimension might be changed. There will be dimension conflict :param pruner: the method to determinate the pruning threshold. :param prune_mode: same as `models.common.prune_conv_layer` """ if in_channel_mask is not None: raise ValueError("Do not set in_channel_mask") if self.is_empty_block: return # keep input dim and output dim unchanged in_channel_mask = np.ones(self.conv1.in_channels) # prune conv1 in_channel_mask, conv1_input_channel_mask = prune_conv_layer(conv_layer=self.conv1, bn_layer=self.bn1, sparse_layer_in=self.input_mask, sparse_layer=self.gate1 if self.gate else self.bn1, in_channel_mask=None if self.use_input_mask else in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode=prune_mode, prune_on=prune_on, ) if not np.any(in_channel_mask) or not np.any(conv1_input_channel_mask): # prune the entire block self.is_empty_block = True return if self.use_input_mask: # prune the input dimension of the first conv layer (conv1) channel_select_idx = np.squeeze(np.argwhere(np.asarray(conv1_input_channel_mask))) if len(channel_select_idx.shape) == 0: # expand the single scalar to array channel_select_idx = np.expand_dims(channel_select_idx, 0) elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0: # nothing left # this code should not be executed, if there is no channel left, # the identity will be set as True and return (see code above) raise NotImplementedError("No layer left in input channel") self.input_channel_selector.idx = channel_select_idx self.input_mask.do_pruning(conv1_input_channel_mask) # prune conv2 out_channel_mask, _ = prune_conv_layer(conv_layer=self.conv2, bn_layer=self.bn2, sparse_layer=self.gate2 if self.gate else self.bn2, in_channel_mask=in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode=prune_mode, prune_on=prune_on, ) if not np.any(out_channel_mask): # prune the entire block self.is_empty_block = True return # do padding allowing adding with residual connection # the output dim is unchanged # note that the idx of the expander might be set in a pruned model original_expander_idx = self.expand_layer.idx assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent" pruned_expander_idx = original_expander_idx[out_channel_mask] idx = np.squeeze(pruned_expander_idx) if len(idx.shape) == 0: # expand 0-d idx idx = np.expand_dims(idx, 0) pass self.expand_layer.idx = idx def config(self) -> typing.Tuple[int, int, int]: if self.is_empty_block: return 0, 0, 0 return self.conv1.out_channels, self.conv2.out_channels, self.conv1.in_channels @property def expand_idx(self): raise NotImplementedError() if self.is_empty_block: return None return self.expand_layer.idx @expand_idx.setter def expand_idx(self, value): if self.is_empty_block: if value is not None: raise ValueError(f"The expand_idx of the empty block is supposed to be None, got {value}") # do nothing for empty block self.expand_layer.idx = value def _compute_flops_weight(self, scaling): conv1_flops_weight = self.conv1.d_flops_out + self.conv2.d_flops_in conv2_flops_weight = self.conv2.d_flops_out def scale(raw_value): if raw_value is None: return None return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min) def identity(raw_value): return raw_value if scaling: scaling_func = scale else: scaling_func = identity self.conv_flops_weight = (scaling_func(conv1_flops_weight), scaling_func(conv2_flops_weight)) def get_conv_flops_weight(self, update: bool, scaling: bool): # force update # time of update flops weight is very cheap self._compute_flops_weight(scaling=scaling) assert self._conv_flops_weight is not None return self._conv_flops_weight @property def conv_flops_weight(self) -> typing.Tuple[float, float]: """This method is supposed to used in forward pass. To use more argument, call `get_conv_flops_weight`.""" return self.get_conv_flops_weight(update=True, scaling=True) @conv_flops_weight.setter def conv_flops_weight(self, weight: typing.Tuple[float, float]): assert len(weight) == 2, f"The length convolution FLOPs weight should be 2, got {len(weight)}" self._conv_flops_weight = weight def get_sparse_modules(self): if self.gate: return self.gate1, self.gate2 else: return self.bn1, self.bn2
class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes: int, outplanes: int, cfg: List[int], gate: bool, stride=1, downsample=None, expand=False): """ :param inplanes: the input dimension of the block :param outplanes: the output dimension of the block :param cfg: the output dimension of each convolution layer config format: [conv1_out, conv2_out, conv3_out, conv1_in] :param gate: if use gate between conv layers :param stride: the stride of the first convolution layer :param downsample: if use the downsample convolution in the residual connection :param expand: if use ChannelExpand layer in the block """ super(Bottleneck, self).__init__() conv_in = cfg[3] self.use_gate = gate self._identity = False if 0 in cfg or conv_in == 0: # the whole block is pruned self.identity = True else: # the main body of the block # add a SparseGate before the first conv # to enable the pruning of the input dimension for further reducing computational complexity self.select = ChannelSelect(inplanes) # after select, the channel number of feature map is conv_in self.input_gate = SparseGate(conv_in) self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(cfg[0]) self.gate1 = SparseGate(cfg[0]) if gate else Identity() self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(cfg[1]) self.gate2 = SparseGate(cfg[1]) if gate else Identity() self.conv3 = nn.Conv2d(cfg[1], cfg[2], kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(cfg[2]) self.gate3 = SparseGate(cfg[2]) if gate else Identity() self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.expand = expand self.expand_layer = ChannelExpand(outplanes) if expand else None def forward(self, x): residual = x if not self.identity: out = self.select(x) out = self.input_gate(out) out = self.conv1(out) out = self.bn1(out) out = self.gate1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.gate2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out = self.gate3(out) else: # the whole layer is pruned out = x if self.downsample: residual = self.downsample(x) if self.expand_layer: out = self.expand_layer(out) out += residual out = self.relu(out) return out def do_pruning(self, pruner: Pruner, prune_mode: str) -> None: """ Prune the block in place. Note: There is not ChannelExpand layer at the end of the block. After pruning, the output dimension might be changed. There will be dimension conflict :param pruner: the method to determinate the pruning threshold. :param prune_mode: same as `models.common.prune_conv_layer` """ # keep input dim and output dim unchanged # prune conv1 in_channel_mask, input_gate_mask = models.common.prune_conv_layer(conv_layer=self.conv1, bn_layer=self.bn1, # prune the input gate sparse_layer_in=self.input_gate, sparse_layer_out=self.gate1 if self.use_gate else self.bn1, in_channel_mask=None, pruner=pruner, prune_output_mode="prune", prune_mode=prune_mode) # this layer has no channel left. the whole block is pruned if not np.any(in_channel_mask) or not np.any(input_gate_mask): self.identity = True return # prune the input dimension of the first conv layer (conv1) channel_select_idx = np.squeeze(np.argwhere(np.asarray(input_gate_mask))) if len(channel_select_idx.shape) == 0: # expand the single scalar to array channel_select_idx = np.expand_dims(channel_select_idx, 0) elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0: # nothing left # this code should not be executed, if there is no channel left, # the identity will be set as True and return (see code above) raise NotImplementedError("No layer left in input channel") self.select.idx = channel_select_idx self.input_gate.do_pruning(input_gate_mask) # prune conv2 in_channel_mask, _ = models.common.prune_conv_layer(conv_layer=self.conv2, bn_layer=self.bn2, sparse_layer_in=None, sparse_layer_out=self.gate2 if self.use_gate else self.bn2, in_channel_mask=in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode=prune_mode) remain_channel_num = np.sum(in_channel_mask == 1) if remain_channel_num == 0: # this layer has no channel left. the whole block is pruned self.identity = True return # prune conv3 out_channel_mask, _ = models.common.prune_conv_layer(conv_layer=self.conv3, bn_layer=self.bn3, sparse_layer_in=None, sparse_layer_out=self.gate3 if self.use_gate else self.bn3, in_channel_mask=in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode=prune_mode) remain_channel_num = np.sum(out_channel_mask == 1) if remain_channel_num == 0: # this layer has no channel left. the whole block is pruned self.identity = True return # do not prune downsample layers # if need pruning downsample layers (especially with gate), remember to add gate to downsample layer # if self.use_res_connect: # do padding allowing adding with residual connection # the output dim is unchanged # note that the idx of the expander might be set in a pruned model original_expander_idx = self.expand_layer.idx assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent" pruned_expander_idx = original_expander_idx[out_channel_mask] idx = np.squeeze(pruned_expander_idx) self.expand_layer.idx = idx def config(self): if self.identity: return [0] * 4 else: return [self.conv1.out_channels, self.conv2.out_channels, self.conv3.out_channels, self.conv1.in_channels] @property def identity(self): """ If the block as a identity block. Note: When there is a downsample module in the block, the downsample will NOT be pruned. In this case, the block will NOT be a identity mapping. """ return self._identity @identity.setter def identity(self, value): """ Set the block as a identity block. Equivalent to the whole block is pruned. Note: When there is a downsample module in the block, the downsample will NOT be pruned. In this case, the block will NOT be a identity mapping. """ self._identity = value def _compute_flops_weight(self, scaling: bool): # check if flops is computed # the checking might be time-consuming, if the flops is not computed, there will be error # for i in range(1, 4): # conv_layer: nn.Conv2d = getattr(self, f"conv{i}") # if not hasattr(conv_layer, "d_flops_in") or not hasattr(conv_layer, "d_flops_out"): # raise AssertionError("Need compute FLOPs for each conv layer first!") conv1_flops_weight = self.conv1.d_flops_out + self.conv2.d_flops_in conv2_flops_weight = self.conv2.d_flops_out + self.conv3.d_flops_in conv3_flops_weight = self.conv3.d_flops_out def scale(raw_value): return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min) def identity(raw_value): return raw_value if scaling: scaling_func = scale else: scaling_func = identity self.conv_flops_weight = (scaling_func(conv1_flops_weight), scaling_func(conv2_flops_weight), scaling_func(conv3_flops_weight)) def get_conv_flops_weight(self, update: bool, scaling: bool) -> typing.Tuple[float, float, float]: if update: self._compute_flops_weight(scaling=scaling) assert self._conv_flops_weight is not None return self._conv_flops_weight @property def conv_flops_weight(self) -> typing.Tuple[float, float, float]: """This method is supposed to used in forward pass. To use more argument, call `get_conv_flops_weight`.""" return self.get_conv_flops_weight(update=True, scaling=True) @conv_flops_weight.setter def conv_flops_weight(self, weight: typing.Tuple[int, int, int]): assert len(weight) == 3, f"The length convolution FLOPs weight should be 3, got {len(weight)}" self._conv_flops_weight = weight def layer_wise_collection(self) -> List[SparseLayerCollection]: layer_weight: typing.Tuple[float, float, float] = self.get_conv_flops_weight(update=True, scaling=True) collection: List[SparseLayerCollection] = [SparseLayerCollection(conv_layer=self.conv1, bn_layer=self.bn1, sparse_layer=self.gate1 if self.use_gate else self.bn1, layer_weight=layer_weight[0]), SparseLayerCollection(conv_layer=self.conv2, bn_layer=self.bn2, sparse_layer=self.gate2 if self.use_gate else self.bn2, layer_weight=layer_weight[1]), SparseLayerCollection(conv_layer=self.conv3, bn_layer=self.bn3, sparse_layer=self.gate3 if self.use_gate else self.bn3, layer_weight=layer_weight[2])] return collection pass
def __init__(self, inp, oup, stride, conv_in, hidden_dim, use_shortcut_connection: bool, use_gate: bool, input_mask: bool, pw: bool): """ :param inp: the number of input channel of the block :param oup: the number of output channel of the block :param stride: the stride of the deep-wise conv layer :param conv_in: the input dimension of the conv layer :param hidden_dim: the inner dimension of the conv layers :param use_shortcut_connection: if use shortcut connect or not :param use_gate: if use SparseGate layers between conv layers :param input_mask: if use a mask at the beginning of the model. The mask is supposed to replace the first gate before the input. """ super(InvertedResidual, self).__init__() self.stride = stride self.output_channel = oup self.hidden_dim = hidden_dim self._gate = use_gate self._conv_flops_weight: typing.Optional[typing.Tuple[float, float]] = None assert stride in [1, 2] if hidden_dim != 0: # the ChannelSelect layer is supposed to select conv_in channels from inp channels self.select = ChannelSelect(inp) # this gate will not be affected by gate option # the gate should be kept in finetuning stage self.input_gate = None # this part is moved to flat_mobilenet_settings method, so comment it # hidden_dim = int(round(inp * expand_ratio)) # self.use_res_connect = self.stride == 1 and inp == oup self.use_res_connect = use_shortcut_connection layers = [] self.pw = pw # if there is a pixel-wise conv in the block # if hidden_dim != inp: # this part is moved to the config generation method if self.pw: if use_gate or input_mask: self.input_gate = SparseGate(conv_in) # pw (use this bn to prune the hidden_dim) layers.append(ConvBNReLU(conv_in, hidden_dim, kernel_size=1, gate=use_gate)) self.pw = True layers.extend([ # dw (do not apply sparsity on this bn) ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, gate=False), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), # use this bn to prune the output dim nn.BatchNorm2d(oup), ]) if use_gate: layers.append(SparseGate(oup)) else: layers.append(Identity()) layers.append(ChannelExpand(oup)) self.conv = nn.Sequential(*layers) else: self.select = None self.input_gate = None self.conv = None self.use_res_connect = True
class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, conv_in, hidden_dim, use_shortcut_connection: bool, use_gate: bool, input_mask: bool, pw: bool): """ :param inp: the number of input channel of the block :param oup: the number of output channel of the block :param stride: the stride of the deep-wise conv layer :param conv_in: the input dimension of the conv layer :param hidden_dim: the inner dimension of the conv layers :param use_shortcut_connection: if use shortcut connect or not :param use_gate: if use SparseGate layers between conv layers :param input_mask: if use a mask at the beginning of the model. The mask is supposed to replace the first gate before the input. """ super(InvertedResidual, self).__init__() self.stride = stride self.output_channel = oup self.hidden_dim = hidden_dim self._gate = use_gate self._conv_flops_weight: typing.Optional[typing.Tuple[float, float]] = None assert stride in [1, 2] if hidden_dim != 0: # the ChannelSelect layer is supposed to select conv_in channels from inp channels self.select = ChannelSelect(inp) # this gate will not be affected by gate option # the gate should be kept in finetuning stage self.input_gate = None # this part is moved to flat_mobilenet_settings method, so comment it # hidden_dim = int(round(inp * expand_ratio)) # self.use_res_connect = self.stride == 1 and inp == oup self.use_res_connect = use_shortcut_connection layers = [] self.pw = pw # if there is a pixel-wise conv in the block # if hidden_dim != inp: # this part is moved to the config generation method if self.pw: if use_gate or input_mask: self.input_gate = SparseGate(conv_in) # pw (use this bn to prune the hidden_dim) layers.append(ConvBNReLU(conv_in, hidden_dim, kernel_size=1, gate=use_gate)) self.pw = True layers.extend([ # dw (do not apply sparsity on this bn) ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, gate=False), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), # use this bn to prune the output dim nn.BatchNorm2d(oup), ]) if use_gate: layers.append(SparseGate(oup)) else: layers.append(Identity()) layers.append(ChannelExpand(oup)) self.conv = nn.Sequential(*layers) else: self.select = None self.input_gate = None self.conv = None self.use_res_connect = True def forward(self, x): original_input = x if self.conv is not None: # Select channels from input x = self.select(x) if self.input_gate is not None: x = self.input_gate(x) if self.use_res_connect: if self.conv is None: # the whole layer is pruned return original_input return original_input + self.conv(x) else: return self.conv(x) def _prune_whole_layer(self): """set the layer as a identity mapping""" if not self.use_res_connect: raise ValueError("The network will unrelated to the input if prune the whole block without the shortcut.") self.conv = None self.hidden_dim = 0 self.pw = False pass def do_pruning(self, in_channel_mask: np.ndarray, pruner: Pruner): """ Prune the block in place :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1) :param pruner: the method to determinate the pruning threshold. the pruner accepts a torch.Tensor as input and return a threshold """ # prune the pixel-wise conv layer if self.pw: pw_layer = self.conv[0] in_channel_mask, input_gate_mask = prune_conv_layer(conv_layer=pw_layer[0], bn_layer=pw_layer[1], sparse_layer_in=self.input_gate if self.has_input_mask else None, sparse_layer_out=pw_layer.sparse_layer, in_channel_mask=None if self.has_input_mask else in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode='default') if not np.any(in_channel_mask) or not np.any(input_gate_mask): # no channel left self._prune_whole_layer() return self.output_channel channel_select_idx = np.squeeze(np.argwhere(np.asarray(input_gate_mask))) if len(channel_select_idx.shape) == 0: # expand the single scalar to array channel_select_idx = np.expand_dims(channel_select_idx, 0) elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0: # nothing left raise NotImplementedError("No layer left in input channel") self.select.idx = channel_select_idx if self.has_input_mask: self.input_gate.do_pruning(input_gate_mask) # update the hidden dim self.hidden_dim = int(in_channel_mask.astype(np.int).sum()) # prune the output of the dw layer # this in_channel_mask is supposed unchanged dw_layer = self.conv[-5] in_channel_mask, _ = prune_conv_layer(conv_layer=dw_layer[0], bn_layer=dw_layer[1], sparse_layer_in=None, sparse_layer_out=dw_layer.sparse_layer, in_channel_mask=in_channel_mask, pruner=pruner, prune_output_mode="same", prune_mode='default') # prune input of the dw-linear layer (the last layer) out_channel_mask, _ = prune_conv_layer(conv_layer=self.conv[-4], bn_layer=self.conv[-3], sparse_layer_in=None, sparse_layer_out=self.conv[-2] if isinstance(self.conv[-2], SparseGate) else self.conv[ -3], in_channel_mask=in_channel_mask, pruner=pruner, prune_output_mode="prune", prune_mode='default') # update output_channel self.output_channel = int(out_channel_mask.astype(np.int).sum()) # if self.use_res_connect: # do padding allowing adding with residual connection # the output dim is unchanged expander: ChannelExpand = self.conv[-1] # note that the idx of the expander might be set in a pruned model original_expander_idx = expander.idx assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent" pruned_expander_idx = original_expander_idx[out_channel_mask] idx = np.squeeze(pruned_expander_idx) expander.idx = idx pass # return the output dim # the output dim is kept unchanged return expander.channel_num def get_config(self): # format # conv_in, hidden_dim, output_channel, repeat_num, stride, use_shortcut if self.conv is not None: if self.pw: conv_in = self.pw_layer[0].in_channels else: conv_in = self.hidden_dim else: conv_in = 0 return [conv_in, self.hidden_dim, self.output_channel, 1, self.stride, self.use_res_connect, self.pw] def _compute_flops_weight(self, scaling: bool): """ compute the FLOPs weight for each layer according to `d_flops_in` and `d_flops_out` """ # before compute the flops weight, need to # 1. set self.raw_weight_max and self.raw_weight_min # 2. compute d_flops_out and d_flops_in if self.pw: pw_flops_weight = self.conv[0][0].d_flops_out + self.conv[1][0].d_flops_in + self.conv[2].d_flops_in linear_flops_weight = self.conv[2].d_flops_in else: # there is no pixel-wise layer pw_flops_weight = None linear_flops_weight = self.conv[1].d_flops_in def scale(raw_value): if raw_value is None: return None return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min) def identity(raw_value): return raw_value if scaling: scaling_func = scale else: scaling_func = identity self.conv_flops_weight = (scaling_func(pw_flops_weight), scaling_func(linear_flops_weight)) def get_conv_flops_weight(self, update: bool, scaling: bool) -> typing.Tuple[float, float]: # force to update the weight, because _compute_flops_weight is very fast # if update: # self._compute_flops_weight(scaling=scaling) self._compute_flops_weight(scaling=scaling) assert self._conv_flops_weight is not None return self._conv_flops_weight @property def conv_flops_weight(self) -> typing.Tuple[float, float]: """This method is supposed to used in forward pass. To use more argument, call `get_conv_flops_weight`.""" return self.get_conv_flops_weight(update=True, scaling=True) @conv_flops_weight.setter def conv_flops_weight(self, weight: typing.Tuple[int, int, int]): assert len(weight) == 2, f"The length convolution FLOPs weight should be 2, got {len(weight)}" self._conv_flops_weight = weight @property def pw_layer(self) -> typing.Tuple[nn.Conv2d, nn.BatchNorm2d, typing.Optional[SparseGate]]: """get the pixel-wise layer (conv, bn, gate)""" if not self.pw: return [None, None, None] if self.conv is None: return [None, None, None] pw_conv_bn_relu = self.conv[0] pw_conv = pw_conv_bn_relu[0] pw_bn = pw_conv_bn_relu[1] if self._gate: pw_gate = pw_conv_bn_relu[2] else: pw_gate = None return pw_conv, pw_bn, pw_gate @property def has_input_mask(self) -> bool: return self.input_gate is not None @property def linear_layer(self) -> typing.Tuple[nn.Conv2d, nn.BatchNorm2d, typing.Optional[SparseGate]]: """get the linear layer (conv, bn, gate)""" if self.pw: linear_conv_idx = 2 else: linear_conv_idx = 1 linear_conv = self.conv[linear_conv_idx] linear_bn = self.conv[linear_conv_idx + 1] if self._gate: linear_gate = self.conv[linear_conv_idx + 2] else: linear_gate = None return linear_conv, linear_bn, linear_gate