Ejemplo n.º 1
0
    def __init__(self,
                 in_channels=3,
                 num_layers=7,
                 conv_cfg=dict(type='PConv', multi_channel=True),
                 norm_cfg=dict(type='BN', requires_grad=True),
                 norm_eval=False):
        super().__init__()
        self.num_layers = num_layers
        self.norm_eval = norm_eval

        self.enc1 = MaskConvModule(in_channels,
                                   64,
                                   kernel_size=7,
                                   stride=2,
                                   padding=3,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=None,
                                   act_cfg=dict(type='ReLU'))

        self.enc2 = MaskConvModule(64,
                                   128,
                                   kernel_size=5,
                                   stride=2,
                                   padding=2,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='ReLU'))

        self.enc3 = MaskConvModule(128,
                                   256,
                                   kernel_size=5,
                                   stride=2,
                                   padding=2,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='ReLU'))

        self.enc4 = MaskConvModule(256,
                                   512,
                                   kernel_size=3,
                                   stride=2,
                                   padding=1,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='ReLU'))

        for i in range(4, num_layers):
            name = f'enc{i+1}'
            self.add_module(
                name,
                MaskConvModule(512,
                               512,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               conv_cfg=conv_cfg,
                               norm_cfg=norm_cfg,
                               act_cfg=dict(type='ReLU')))
Ejemplo n.º 2
0
    def __init__(self,
                 num_layers=7,
                 interpolation='nearest',
                 conv_cfg=dict(type='PConv', multi_channel=True),
                 norm_cfg=dict(type='BN')):
        super(PConvDecoder, self).__init__()
        self.num_layers = num_layers
        self.interpolation = interpolation

        for i in range(4, num_layers):
            name = f'dec{i+1}'
            self.add_module(
                name,
                MaskConvModule(512 + 512,
                               512,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               conv_cfg=conv_cfg,
                               norm_cfg=norm_cfg,
                               act_cfg=dict(type='LeakyReLU',
                                            negative_slope=0.2)))

        self.dec4 = MaskConvModule(512 + 256,
                                   256,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='LeakyReLU',
                                                negative_slope=0.2))

        self.dec3 = MaskConvModule(256 + 128,
                                   128,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='LeakyReLU',
                                                negative_slope=0.2))

        self.dec2 = MaskConvModule(128 + 64,
                                   64,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=norm_cfg,
                                   act_cfg=dict(type='LeakyReLU',
                                                negative_slope=0.2))

        self.dec1 = MaskConvModule(64 + 3,
                                   3,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   conv_cfg=conv_cfg,
                                   norm_cfg=None,
                                   act_cfg=None)
Ejemplo n.º 3
0
def test_mask_conv_module():
    with pytest.raises(KeyError):
        # conv_cfg must be a dict or None
        conv_cfg = dict(type='conv')
        MaskConvModule(3, 8, 2, conv_cfg=conv_cfg)

    with pytest.raises(AssertionError):
        # norm_cfg must be a dict or None
        norm_cfg = ['norm']
        MaskConvModule(3, 8, 2, norm_cfg=norm_cfg)

    with pytest.raises(AssertionError):
        # order elements must be ('conv', 'norm', 'act')
        order = ['conv', 'norm', 'act']
        MaskConvModule(3, 8, 2, order=order)

    with pytest.raises(AssertionError):
        # order elements must be ('conv', 'norm', 'act')
        order = ('conv', 'norm')
        MaskConvModule(3, 8, 2, order=order)

    with pytest.raises(KeyError):
        # softmax is not supported
        act_cfg = dict(type='softmax')
        MaskConvModule(3, 8, 2, act_cfg=act_cfg)

    conv_cfg = dict(type='PConv', multi_channel=True)
    conv = MaskConvModule(3, 8, 2, conv_cfg=conv_cfg)
    x = torch.rand(1, 3, 256, 256)
    mask_in = torch.ones_like(x)
    mask_in[..., 20:130, 120:150] = 0.
    output, mask_update = conv(x, mask_in)
    assert output.shape == (1, 8, 255, 255)
    assert mask_update.shape == (1, 8, 255, 255)

    # add test for ['norm', 'conv', 'act']
    conv = MaskConvModule(3,
                          8,
                          2,
                          order=('norm', 'conv', 'act'),
                          conv_cfg=conv_cfg)
    x = torch.rand(1, 3, 256, 256)
    output = conv(x, mask_in, return_mask=False)
    assert output.shape == (1, 8, 255, 255)

    conv = MaskConvModule(3,
                          8,
                          3,
                          padding=1,
                          conv_cfg=conv_cfg,
                          with_spectral_norm=True)
    assert hasattr(conv.conv, 'weight_orig')
    output = conv(x, return_mask=False)
    assert output.shape == (1, 8, 256, 256)

    conv = MaskConvModule(3,
                          8,
                          3,
                          padding=1,
                          norm_cfg=dict(type='BN'),
                          padding_mode='reflect',
                          conv_cfg=conv_cfg)
    assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
    output = conv(x, mask_in, return_mask=False)
    assert output.shape == (1, 8, 256, 256)

    conv = MaskConvModule(3,
                          8,
                          3,
                          padding=1,
                          act_cfg=dict(type='LeakyReLU'),
                          conv_cfg=conv_cfg)
    output = conv(x, mask_in, return_mask=False)
    assert output.shape == (1, 8, 256, 256)

    with pytest.raises(KeyError):
        conv = MaskConvModule(3, 8, 3, padding=1, padding_mode='igccc')