Ejemplo n.º 1
0
def test_gca_module():
    img_feat = torch.rand(1, 128, 64, 64)
    alpha_feat = torch.rand(1, 128, 64, 64)
    unknown = None
    gca = GCAModule(128, 128, rate=1)
    output = gca(img_feat, alpha_feat, unknown)
    assert output.shape == (1, 128, 64, 64)

    img_feat = torch.rand(1, 128, 64, 64)
    alpha_feat = torch.rand(1, 128, 64, 64)
    unknown = torch.rand(1, 1, 64, 64)
    gca = GCAModule(128, 128, rate=2)
    output = gca(img_feat, alpha_feat, unknown)
    assert output.shape == (1, 128, 64, 64)
Ejemplo n.º 2
0
 def __init__(self,
              block,
              layers,
              in_channels,
              kernel_size=3,
              conv_cfg=None,
              norm_cfg=dict(type='BN'),
              act_cfg=dict(type='LeakyReLU',
                           negative_slope=0.2,
                           inplace=True),
              with_spectral_norm=False,
              late_downsample=False):
     super().__init__(block, layers, in_channels, kernel_size, conv_cfg,
                      norm_cfg, act_cfg, with_spectral_norm,
                      late_downsample)
     self.gca = GCAModule(128, 128)
Ejemplo n.º 3
0
    def __init__(self,
                 block,
                 layers,
                 in_channels,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 with_spectral_norm=False,
                 late_downsample=False,
                 order=('conv', 'act', 'norm')):
        super(ResGCAEncoder,
              self).__init__(block, layers, in_channels, conv_cfg, norm_cfg,
                             act_cfg, with_spectral_norm, late_downsample,
                             order)

        assert in_channels == 4 or in_channels == 6, (
            f'in_channels must be 4 or 6, but got {in_channels}')

        self.trimap_channels = in_channels - 3

        guidance_in_channels = [3, 16, 32]
        guidance_out_channels = [16, 32, 128]

        guidance_head = []
        for in_channels, out_channels in zip(guidance_in_channels,
                                             guidance_out_channels):
            guidance_head += [
                ConvModule(in_channels,
                           out_channels,
                           3,
                           stride=2,
                           padding=1,
                           norm_cfg=norm_cfg,
                           act_cfg=act_cfg,
                           with_spectral_norm=with_spectral_norm,
                           padding_mode='reflect',
                           order=order)
            ]
        self.guidance_head = nn.Sequential(*guidance_head)

        self.gca = GCAModule(128, 128)