Esempio n. 1
0
    def __init__(self, in_channels, out_channels): 
        super(NF_Block, self).__init__()
        self.conv1 = WSConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=2, dilation=2)
        self.relu1 = nn.ReLU()
        self.conv2 = WSConv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()

        self.shortcut_conv = WSConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
Esempio n. 2
0
    def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False):
        super(ConvNorm, self).__init__()

        reflection_padding = kernel_size // 2
        #self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        # because of tensorrt
        self.reflection_pad = torch.nn.ZeroPad2d(reflection_padding)

        if cfg['network_G']['conv'] == 'doconv':
            self.conv = DOConv2d(in_feat,
                                 out_feat,
                                 stride=stride,
                                 kernel_size=kernel_size,
                                 bias=True)
        elif cfg['network_G']['conv'] == 'gated':
            self.conv = GatedConv2dWithActivation(in_feat,
                                                  out_feat,
                                                  stride=stride,
                                                  kernel_size=kernel_size,
                                                  bias=True)
        elif cfg['network_G']['conv'] == 'TBC':
            self.conv = TiedBlockConv2d(in_feat,
                                        out_feat,
                                        stride=stride,
                                        kernel_size=kernel_size,
                                        bias=True)
        elif cfg['network_G']['conv'] == 'dynamic':
            self.conv = DynamicConvolution(nof_kernels_param,
                                           reduce_param,
                                           in_channels=in_feat,
                                           out_channels=out_feat,
                                           stride=stride,
                                           kernel_size=kernel_size,
                                           bias=True)
        elif cfg['network_G']['conv'] == 'CondConv':
            self.conv = CondConv(in_planes=in_feat,
                                 out_planes=out_feat,
                                 kernel_size=kernel_size,
                                 stride=1,
                                 padding=1,
                                 bias=False)
        elif cfg['network_G']['conv'] == 'MBConv':
            self.conv = MBConv(in_feat, out_feat, 1, 1, True)
        elif cfg['network_G']['conv'] == 'fft':
            self.conv = FourierUnit(in_feat, out_feat)
        elif cfg['network_G']['conv'] == 'WSConv':
            self.conv = WSConv2d(in_feat,
                                 out_feat,
                                 stride=stride,
                                 kernel_size=kernel_size,
                                 bias=True)
        elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][
                'conv'] == 'Involution':
            self.conv = nn.Conv2d(in_feat,
                                  out_feat,
                                  stride=stride,
                                  kernel_size=kernel_size,
                                  bias=True)
Esempio n. 3
0
def conv(in_planes,
         out_planes,
         kernel_size=3,
         stride=1,
         padding=1,
         dilation=1):
    if cfg['network_G']['conv'] == 'doconv':
        return nn.Sequential(
            DOConv2d(in_planes,
                     out_planes,
                     stride=stride,
                     kernel_size=1,
                     bias=True), nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'gated':
        return nn.Sequential(
            GatedConv2dWithActivation(in_planes,
                                      out_planes,
                                      stride=stride,
                                      kernel_size=1,
                                      bias=True), nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'TBC':
        return nn.Sequential(
            TiedBlockConv2d(in_planes,
                            out_planes,
                            stride=stride,
                            kernel_size=1,
                            bias=True), nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'dynamic':
        return nn.Sequential(
            DynamicConvolution(nof_kernels_param,
                               reduce_param,
                               in_channels=in_planes,
                               out_channels=out_planes,
                               stride=stride,
                               kernel_size=1,
                               bias=True), nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'MBConv':
        return nn.Sequential(MBConv(in_planes, out_planes, 1, 1, True),
                             nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'fft':
        return nn.Sequential(FourierUnit(in_planes, out_planes),
                             nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'WSConv':
        return nn.Sequential(
            WSConv2d(in_planes,
                     out_planes,
                     stride=stride,
                     kernel_size=1,
                     bias=True), nn.PReLU(out_planes))
    elif cfg['network_G']['conv'] == 'conv2d':
        return nn.Sequential(
            nn.Conv2d(in_planes,
                      out_planes,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      bias=True), nn.PReLU(out_planes))
Esempio n. 4
0
def replace_conv(module: nn.Module):
    """Recursively replaces every convolution with WSConv2d.

    Usage: replace_conv(model) #(In-line replacement)
    Args:
      module(nn.Module): target's model whose convolutions must be replaced.
    """
    for name, mod in module.named_children():
        target_mod = getattr(module, name)
        if type(mod) == torch.nn.Conv2d:
            setattr(module, name, WSConv2d(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
                                           target_mod.stride, target_mod.padding, target_mod.dilation, target_mod.groups, target_mod.bias))
        
        if type(mod) == torch.nn.BatchNorm2d:
            setattr(module, name, torch.nn.Identity())

    for name, mod in module.named_children():
        replace_conv(mod)
Esempio n. 5
0
    def __init__(self,
                 n_resgroups,
                 n_resblocks,
                 n_feats,
                 reduction=16,
                 act=nn.LeakyReLU(0.2, True),
                 norm=False):
        super(Interpolation, self).__init__()

        if cfg['network_G']['conv'] == 'doconv':
            self.headConv = DOConv2d(n_feats * 2,
                                     n_feats,
                                     stride=1,
                                     padding=1,
                                     bias=False,
                                     groups=1,
                                     kernel_size=3)
        elif cfg['network_G']['conv'] == 'gated':
            self.headConv = GatedConv2dWithActivation(n_feats * 2,
                                                      n_feats,
                                                      stride=1,
                                                      padding=1,
                                                      bias=False,
                                                      groups=1,
                                                      kernel_size=3)
        elif cfg['network_G']['conv'] == 'TBC':
            self.headConv = TiedBlockConv2d(n_feats * 2,
                                            n_feats,
                                            stride=1,
                                            padding=1,
                                            bias=False,
                                            groups=1,
                                            kernel_size=3)
        elif cfg['network_G']['conv'] == 'dynamic':
            self.headConv = DynamicConvolution(nof_kernels_param,
                                               reduce_param,
                                               in_channels=n_feats * 2,
                                               out_channels=n_feats,
                                               stride=1,
                                               padding=1,
                                               bias=False,
                                               groups=1,
                                               kernel_size=3)
        elif cfg['network_G']['conv'] == 'MBConv':
            self.headConv = MBConv(n_feats * 2, n_feats, 1, 1, True)
        elif cfg['network_G']['conv'] == 'fft':
            self.headConv = FourierUnit(in_channels=n_feats * 2,
                                        out_channels=n_feats,
                                        groups=1,
                                        spatial_scale_factor=None,
                                        spatial_scale_mode='bilinear',
                                        spectral_pos_encoding=False,
                                        use_se=False,
                                        se_kwargs=None,
                                        ffc3d=False,
                                        fft_norm='ortho')
        elif cfg['network_G']['conv'] == 'WSConv':
            self.headConv = WSConv2d(n_feats * 2,
                                     n_feats,
                                     stride=1,
                                     padding=1,
                                     bias=False,
                                     groups=1,
                                     kernel_size=3)
        # Involution does have fixed in/output dimension, CondConv results in shape error
        elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][
                'conv'] == 'Involution' or cfg['network_G'][
                    'conv'] == 'CondConv':
            self.headConv = nn.Conv2d(n_feats * 2,
                                      n_feats,
                                      stride=1,
                                      padding=1,
                                      bias=False,
                                      groups=1,
                                      kernel_size=3)

        modules_body = [
            ResidualGroup(RCAB,
                          n_resblocks=12,
                          n_feat=n_feats,
                          kernel_size=3,
                          reduction=reduction,
                          act=act,
                          norm=norm) for _ in range(cfg['network_G']['RG'])
        ]
        self.body = nn.Sequential(*modules_body)

        if cfg['network_G']['conv'] == 'doconv':
            self.tailConv = DOConv2d(n_feats,
                                     n_feats,
                                     stride=1,
                                     padding=1,
                                     bias=False,
                                     groups=1,
                                     kernel_size=3)
        elif cfg['network_G']['conv'] == 'gated':
            self.tailConv = GatedConv2dWithActivation(n_feats,
                                                      n_feats,
                                                      stride=1,
                                                      padding=1,
                                                      bias=False,
                                                      groups=1,
                                                      kernel_size=3)
        elif cfg['network_G']['conv'] == 'TBC':
            self.tailConv = TiedBlockConv2d(n_feats,
                                            n_feats,
                                            stride=1,
                                            padding=1,
                                            bias=False,
                                            groups=1,
                                            kernel_size=3)
        elif cfg['network_G']['conv'] == 'dynamic':
            self.tailConv = DynamicConvolution(nof_kernels_param,
                                               reduce_param,
                                               in_channels=n_feats,
                                               out_channels=n_feats,
                                               stride=1,
                                               padding=1,
                                               bias=False,
                                               groups=1,
                                               kernel_size=3)
        elif cfg['network_G']['conv'] == 'MBConv':
            self.tailConv = MBConv(n_feats, n_feats, 1, 1, True)
        elif cfg['network_G']['conv'] == 'Involution':
            self.tailConv = Involution(in_channel=n_feats,
                                       kernel_size=3,
                                       stride=1)
        elif cfg['network_G']['conv'] == 'CondConv':
            self.tailConv = CondConv(in_planes=n_feats,
                                     out_planes=n_feats,
                                     kernel_size=1,
                                     stride=1,
                                     padding=0,
                                     bias=False)
        elif cfg['network_G']['conv'] == 'fft':
            self.tailConv = FourierUnit(in_channels=n_feats,
                                        out_channels=n_feats,
                                        groups=1,
                                        spatial_scale_factor=None,
                                        spatial_scale_mode='bilinear',
                                        spectral_pos_encoding=False,
                                        use_se=False,
                                        se_kwargs=None,
                                        ffc3d=False,
                                        fft_norm='ortho')
        elif cfg['network_G']['conv'] == 'WSConv':
            self.tailConv = WSConv2d(n_feats,
                                     n_feats,
                                     stride=1,
                                     padding=1,
                                     bias=False,
                                     groups=1,
                                     kernel_size=3)
        elif cfg['network_G']['conv'] == 'conv2d':
            self.tailConv = nn.Conv2d(n_feats,
                                      n_feats,
                                      stride=1,
                                      padding=1,
                                      bias=False,
                                      groups=1,
                                      kernel_size=3)
Esempio n. 6
0
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight

        if cfg['network_G']['conv'] == 'doconv':
            self.conv_du = nn.Sequential(
                DOConv2d(channel,
                         channel // reduction,
                         1,
                         padding=0,
                         bias=False), nn.ReLU(inplace=True),
                DOConv2d(channel // reduction,
                         channel,
                         1,
                         padding=0,
                         bias=False), nn.Sigmoid())

        elif cfg['network_G']['conv'] == 'TBC':
            self.conv_du = nn.Sequential(
                TiedBlockConv2d(channel,
                                channel // reduction,
                                1,
                                padding=0,
                                bias=False), nn.ReLU(inplace=True),
                TiedBlockConv2d(channel // reduction,
                                channel,
                                1,
                                padding=0,
                                bias=False), nn.Sigmoid())

        elif cfg['network_G']['conv'] == 'dynamic':
            self.conv_du = nn.Sequential(
                DynamicConvolution(nof_kernels_param,
                                   reduce_param,
                                   in_channels=channel,
                                   out_channels=(channel // reduction),
                                   kernel_size=1,
                                   padding=0,
                                   bias=False), nn.ReLU(inplace=True),
                DynamicConvolution(nof_kernels_param,
                                   reduce_param,
                                   in_channels=(channel // reduction),
                                   out_channels=channel,
                                   kernel_size=1,
                                   padding=0,
                                   bias=False), nn.Sigmoid())

        elif cfg['network_G']['conv'] == 'CondConv':
            self.conv_du = nn.Sequential(
                CondConv(in_planes=channel,
                         out_planes=channel // reduction,
                         kernel_size=1,
                         stride=1,
                         padding=0,
                         bias=False), nn.ReLU(inplace=True),
                CondConv(in_planes=channel // reduction,
                         out_planes=channel,
                         kernel_size=1,
                         stride=1,
                         padding=0,
                         bias=False), nn.Sigmoid())

        elif cfg['network_G']['conv'] == 'WSConv':
            self.conv_du = nn.Sequential(
                WSConv2d(channel,
                         channel // reduction,
                         1,
                         padding=0,
                         bias=False), nn.ReLU(inplace=True),
                WSConv2d(channel // reduction,
                         channel,
                         1,
                         padding=0,
                         bias=False), nn.Sigmoid())

        # shape error if gated, MBConv, Involution or fft is used here
        elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][
                'conv'] == 'gated' or cfg['network_G'][
                    'conv'] == 'MBConv' or cfg['network_G'][
                        'conv'] == 'Involution' or cfg['network_G'][
                            'conv'] == 'fft':
            self.conv_du = nn.Sequential(
                nn.Conv2d(channel,
                          channel // reduction,
                          1,
                          padding=0,
                          bias=False), nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction,
                          channel,
                          1,
                          padding=0,
                          bias=False), nn.Sigmoid())
Esempio n. 7
0
    def __init__(self, rgbRange, rgbMean, sign, nChannel=3):
        super(meanShift, self).__init__()
        if nChannel == 1:
            l = rgbMean[0] * rgbRange * float(sign)

            if cfg['network_G']['conv'] == 'doconv':
                self.shifter = DOConv2d(1,
                                        1,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
            elif cfg['network_G']['conv'] == 'gated':
                self.shifter = GatedConv2dWithActivation(1,
                                                         1,
                                                         kernel_size=1,
                                                         stride=1,
                                                         padding=0)
            elif cfg['network_G']['conv'] == 'TBC':
                self.shifter = TiedBlockConv2d(1,
                                               1,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0)
            elif cfg['network_G']['conv'] == 'dynamic':
                self.shifter = DynamicConvolution(nof_kernels_param,
                                                  reduce_param,
                                                  in_channels=1,
                                                  out_channels=1,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0)
            elif cfg['network_G']['conv'] == 'MBConv':
                self.shifter = MBConv(1, 1, 1, 2, True)
            elif cfg['network_G']['conv'] == 'Involution':
                self.shifter = Involution(in_channel=1,
                                          kernel_size=1,
                                          stride=1)
            elif cfg['network_G']['conv'] == 'CondConv':
                self.shifter = CondConv(in_planes=1,
                                        out_planes=1,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=False)
            elif cfg['network_G']['conv'] == 'fft':
                self.shifter = FourierUnit(in_channels=1,
                                           out_channels=1,
                                           groups=1,
                                           spatial_scale_factor=None,
                                           spatial_scale_mode='bilinear',
                                           spectral_pos_encoding=False,
                                           use_se=False,
                                           se_kwargs=None,
                                           ffc3d=False,
                                           fft_norm='ortho')
            elif cfg['network_G']['conv'] == 'WSConv':
                self.conv = WSConv2d(1, 1, kernel_size=1, stride=1, padding=0)
            elif cfg['network_G']['conv'] == 'conv2d':
                self.shifter = nn.Conv2d(1,
                                         1,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)

            self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1)
            self.shifter.bias.data = torch.Tensor([l])
        elif nChannel == 3:
            r = rgbMean[0] * rgbRange * float(sign)
            g = rgbMean[1] * rgbRange * float(sign)
            b = rgbMean[2] * rgbRange * float(sign)

            if cfg['network_G']['conv'] == 'doconv':
                self.shifter = DOConv2d(3,
                                        3,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
            elif cfg['network_G']['conv'] == 'gated':
                self.shifter = GatedConv2dWithActivation(3,
                                                         3,
                                                         kernel_size=1,
                                                         stride=1,
                                                         padding=0)
            elif cfg['network_G']['conv'] == 'TBC':
                self.shifter = TiedBlockConv2d(3,
                                               3,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0)
            elif cfg['network_G']['conv'] == 'dynamic':
                self.shifter = DynamicConvolution(nof_kernels_param,
                                                  reduce_param,
                                                  in_channels=3,
                                                  out_channels=3,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0)
            elif cfg['network_G']['conv'] == 'MBConv':
                self.shifter = MBConv(3, 3, 1, 2, True)
            elif cfg['network_G']['conv'] == 'Involution':
                self.shifter = Involution(in_channel=3,
                                          kernel_size=1,
                                          stride=1)
            elif cfg['network_G']['conv'] == 'CondConv':
                self.shifter = CondConv(in_planes=3,
                                        out_planes=3,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=False)
            elif cfg['network_G']['conv'] == 'fft':
                self.shifter = FourierUnit(in_channels=3,
                                           out_channels=3,
                                           groups=1,
                                           spatial_scale_factor=None,
                                           spatial_scale_mode='bilinear',
                                           spectral_pos_encoding=False,
                                           use_se=False,
                                           se_kwargs=None,
                                           ffc3d=False,
                                           fft_norm='ortho')
            elif cfg['network_G']['conv'] == 'WSConv':
                self.conv = WSConv2d(3, 3, kernel_size=1, stride=1, padding=0)
            elif cfg['network_G']['conv'] == 'conv2d':
                self.shifter = nn.Conv2d(3,
                                         3,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)

            self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
            self.shifter.bias.data = torch.Tensor([r, g, b])
        else:
            r = rgbMean[0] * rgbRange * float(sign)
            g = rgbMean[1] * rgbRange * float(sign)
            b = rgbMean[2] * rgbRange * float(sign)

            if cfg['network_G']['conv'] == 'doconv':
                self.shifter = DOConv2d(6,
                                        6,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
            elif cfg['network_G']['conv'] == 'gated':
                self.shifter = GatedConv2dWithActivation(6,
                                                         6,
                                                         kernel_size=1,
                                                         stride=1,
                                                         padding=0)
            elif cfg['network_G']['conv'] == 'TBC':
                self.shifter = TiedBlockConv2d(6,
                                               6,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0)
            elif cfg['network_G']['conv'] == 'dynamic':
                self.shifter = DynamicConvolution(nof_kernels_param,
                                                  reduce_param,
                                                  in_channels=6,
                                                  out_channels=6,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0)
            elif cfg['network_G']['conv'] == 'MBConv':
                self.shifter = MBConv(6, 6, 1, 2, True)
            elif cfg['network_G']['conv'] == 'Involution':
                self.shifter = Involution(in_channel=6,
                                          kernel_size=1,
                                          stride=1)
            elif cfg['network_G']['conv'] == 'CondConv':
                self.shifter = CondConv(in_planes=6,
                                        out_planes=6,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0,
                                        bias=False)
            elif cfg['network_G']['conv'] == 'fft':
                self.shifter = FourierUnit(in_channels=6,
                                           out_channels=6,
                                           groups=1,
                                           spatial_scale_factor=None,
                                           spatial_scale_mode='bilinear',
                                           spectral_pos_encoding=False,
                                           use_se=False,
                                           se_kwargs=None,
                                           ffc3d=False,
                                           fft_norm='ortho')
            elif cfg['network_G']['conv'] == 'WSConv':
                self.conv = WSConv2d(6, 6, kernel_size=1, stride=1, padding=0)
            elif cfg['network_G']['conv'] == 'conv2d':
                self.shifter = nn.Conv2d(6,
                                         6,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)

            self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1)
            self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b])

        # Freeze the meanShift layer
        for params in self.shifter.parameters():
            params.requires_grad = False
Esempio n. 8
0
def test_wsconv2d():
    c = WSConv2d(3, 6, 3)
    assert c(torch.randn(1, 3, 32, 32)) is not None, "Conv failed."