Пример #1
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride,
                 dilation,
                 expand_ratio,
                 norm_cfg,
                 use_res_connect=False):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2], 'stride must 1 or 2'

        self.use_res_connect = use_res_connect
        self.kernel_size = 3
        self.dilation = dilation

        if expand_ratio == 1:
            self.conv = DepthwiseSeparableConvModule(
                in_channels,
                out_channels,
                3,
                padding=1,
                stride=stride,
                dilation=dilation,
                norm_cfg=norm_cfg,
                dw_act_cfg=dict(type='ReLU6'),
                pw_act_cfg=None)
        else:
            hidden_dim = round(in_channels * expand_ratio)
            self.conv = nn.Sequential(
                ConvModule(
                    in_channels,
                    hidden_dim,
                    1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    act_cfg=dict(type='ReLU6')),
                DepthwiseSeparableConvModule(
                    hidden_dim,
                    out_channels,
                    3,
                    stride=stride,
                    dilation=dilation,
                    norm_cfg=norm_cfg,
                    dw_act_cfg=dict(type='ReLU6'),
                    pw_act_cfg=None))
Пример #2
0
def test_depthwise_separable_conv():
    with pytest.raises(AssertionError):
        # conv_cfg must be a dict or None
        DepthwiseSeparableConvModule(4, 8, 2, groups=2)

    # test default config
    conv = DepthwiseSeparableConvModule(3, 8, 2)
    assert conv.depthwise_conv.conv.groups == 3
    assert conv.pointwise_conv.conv.kernel_size == (1, 1)
    assert not conv.depthwise_conv.with_norm
    assert not conv.pointwise_conv.with_norm
    assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
    assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    # test
    conv = DepthwiseSeparableConvModule(3, 8, 2, dw_norm_cfg=dict(type='BN'))
    assert conv.depthwise_conv.norm_name == 'bn'
    assert not conv.pointwise_conv.with_norm
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    conv = DepthwiseSeparableConvModule(3, 8, 2, pw_norm_cfg=dict(type='BN'))
    assert not conv.depthwise_conv.with_norm
    assert conv.pointwise_conv.norm_name == 'bn'
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    # add test for ['norm', 'conv', 'act']
    conv = DepthwiseSeparableConvModule(3, 8, 2, order=('norm', 'conv', 'act'))
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    conv = DepthwiseSeparableConvModule(3,
                                        8,
                                        3,
                                        padding=1,
                                        with_spectral_norm=True)
    assert hasattr(conv.depthwise_conv.conv, 'weight_orig')
    assert hasattr(conv.pointwise_conv.conv, 'weight_orig')
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    conv = DepthwiseSeparableConvModule(3,
                                        8,
                                        3,
                                        padding=1,
                                        padding_mode='reflect')
    assert isinstance(conv.depthwise_conv.padding_layer, nn.ReflectionPad2d)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    conv = DepthwiseSeparableConvModule(3,
                                        8,
                                        3,
                                        padding=1,
                                        dw_act_cfg=dict(type='LeakyReLU'))
    assert conv.depthwise_conv.activate.__class__.__name__ == 'LeakyReLU'
    assert conv.pointwise_conv.activate.__class__.__name__ == 'ReLU'
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    conv = DepthwiseSeparableConvModule(3,
                                        8,
                                        3,
                                        padding=1,
                                        pw_act_cfg=dict(type='LeakyReLU'))
    assert conv.depthwise_conv.activate.__class__.__name__ == 'ReLU'
    assert conv.pointwise_conv.activate.__class__.__name__ == 'LeakyReLU'
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)