コード例 #1
0
    def __init__(self, inplanes: int, outplanes: int, cfg: List[int], gate: bool, stride=1, downsample=None,
                 expand=False):
        """

        :param inplanes: the input dimension of the block
        :param outplanes: the output dimension of the block
        :param cfg: the output dimension of each convolution layer
            config format:
            [conv1_out, conv2_out, conv3_out, conv1_in]
        :param gate: if use gate between conv layers
        :param stride: the stride of the first convolution layer
        :param downsample: if use the downsample convolution in the residual connection
        :param expand: if use ChannelExpand layer in the block
        """
        super(Bottleneck, self).__init__()

        conv_in = cfg[3]
        self.use_gate = gate
        self._identity = False
        if 0 in cfg or conv_in == 0:
            # the whole block is pruned
            self.identity = True
        else:
            # the main body of the block

            # add a SparseGate before the first conv
            # to enable the pruning of the input dimension for further reducing computational complexity
            self.select = ChannelSelect(inplanes)  # after select, the channel number of feature map is conv_in
            self.input_gate = SparseGate(conv_in)

            self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(cfg[0])
            self.gate1 = SparseGate(cfg[0]) if gate else Identity()

            self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(cfg[1])
            self.gate2 = SparseGate(cfg[1]) if gate else Identity()

            self.conv3 = nn.Conv2d(cfg[1], cfg[2], kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(cfg[2])
            self.gate3 = SparseGate(cfg[2]) if gate else Identity()

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        self.expand = expand
        self.expand_layer = ChannelExpand(outplanes) if expand else None
コード例 #2
0
    def __init__(self, in_planes, outplanes, cfg, stride=1, option='A', gate=False, use_input_mask=False):
        super(BasicBlock, self).__init__()
        self.gate = gate
        self.use_input_mask = use_input_mask
        conv_in = cfg[2]

        if len(cfg) != 3:
            raise ValueError("cfg len should be 3, got {}".format(cfg))

        self.is_empty_block = 0 in cfg

        if not self.is_empty_block:
            if self.use_input_mask:
                # input channel: in_planes, output_channel: conv_in
                self.input_channel_selector = ChannelSelect(in_planes)
                # input channel: conv_in, output_channel: conv_in
                self.input_mask = SparseGate(conv_in)
            else:
                self.input_channel_selector = None
                self.input_mask = None
            self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(cfg[0])
            self.gate1 = SparseGate(cfg[0]) if self.gate else Identity()

            self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(cfg[1])
            self.gate2 = SparseGate(cfg[1]) if self.gate else Identity()

            self.expand_layer = ChannelExpand(outplanes)
        else:
            self.conv1 = Identity()
            self.bn1 = Identity()
            self.conv2 = Identity()
            self.bn2 = Identity()

            self.expand_layer = Identity()

        self.shortcut = nn.Sequential()  # do nothing
        if stride != 1 or in_planes != outplanes:
            if option == 'A':
                """For CIFAR10 ResNet paper uses option A."""
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (
                                                0, 0, 0, 0, (outplanes - in_planes) // 2, (outplanes - in_planes) // 2),
                                                  "constant",
                                                  0))
            elif option == 'B':
                raise NotImplementedError("Option B is not implemented")
コード例 #3
0
ファイル: vgg.py プロジェクト: xue1234730/Prune
    def __init__(self, conv: nn.Conv2d, batch_norm: bool, output_channel: int,
                 gate: bool):
        super().__init__()
        self.conv = conv
        self.gate: bool = gate

        if batch_norm:
            if isinstance(self.conv, nn.Conv2d):
                self.batch_norm = nn.BatchNorm2d(output_channel)
            elif isinstance(self.conv, nn.Linear):
                self.batch_norm = nn.BatchNorm1d(output_channel)
        else:
            self.batch_norm = Identity()  #返回输入

        if gate:
            self.sparse_gate = SparseGate(output_channel)

        self.relu = nn.ReLU(inplace=True)
コード例 #4
0
    def __init__(self, in_planes, out_planes, gate: bool, kernel_size=3, stride=1, groups=1):
        """
        A sequence of modules
            - Conv
            - BN
            - Gate (optional, if gate is True)
            - ReLU
        """
        padding = (kernel_size - 1) // 2

        layers = [
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
        ]
        if gate:
            layers.append(SparseGate(out_planes))
        layers.append(nn.ReLU6(inplace=True))
        super(ConvBNReLU, self).__init__(*layers)
コード例 #5
0
class BasicBlock(BuildingBlock):
    expansion = 1

    def __init__(self, in_planes, outplanes, cfg, stride=1, option='A', gate=False, use_input_mask=False):
        super(BasicBlock, self).__init__()
        self.gate = gate
        self.use_input_mask = use_input_mask
        conv_in = cfg[2]

        if len(cfg) != 3:
            raise ValueError("cfg len should be 3, got {}".format(cfg))

        self.is_empty_block = 0 in cfg

        if not self.is_empty_block:
            if self.use_input_mask:
                # input channel: in_planes, output_channel: conv_in
                self.input_channel_selector = ChannelSelect(in_planes)
                # input channel: conv_in, output_channel: conv_in
                self.input_mask = SparseGate(conv_in)
            else:
                self.input_channel_selector = None
                self.input_mask = None
            self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(cfg[0])
            self.gate1 = SparseGate(cfg[0]) if self.gate else Identity()

            self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(cfg[1])
            self.gate2 = SparseGate(cfg[1]) if self.gate else Identity()

            self.expand_layer = ChannelExpand(outplanes)
        else:
            self.conv1 = Identity()
            self.bn1 = Identity()
            self.conv2 = Identity()
            self.bn2 = Identity()

            self.expand_layer = Identity()

        self.shortcut = nn.Sequential()  # do nothing
        if stride != 1 or in_planes != outplanes:
            if option == 'A':
                """For CIFAR10 ResNet paper uses option A."""
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (
                                                0, 0, 0, 0, (outplanes - in_planes) // 2, (outplanes - in_planes) // 2),
                                                  "constant",
                                                  0))
            elif option == 'B':
                raise NotImplementedError("Option B is not implemented")
                # self.shortcut = nn.Sequential(
                #     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                #     nn.BatchNorm2d(self.expansion * planes)
                # )

    def forward(self, x):
        if self.is_empty_block:
            out = self.shortcut(x)
            out = F.relu(out)
            return out
        else:
            if self.use_input_mask:
                out = self.input_channel_selector(x)
                out = self.input_mask(out)
            else:
                out = x

            # relu-gate and gate-relu is same
            out = F.relu(self.bn1(self.conv1(out)))
            out = self.gate1(out)

            out = self.bn2(self.conv2(out))
            out = self.gate2(out)

            out = self.expand_layer(out)

            out += self.shortcut(x)
            out = F.relu(out)
            return out

    def do_pruning(self, pruner: Callable[[np.ndarray], float], prune_mode: str,
                   in_channel_mask: np.ndarray = None, prune_on=None) -> None:
        """
        Prune the block in place.
        Note: There is not ChannelExpand layer at the end of the block. After pruning, the output dimension might be
        changed. There will be dimension conflict

        :param pruner: the method to determinate the pruning threshold.
        :param prune_mode: same as `models.common.prune_conv_layer`

        """
        if in_channel_mask is not None:
            raise ValueError("Do not set in_channel_mask")

        if self.is_empty_block:
            return

        # keep input dim and output dim unchanged
        in_channel_mask = np.ones(self.conv1.in_channels)

        # prune conv1

        in_channel_mask, conv1_input_channel_mask = prune_conv_layer(conv_layer=self.conv1,
                                                                     bn_layer=self.bn1,
                                                                     sparse_layer_in=self.input_mask,
                                                                     sparse_layer=self.gate1 if self.gate else self.bn1,
                                                                     in_channel_mask=None if self.use_input_mask else in_channel_mask,
                                                                     pruner=pruner,
                                                                     prune_output_mode="prune",
                                                                     prune_mode=prune_mode,
                                                                     prune_on=prune_on, )
        if not np.any(in_channel_mask) or not np.any(conv1_input_channel_mask):
            # prune the entire block
            self.is_empty_block = True
            return

        if self.use_input_mask:
            # prune the input dimension of the first conv layer (conv1)
            channel_select_idx = np.squeeze(np.argwhere(np.asarray(conv1_input_channel_mask)))
            if len(channel_select_idx.shape) == 0:
                # expand the single scalar to array
                channel_select_idx = np.expand_dims(channel_select_idx, 0)
            elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0:
                # nothing left
                # this code should not be executed, if there is no channel left,
                # the identity will be set as True and return (see code above)
                raise NotImplementedError("No layer left in input channel")
            self.input_channel_selector.idx = channel_select_idx
            self.input_mask.do_pruning(conv1_input_channel_mask)

        # prune conv2
        out_channel_mask, _ = prune_conv_layer(conv_layer=self.conv2,
                                               bn_layer=self.bn2,
                                               sparse_layer=self.gate2 if self.gate else self.bn2,
                                               in_channel_mask=in_channel_mask,
                                               pruner=pruner,
                                               prune_output_mode="prune",
                                               prune_mode=prune_mode,
                                               prune_on=prune_on, )
        if not np.any(out_channel_mask):
            # prune the entire block
            self.is_empty_block = True
            return

        # do padding allowing adding with residual connection
        # the output dim is unchanged
        # note that the idx of the expander might be set in a pruned model
        original_expander_idx = self.expand_layer.idx
        assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent"
        pruned_expander_idx = original_expander_idx[out_channel_mask]
        idx = np.squeeze(pruned_expander_idx)
        if len(idx.shape) == 0:
            # expand 0-d idx
            idx = np.expand_dims(idx, 0)
            pass
        self.expand_layer.idx = idx

    def config(self) -> typing.Tuple[int, int, int]:
        if self.is_empty_block:
            return 0, 0, 0
        return self.conv1.out_channels, self.conv2.out_channels, self.conv1.in_channels

    @property
    def expand_idx(self):
        raise NotImplementedError()
        if self.is_empty_block:
            return None
        return self.expand_layer.idx

    @expand_idx.setter
    def expand_idx(self, value):
        if self.is_empty_block:
            if value is not None:
                raise ValueError(f"The expand_idx of the empty block is supposed to be None, got {value}")
            # do nothing for empty block
        self.expand_layer.idx = value

    def _compute_flops_weight(self, scaling):
        conv1_flops_weight = self.conv1.d_flops_out + self.conv2.d_flops_in
        conv2_flops_weight = self.conv2.d_flops_out

        def scale(raw_value):
            if raw_value is None:
                return None
            return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min)

        def identity(raw_value):
            return raw_value

        if scaling:
            scaling_func = scale
        else:
            scaling_func = identity

        self.conv_flops_weight = (scaling_func(conv1_flops_weight),
                                  scaling_func(conv2_flops_weight))

    def get_conv_flops_weight(self, update: bool, scaling: bool):
        # force update
        # time of update flops weight is very cheap
        self._compute_flops_weight(scaling=scaling)

        assert self._conv_flops_weight is not None
        return self._conv_flops_weight

    @property
    def conv_flops_weight(self) -> typing.Tuple[float, float]:
        """This method is supposed to used in forward pass.
        To use more argument, call `get_conv_flops_weight`."""
        return self.get_conv_flops_weight(update=True, scaling=True)

    @conv_flops_weight.setter
    def conv_flops_weight(self, weight: typing.Tuple[float, float]):
        assert len(weight) == 2, f"The length convolution FLOPs weight should be 2, got {len(weight)}"
        self._conv_flops_weight = weight

    def get_sparse_modules(self):
        if self.gate:
            return self.gate1, self.gate2
        else:
            return self.bn1, self.bn2
コード例 #6
0
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes: int, outplanes: int, cfg: List[int], gate: bool, stride=1, downsample=None,
                 expand=False):
        """

        :param inplanes: the input dimension of the block
        :param outplanes: the output dimension of the block
        :param cfg: the output dimension of each convolution layer
            config format:
            [conv1_out, conv2_out, conv3_out, conv1_in]
        :param gate: if use gate between conv layers
        :param stride: the stride of the first convolution layer
        :param downsample: if use the downsample convolution in the residual connection
        :param expand: if use ChannelExpand layer in the block
        """
        super(Bottleneck, self).__init__()

        conv_in = cfg[3]
        self.use_gate = gate
        self._identity = False
        if 0 in cfg or conv_in == 0:
            # the whole block is pruned
            self.identity = True
        else:
            # the main body of the block

            # add a SparseGate before the first conv
            # to enable the pruning of the input dimension for further reducing computational complexity
            self.select = ChannelSelect(inplanes)  # after select, the channel number of feature map is conv_in
            self.input_gate = SparseGate(conv_in)

            self.conv1 = nn.Conv2d(conv_in, cfg[0], kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(cfg[0])
            self.gate1 = SparseGate(cfg[0]) if gate else Identity()

            self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(cfg[1])
            self.gate2 = SparseGate(cfg[1]) if gate else Identity()

            self.conv3 = nn.Conv2d(cfg[1], cfg[2], kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(cfg[2])
            self.gate3 = SparseGate(cfg[2]) if gate else Identity()

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        self.expand = expand
        self.expand_layer = ChannelExpand(outplanes) if expand else None

    def forward(self, x):
        residual = x

        if not self.identity:
            out = self.select(x)
            out = self.input_gate(out)

            out = self.conv1(out)
            out = self.bn1(out)
            out = self.gate1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)
            out = self.gate2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.bn3(out)
            out = self.gate3(out)
        else:
            # the whole layer is pruned
            out = x

        if self.downsample:
            residual = self.downsample(x)

        if self.expand_layer:
            out = self.expand_layer(out)

        out += residual
        out = self.relu(out)

        return out

    def do_pruning(self, pruner: Pruner, prune_mode: str) -> None:
        """
        Prune the block in place.
        Note: There is not ChannelExpand layer at the end of the block. After pruning, the output dimension might be
        changed. There will be dimension conflict

        :param pruner: the method to determinate the pruning threshold.
        :param prune_mode: same as `models.common.prune_conv_layer`

        """

        # keep input dim and output dim unchanged
        # prune conv1
        in_channel_mask, input_gate_mask = models.common.prune_conv_layer(conv_layer=self.conv1,
                                                                          bn_layer=self.bn1,
                                                                          # prune the input gate
                                                                          sparse_layer_in=self.input_gate,
                                                                          sparse_layer_out=self.gate1 if self.use_gate else self.bn1,
                                                                          in_channel_mask=None,
                                                                          pruner=pruner,
                                                                          prune_output_mode="prune",
                                                                          prune_mode=prune_mode)
        # this layer has no channel left. the whole block is pruned
        if not np.any(in_channel_mask) or not np.any(input_gate_mask):
            self.identity = True
            return

        # prune the input dimension of the first conv layer (conv1)
        channel_select_idx = np.squeeze(np.argwhere(np.asarray(input_gate_mask)))
        if len(channel_select_idx.shape) == 0:
            # expand the single scalar to array
            channel_select_idx = np.expand_dims(channel_select_idx, 0)
        elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0:
            # nothing left
            # this code should not be executed, if there is no channel left,
            # the identity will be set as True and return (see code above)
            raise NotImplementedError("No layer left in input channel")
        self.select.idx = channel_select_idx
        self.input_gate.do_pruning(input_gate_mask)

        # prune conv2
        in_channel_mask, _ = models.common.prune_conv_layer(conv_layer=self.conv2,
                                                            bn_layer=self.bn2,
                                                            sparse_layer_in=None,
                                                            sparse_layer_out=self.gate2 if self.use_gate else self.bn2,
                                                            in_channel_mask=in_channel_mask,
                                                            pruner=pruner,
                                                            prune_output_mode="prune",
                                                            prune_mode=prune_mode)

        remain_channel_num = np.sum(in_channel_mask == 1)
        if remain_channel_num == 0:
            # this layer has no channel left. the whole block is pruned
            self.identity = True
            return

            # prune conv3
        out_channel_mask, _ = models.common.prune_conv_layer(conv_layer=self.conv3,
                                                             bn_layer=self.bn3,
                                                             sparse_layer_in=None,
                                                             sparse_layer_out=self.gate3 if self.use_gate else self.bn3,
                                                             in_channel_mask=in_channel_mask,
                                                             pruner=pruner,
                                                             prune_output_mode="prune",
                                                             prune_mode=prune_mode)

        remain_channel_num = np.sum(out_channel_mask == 1)
        if remain_channel_num == 0:
            # this layer has no channel left. the whole block is pruned
            self.identity = True
            return

        # do not prune downsample layers
        # if need pruning downsample layers (especially with gate), remember to add gate to downsample layer

        # if self.use_res_connect:
        # do padding allowing adding with residual connection
        # the output dim is unchanged
        # note that the idx of the expander might be set in a pruned model
        original_expander_idx = self.expand_layer.idx
        assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent"
        pruned_expander_idx = original_expander_idx[out_channel_mask]
        idx = np.squeeze(pruned_expander_idx)
        self.expand_layer.idx = idx

    def config(self):
        if self.identity:
            return [0] * 4
        else:
            return [self.conv1.out_channels, self.conv2.out_channels, self.conv3.out_channels, self.conv1.in_channels]

    @property
    def identity(self):
        """
        If the block as a identity block.
        Note: When there is a downsample module in the block, the downsample will NOT
        be pruned. In this case, the block will NOT be a identity mapping.
        """
        return self._identity

    @identity.setter
    def identity(self, value):
        """
        Set the block as a identity block. Equivalent to the whole block is pruned.
        Note: When there is a downsample module in the block, the downsample will NOT
        be pruned. In this case, the block will NOT be a identity mapping.
        """
        self._identity = value

    def _compute_flops_weight(self, scaling: bool):
        # check if flops is computed
        # the checking might be time-consuming, if the flops is not computed, there will be error
        # for i in range(1, 4):
        #     conv_layer: nn.Conv2d = getattr(self, f"conv{i}")
        #     if not hasattr(conv_layer, "d_flops_in") or not hasattr(conv_layer, "d_flops_out"):
        #         raise AssertionError("Need compute FLOPs for each conv layer first!")

        conv1_flops_weight = self.conv1.d_flops_out + self.conv2.d_flops_in
        conv2_flops_weight = self.conv2.d_flops_out + self.conv3.d_flops_in
        conv3_flops_weight = self.conv3.d_flops_out

        def scale(raw_value):
            return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min)

        def identity(raw_value):
            return raw_value

        if scaling:
            scaling_func = scale
        else:
            scaling_func = identity

        self.conv_flops_weight = (scaling_func(conv1_flops_weight),
                                  scaling_func(conv2_flops_weight),
                                  scaling_func(conv3_flops_weight))

    def get_conv_flops_weight(self, update: bool, scaling: bool) -> typing.Tuple[float, float, float]:
        if update:
            self._compute_flops_weight(scaling=scaling)

        assert self._conv_flops_weight is not None
        return self._conv_flops_weight

    @property
    def conv_flops_weight(self) -> typing.Tuple[float, float, float]:
        """This method is supposed to used in forward pass.
        To use more argument, call `get_conv_flops_weight`."""
        return self.get_conv_flops_weight(update=True, scaling=True)

    @conv_flops_weight.setter
    def conv_flops_weight(self, weight: typing.Tuple[int, int, int]):
        assert len(weight) == 3, f"The length convolution FLOPs weight should be 3, got {len(weight)}"
        self._conv_flops_weight = weight

    def layer_wise_collection(self) -> List[SparseLayerCollection]:
        layer_weight: typing.Tuple[float, float, float] = self.get_conv_flops_weight(update=True, scaling=True)
        collection: List[SparseLayerCollection] = [SparseLayerCollection(conv_layer=self.conv1,
                                                                         bn_layer=self.bn1,
                                                                         sparse_layer=self.gate1 if self.use_gate else self.bn1,
                                                                         layer_weight=layer_weight[0]),
                                                   SparseLayerCollection(conv_layer=self.conv2,
                                                                         bn_layer=self.bn2,
                                                                         sparse_layer=self.gate2 if self.use_gate else self.bn2,
                                                                         layer_weight=layer_weight[1]),
                                                   SparseLayerCollection(conv_layer=self.conv3,
                                                                         bn_layer=self.bn3,
                                                                         sparse_layer=self.gate3 if self.use_gate else self.bn3,
                                                                         layer_weight=layer_weight[2])]
        return collection
        pass
コード例 #7
0
    def __init__(self, inp, oup, stride, conv_in, hidden_dim, use_shortcut_connection: bool,
                 use_gate: bool, input_mask: bool, pw: bool):
        """
        :param inp: the number of input channel of the block
        :param oup: the number of output channel of the block
        :param stride: the stride of the deep-wise conv layer
        :param conv_in: the input dimension of the conv layer
        :param hidden_dim: the inner dimension of the conv layers
        :param use_shortcut_connection: if use shortcut connect or not
        :param use_gate: if use SparseGate layers between conv layers

        :param input_mask: if use a mask at the beginning of the model.
        The mask is supposed to replace the first gate before the input.
        """
        super(InvertedResidual, self).__init__()

        self.stride = stride
        self.output_channel = oup
        self.hidden_dim = hidden_dim

        self._gate = use_gate
        self._conv_flops_weight: typing.Optional[typing.Tuple[float, float]] = None
        assert stride in [1, 2]

        if hidden_dim != 0:

            # the ChannelSelect layer is supposed to select conv_in channels from inp channels
            self.select = ChannelSelect(inp)
            # this gate will not be affected by gate option
            # the gate should be kept in finetuning stage
            self.input_gate = None

            # this part is moved to flat_mobilenet_settings method, so comment it
            # hidden_dim = int(round(inp * expand_ratio))
            # self.use_res_connect = self.stride == 1 and inp == oup
            self.use_res_connect = use_shortcut_connection

            layers = []
            self.pw = pw  # if there is a pixel-wise conv in the block
            # if hidden_dim != inp: # this part is moved to the config generation method
            if self.pw:
                if use_gate or input_mask:
                    self.input_gate = SparseGate(conv_in)

                # pw (use this bn to prune the hidden_dim)
                layers.append(ConvBNReLU(conv_in, hidden_dim, kernel_size=1, gate=use_gate))
                self.pw = True
            layers.extend([
                # dw (do not apply sparsity on this bn)
                ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, gate=False),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                # use this bn to prune the output dim
                nn.BatchNorm2d(oup),
            ])
            if use_gate:
                layers.append(SparseGate(oup))
            else:
                layers.append(Identity())
            layers.append(ChannelExpand(oup))
            self.conv = nn.Sequential(*layers)
        else:
            self.select = None
            self.input_gate = None
            self.conv = None
            self.use_res_connect = True
コード例 #8
0
class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, conv_in, hidden_dim, use_shortcut_connection: bool,
                 use_gate: bool, input_mask: bool, pw: bool):
        """
        :param inp: the number of input channel of the block
        :param oup: the number of output channel of the block
        :param stride: the stride of the deep-wise conv layer
        :param conv_in: the input dimension of the conv layer
        :param hidden_dim: the inner dimension of the conv layers
        :param use_shortcut_connection: if use shortcut connect or not
        :param use_gate: if use SparseGate layers between conv layers

        :param input_mask: if use a mask at the beginning of the model.
        The mask is supposed to replace the first gate before the input.
        """
        super(InvertedResidual, self).__init__()

        self.stride = stride
        self.output_channel = oup
        self.hidden_dim = hidden_dim

        self._gate = use_gate
        self._conv_flops_weight: typing.Optional[typing.Tuple[float, float]] = None
        assert stride in [1, 2]

        if hidden_dim != 0:

            # the ChannelSelect layer is supposed to select conv_in channels from inp channels
            self.select = ChannelSelect(inp)
            # this gate will not be affected by gate option
            # the gate should be kept in finetuning stage
            self.input_gate = None

            # this part is moved to flat_mobilenet_settings method, so comment it
            # hidden_dim = int(round(inp * expand_ratio))
            # self.use_res_connect = self.stride == 1 and inp == oup
            self.use_res_connect = use_shortcut_connection

            layers = []
            self.pw = pw  # if there is a pixel-wise conv in the block
            # if hidden_dim != inp: # this part is moved to the config generation method
            if self.pw:
                if use_gate or input_mask:
                    self.input_gate = SparseGate(conv_in)

                # pw (use this bn to prune the hidden_dim)
                layers.append(ConvBNReLU(conv_in, hidden_dim, kernel_size=1, gate=use_gate))
                self.pw = True
            layers.extend([
                # dw (do not apply sparsity on this bn)
                ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, gate=False),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                # use this bn to prune the output dim
                nn.BatchNorm2d(oup),
            ])
            if use_gate:
                layers.append(SparseGate(oup))
            else:
                layers.append(Identity())
            layers.append(ChannelExpand(oup))
            self.conv = nn.Sequential(*layers)
        else:
            self.select = None
            self.input_gate = None
            self.conv = None
            self.use_res_connect = True

    def forward(self, x):
        original_input = x

        if self.conv is not None:
            # Select channels from input
            x = self.select(x)
            if self.input_gate is not None:
                x = self.input_gate(x)

        if self.use_res_connect:
            if self.conv is None:
                # the whole layer is pruned
                return original_input
            return original_input + self.conv(x)
        else:
            return self.conv(x)

    def _prune_whole_layer(self):
        """set the layer as a identity mapping"""
        if not self.use_res_connect:
            raise ValueError("The network will unrelated to the input if prune the whole block without the shortcut.")
        self.conv = None
        self.hidden_dim = 0
        self.pw = False
        pass

    def do_pruning(self, in_channel_mask: np.ndarray, pruner: Pruner):
        """
        Prune the block in place
        :param in_channel_mask: a 0-1 vector indicates whether the corresponding channel should be pruned (0) or not (1)
        :param pruner: the method to determinate the pruning threshold.
        the pruner accepts a torch.Tensor as input and return a threshold
        """
        # prune the pixel-wise conv layer
        if self.pw:
            pw_layer = self.conv[0]
            in_channel_mask, input_gate_mask = prune_conv_layer(conv_layer=pw_layer[0],
                                                                bn_layer=pw_layer[1],
                                                                sparse_layer_in=self.input_gate if self.has_input_mask else None,
                                                                sparse_layer_out=pw_layer.sparse_layer,
                                                                in_channel_mask=None if self.has_input_mask else in_channel_mask,
                                                                pruner=pruner,
                                                                prune_output_mode="prune",
                                                                prune_mode='default')
            if not np.any(in_channel_mask) or not np.any(input_gate_mask):
                # no channel left
                self._prune_whole_layer()
                return self.output_channel

            channel_select_idx = np.squeeze(np.argwhere(np.asarray(input_gate_mask)))
            if len(channel_select_idx.shape) == 0:
                # expand the single scalar to array
                channel_select_idx = np.expand_dims(channel_select_idx, 0)
            elif len(channel_select_idx.shape) == 1 and channel_select_idx.shape[0] == 0:
                # nothing left
                raise NotImplementedError("No layer left in input channel")
            self.select.idx = channel_select_idx
            if self.has_input_mask:
                self.input_gate.do_pruning(input_gate_mask)

        # update the hidden dim
        self.hidden_dim = int(in_channel_mask.astype(np.int).sum())

        # prune the output of the dw layer
        # this in_channel_mask is supposed unchanged
        dw_layer = self.conv[-5]
        in_channel_mask, _ = prune_conv_layer(conv_layer=dw_layer[0],
                                              bn_layer=dw_layer[1],
                                              sparse_layer_in=None,
                                              sparse_layer_out=dw_layer.sparse_layer,
                                              in_channel_mask=in_channel_mask,
                                              pruner=pruner,
                                              prune_output_mode="same",
                                              prune_mode='default')

        # prune input of the dw-linear layer (the last layer)
        out_channel_mask, _ = prune_conv_layer(conv_layer=self.conv[-4],
                                               bn_layer=self.conv[-3],
                                               sparse_layer_in=None,
                                               sparse_layer_out=self.conv[-2] if isinstance(self.conv[-2],
                                                                                            SparseGate) else self.conv[
                                                   -3],
                                               in_channel_mask=in_channel_mask,
                                               pruner=pruner,
                                               prune_output_mode="prune",
                                               prune_mode='default')

        # update output_channel
        self.output_channel = int(out_channel_mask.astype(np.int).sum())

        # if self.use_res_connect:
        # do padding allowing adding with residual connection
        # the output dim is unchanged
        expander: ChannelExpand = self.conv[-1]
        # note that the idx of the expander might be set in a pruned model
        original_expander_idx = expander.idx
        assert len(original_expander_idx) == len(out_channel_mask), "the output channel should be consistent"
        pruned_expander_idx = original_expander_idx[out_channel_mask]
        idx = np.squeeze(pruned_expander_idx)
        expander.idx = idx
        pass

        # return the output dim
        # the output dim is kept unchanged
        return expander.channel_num

    def get_config(self):
        # format
        # conv_in, hidden_dim, output_channel, repeat_num, stride, use_shortcut
        if self.conv is not None:
            if self.pw:
                conv_in = self.pw_layer[0].in_channels
            else:
                conv_in = self.hidden_dim
        else:
            conv_in = 0

        return [conv_in, self.hidden_dim, self.output_channel, 1, self.stride, self.use_res_connect, self.pw]

    def _compute_flops_weight(self, scaling: bool):
        """
        compute the FLOPs weight for each layer according to `d_flops_in` and `d_flops_out`
        """
        # before compute the flops weight, need to
        # 1. set self.raw_weight_max and self.raw_weight_min
        # 2. compute d_flops_out and d_flops_in

        if self.pw:
            pw_flops_weight = self.conv[0][0].d_flops_out + self.conv[1][0].d_flops_in + self.conv[2].d_flops_in
            linear_flops_weight = self.conv[2].d_flops_in
        else:
            # there is no pixel-wise layer
            pw_flops_weight = None
            linear_flops_weight = self.conv[1].d_flops_in

        def scale(raw_value):
            if raw_value is None:
                return None
            return (raw_value - self.raw_weight_min) / (self.raw_weight_max - self.raw_weight_min)

        def identity(raw_value):
            return raw_value

        if scaling:
            scaling_func = scale
        else:
            scaling_func = identity

        self.conv_flops_weight = (scaling_func(pw_flops_weight),
                                  scaling_func(linear_flops_weight))

    def get_conv_flops_weight(self, update: bool, scaling: bool) -> typing.Tuple[float, float]:
        # force to update the weight, because _compute_flops_weight is very fast
        # if update:
        #     self._compute_flops_weight(scaling=scaling)
        self._compute_flops_weight(scaling=scaling)

        assert self._conv_flops_weight is not None
        return self._conv_flops_weight

    @property
    def conv_flops_weight(self) -> typing.Tuple[float, float]:
        """This method is supposed to used in forward pass.
        To use more argument, call `get_conv_flops_weight`."""
        return self.get_conv_flops_weight(update=True, scaling=True)

    @conv_flops_weight.setter
    def conv_flops_weight(self, weight: typing.Tuple[int, int, int]):
        assert len(weight) == 2, f"The length convolution FLOPs weight should be 2, got {len(weight)}"
        self._conv_flops_weight = weight

    @property
    def pw_layer(self) -> typing.Tuple[nn.Conv2d, nn.BatchNorm2d, typing.Optional[SparseGate]]:
        """get the pixel-wise layer (conv, bn, gate)"""
        if not self.pw:
            return [None, None, None]
        if self.conv is None:
            return [None, None, None]
        pw_conv_bn_relu = self.conv[0]
        pw_conv = pw_conv_bn_relu[0]
        pw_bn = pw_conv_bn_relu[1]
        if self._gate:
            pw_gate = pw_conv_bn_relu[2]
        else:
            pw_gate = None

        return pw_conv, pw_bn, pw_gate

    @property
    def has_input_mask(self) -> bool:
        return self.input_gate is not None

    @property
    def linear_layer(self) -> typing.Tuple[nn.Conv2d, nn.BatchNorm2d, typing.Optional[SparseGate]]:
        """get the linear layer (conv, bn, gate)"""
        if self.pw:
            linear_conv_idx = 2
        else:
            linear_conv_idx = 1

        linear_conv = self.conv[linear_conv_idx]
        linear_bn = self.conv[linear_conv_idx + 1]
        if self._gate:
            linear_gate = self.conv[linear_conv_idx + 2]
        else:
            linear_gate = None

        return linear_conv, linear_bn, linear_gate