def __init__(self, in_channels=3, num_layers=7, conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False): super().__init__() self.num_layers = num_layers self.norm_eval = norm_eval self.enc1 = MaskConvModule(in_channels, 64, kernel_size=7, stride=2, padding=3, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=dict(type='ReLU')) self.enc2 = MaskConvModule(64, 128, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc3 = MaskConvModule(128, 256, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc4 = MaskConvModule(256, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) for i in range(4, num_layers): name = f'enc{i+1}' self.add_module( name, MaskConvModule(512, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')))
def __init__(self, num_layers=7, interpolation='nearest', conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN')): super(PConvDecoder, self).__init__() self.num_layers = num_layers self.interpolation = interpolation for i in range(4, num_layers): name = f'dec{i+1}' self.add_module( name, MaskConvModule(512 + 512, 512, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2))) self.dec4 = MaskConvModule(512 + 256, 256, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec3 = MaskConvModule(256 + 128, 128, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec2 = MaskConvModule(128 + 64, 64, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec1 = MaskConvModule(64 + 3, 3, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=None)
def test_mask_conv_module(): with pytest.raises(KeyError): # conv_cfg must be a dict or None conv_cfg = dict(type='conv') MaskConvModule(3, 8, 2, conv_cfg=conv_cfg) with pytest.raises(AssertionError): # norm_cfg must be a dict or None norm_cfg = ['norm'] MaskConvModule(3, 8, 2, norm_cfg=norm_cfg) with pytest.raises(AssertionError): # order elements must be ('conv', 'norm', 'act') order = ['conv', 'norm', 'act'] MaskConvModule(3, 8, 2, order=order) with pytest.raises(AssertionError): # order elements must be ('conv', 'norm', 'act') order = ('conv', 'norm') MaskConvModule(3, 8, 2, order=order) with pytest.raises(KeyError): # softmax is not supported act_cfg = dict(type='softmax') MaskConvModule(3, 8, 2, act_cfg=act_cfg) conv_cfg = dict(type='PConv', multi_channel=True) conv = MaskConvModule(3, 8, 2, conv_cfg=conv_cfg) x = torch.rand(1, 3, 256, 256) mask_in = torch.ones_like(x) mask_in[..., 20:130, 120:150] = 0. output, mask_update = conv(x, mask_in) assert output.shape == (1, 8, 255, 255) assert mask_update.shape == (1, 8, 255, 255) # add test for ['norm', 'conv', 'act'] conv = MaskConvModule(3, 8, 2, order=('norm', 'conv', 'act'), conv_cfg=conv_cfg) x = torch.rand(1, 3, 256, 256) output = conv(x, mask_in, return_mask=False) assert output.shape == (1, 8, 255, 255) conv = MaskConvModule(3, 8, 3, padding=1, conv_cfg=conv_cfg, with_spectral_norm=True) assert hasattr(conv.conv, 'weight_orig') output = conv(x, return_mask=False) assert output.shape == (1, 8, 256, 256) conv = MaskConvModule(3, 8, 3, padding=1, norm_cfg=dict(type='BN'), padding_mode='reflect', conv_cfg=conv_cfg) assert isinstance(conv.padding_layer, nn.ReflectionPad2d) output = conv(x, mask_in, return_mask=False) assert output.shape == (1, 8, 256, 256) conv = MaskConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'), conv_cfg=conv_cfg) output = conv(x, mask_in, return_mask=False) assert output.shape == (1, 8, 256, 256) with pytest.raises(KeyError): conv = MaskConvModule(3, 8, 3, padding=1, padding_mode='igccc')