def test_gca_module(): img_feat = torch.rand(1, 128, 64, 64) alpha_feat = torch.rand(1, 128, 64, 64) unknown = None gca = GCAModule(128, 128, rate=1) output = gca(img_feat, alpha_feat, unknown) assert output.shape == (1, 128, 64, 64) img_feat = torch.rand(1, 128, 64, 64) alpha_feat = torch.rand(1, 128, 64, 64) unknown = torch.rand(1, 1, 64, 64) gca = GCAModule(128, 128, rate=2) output = gca(img_feat, alpha_feat, unknown) assert output.shape == (1, 128, 64, 64)
def __init__(self, block, layers, in_channels, kernel_size=3, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='LeakyReLU', negative_slope=0.2, inplace=True), with_spectral_norm=False, late_downsample=False): super().__init__(block, layers, in_channels, kernel_size, conv_cfg, norm_cfg, act_cfg, with_spectral_norm, late_downsample) self.gca = GCAModule(128, 128)
def __init__(self, block, layers, in_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_spectral_norm=False, late_downsample=False, order=('conv', 'act', 'norm')): super(ResGCAEncoder, self).__init__(block, layers, in_channels, conv_cfg, norm_cfg, act_cfg, with_spectral_norm, late_downsample, order) assert in_channels == 4 or in_channels == 6, ( f'in_channels must be 4 or 6, but got {in_channels}') self.trimap_channels = in_channels - 3 guidance_in_channels = [3, 16, 32] guidance_out_channels = [16, 32, 128] guidance_head = [] for in_channels, out_channels in zip(guidance_in_channels, guidance_out_channels): guidance_head += [ ConvModule(in_channels, out_channels, 3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm, padding_mode='reflect', order=order) ] self.guidance_head = nn.Sequential(*guidance_head) self.gca = GCAModule(128, 128)