def __init__(self, in_channels, out_channels): super(NF_Block, self).__init__() self.conv1 = WSConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=2, dilation=2) self.relu1 = nn.ReLU() self.conv2 = WSConv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) self.relu2 = nn.ReLU() self.shortcut_conv = WSConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False): super(ConvNorm, self).__init__() reflection_padding = kernel_size // 2 #self.reflection_pad = nn.ReflectionPad2d(reflection_padding) # because of tensorrt self.reflection_pad = torch.nn.ZeroPad2d(reflection_padding) if cfg['network_G']['conv'] == 'doconv': self.conv = DOConv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) elif cfg['network_G']['conv'] == 'gated': self.conv = GatedConv2dWithActivation(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) elif cfg['network_G']['conv'] == 'TBC': self.conv = TiedBlockConv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) elif cfg['network_G']['conv'] == 'dynamic': self.conv = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=in_feat, out_channels=out_feat, stride=stride, kernel_size=kernel_size, bias=True) elif cfg['network_G']['conv'] == 'CondConv': self.conv = CondConv(in_planes=in_feat, out_planes=out_feat, kernel_size=kernel_size, stride=1, padding=1, bias=False) elif cfg['network_G']['conv'] == 'MBConv': self.conv = MBConv(in_feat, out_feat, 1, 1, True) elif cfg['network_G']['conv'] == 'fft': self.conv = FourierUnit(in_feat, out_feat) elif cfg['network_G']['conv'] == 'WSConv': self.conv = WSConv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][ 'conv'] == 'Involution': self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): if cfg['network_G']['conv'] == 'doconv': return nn.Sequential( DOConv2d(in_planes, out_planes, stride=stride, kernel_size=1, bias=True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'gated': return nn.Sequential( GatedConv2dWithActivation(in_planes, out_planes, stride=stride, kernel_size=1, bias=True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'TBC': return nn.Sequential( TiedBlockConv2d(in_planes, out_planes, stride=stride, kernel_size=1, bias=True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'dynamic': return nn.Sequential( DynamicConvolution(nof_kernels_param, reduce_param, in_channels=in_planes, out_channels=out_planes, stride=stride, kernel_size=1, bias=True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'MBConv': return nn.Sequential(MBConv(in_planes, out_planes, 1, 1, True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'fft': return nn.Sequential(FourierUnit(in_planes, out_planes), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'WSConv': return nn.Sequential( WSConv2d(in_planes, out_planes, stride=stride, kernel_size=1, bias=True), nn.PReLU(out_planes)) elif cfg['network_G']['conv'] == 'conv2d': return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), nn.PReLU(out_planes))
def replace_conv(module: nn.Module): """Recursively replaces every convolution with WSConv2d. Usage: replace_conv(model) #(In-line replacement) Args: module(nn.Module): target's model whose convolutions must be replaced. """ for name, mod in module.named_children(): target_mod = getattr(module, name) if type(mod) == torch.nn.Conv2d: setattr(module, name, WSConv2d(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size, target_mod.stride, target_mod.padding, target_mod.dilation, target_mod.groups, target_mod.bias)) if type(mod) == torch.nn.BatchNorm2d: setattr(module, name, torch.nn.Identity()) for name, mod in module.named_children(): replace_conv(mod)
def __init__(self, n_resgroups, n_resblocks, n_feats, reduction=16, act=nn.LeakyReLU(0.2, True), norm=False): super(Interpolation, self).__init__() if cfg['network_G']['conv'] == 'doconv': self.headConv = DOConv2d(n_feats * 2, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'gated': self.headConv = GatedConv2dWithActivation(n_feats * 2, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'TBC': self.headConv = TiedBlockConv2d(n_feats * 2, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'dynamic': self.headConv = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=n_feats * 2, out_channels=n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'MBConv': self.headConv = MBConv(n_feats * 2, n_feats, 1, 1, True) elif cfg['network_G']['conv'] == 'fft': self.headConv = FourierUnit(in_channels=n_feats * 2, out_channels=n_feats, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho') elif cfg['network_G']['conv'] == 'WSConv': self.headConv = WSConv2d(n_feats * 2, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) # Involution does have fixed in/output dimension, CondConv results in shape error elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][ 'conv'] == 'Involution' or cfg['network_G'][ 'conv'] == 'CondConv': self.headConv = nn.Conv2d(n_feats * 2, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) modules_body = [ ResidualGroup(RCAB, n_resblocks=12, n_feat=n_feats, kernel_size=3, reduction=reduction, act=act, norm=norm) for _ in range(cfg['network_G']['RG']) ] self.body = nn.Sequential(*modules_body) if cfg['network_G']['conv'] == 'doconv': self.tailConv = DOConv2d(n_feats, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'gated': self.tailConv = GatedConv2dWithActivation(n_feats, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'TBC': self.tailConv = TiedBlockConv2d(n_feats, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'dynamic': self.tailConv = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=n_feats, out_channels=n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'MBConv': self.tailConv = MBConv(n_feats, n_feats, 1, 1, True) elif cfg['network_G']['conv'] == 'Involution': self.tailConv = Involution(in_channel=n_feats, kernel_size=3, stride=1) elif cfg['network_G']['conv'] == 'CondConv': self.tailConv = CondConv(in_planes=n_feats, out_planes=n_feats, kernel_size=1, stride=1, padding=0, bias=False) elif cfg['network_G']['conv'] == 'fft': self.tailConv = FourierUnit(in_channels=n_feats, out_channels=n_feats, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho') elif cfg['network_G']['conv'] == 'WSConv': self.tailConv = WSConv2d(n_feats, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3) elif cfg['network_G']['conv'] == 'conv2d': self.tailConv = nn.Conv2d(n_feats, n_feats, stride=1, padding=1, bias=False, groups=1, kernel_size=3)
def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight if cfg['network_G']['conv'] == 'doconv': self.conv_du = nn.Sequential( DOConv2d(channel, channel // reduction, 1, padding=0, bias=False), nn.ReLU(inplace=True), DOConv2d(channel // reduction, channel, 1, padding=0, bias=False), nn.Sigmoid()) elif cfg['network_G']['conv'] == 'TBC': self.conv_du = nn.Sequential( TiedBlockConv2d(channel, channel // reduction, 1, padding=0, bias=False), nn.ReLU(inplace=True), TiedBlockConv2d(channel // reduction, channel, 1, padding=0, bias=False), nn.Sigmoid()) elif cfg['network_G']['conv'] == 'dynamic': self.conv_du = nn.Sequential( DynamicConvolution(nof_kernels_param, reduce_param, in_channels=channel, out_channels=(channel // reduction), kernel_size=1, padding=0, bias=False), nn.ReLU(inplace=True), DynamicConvolution(nof_kernels_param, reduce_param, in_channels=(channel // reduction), out_channels=channel, kernel_size=1, padding=0, bias=False), nn.Sigmoid()) elif cfg['network_G']['conv'] == 'CondConv': self.conv_du = nn.Sequential( CondConv(in_planes=channel, out_planes=channel // reduction, kernel_size=1, stride=1, padding=0, bias=False), nn.ReLU(inplace=True), CondConv(in_planes=channel // reduction, out_planes=channel, kernel_size=1, stride=1, padding=0, bias=False), nn.Sigmoid()) elif cfg['network_G']['conv'] == 'WSConv': self.conv_du = nn.Sequential( WSConv2d(channel, channel // reduction, 1, padding=0, bias=False), nn.ReLU(inplace=True), WSConv2d(channel // reduction, channel, 1, padding=0, bias=False), nn.Sigmoid()) # shape error if gated, MBConv, Involution or fft is used here elif cfg['network_G']['conv'] == 'conv2d' or cfg['network_G'][ 'conv'] == 'gated' or cfg['network_G'][ 'conv'] == 'MBConv' or cfg['network_G'][ 'conv'] == 'Involution' or cfg['network_G'][ 'conv'] == 'fft': self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=False), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=False), nn.Sigmoid())
def __init__(self, rgbRange, rgbMean, sign, nChannel=3): super(meanShift, self).__init__() if nChannel == 1: l = rgbMean[0] * rgbRange * float(sign) if cfg['network_G']['conv'] == 'doconv': self.shifter = DOConv2d(1, 1, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'gated': self.shifter = GatedConv2dWithActivation(1, 1, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'TBC': self.shifter = TiedBlockConv2d(1, 1, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'dynamic': self.shifter = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=1, out_channels=1, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'MBConv': self.shifter = MBConv(1, 1, 1, 2, True) elif cfg['network_G']['conv'] == 'Involution': self.shifter = Involution(in_channel=1, kernel_size=1, stride=1) elif cfg['network_G']['conv'] == 'CondConv': self.shifter = CondConv(in_planes=1, out_planes=1, kernel_size=1, stride=1, padding=0, bias=False) elif cfg['network_G']['conv'] == 'fft': self.shifter = FourierUnit(in_channels=1, out_channels=1, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho') elif cfg['network_G']['conv'] == 'WSConv': self.conv = WSConv2d(1, 1, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'conv2d': self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1) self.shifter.bias.data = torch.Tensor([l]) elif nChannel == 3: r = rgbMean[0] * rgbRange * float(sign) g = rgbMean[1] * rgbRange * float(sign) b = rgbMean[2] * rgbRange * float(sign) if cfg['network_G']['conv'] == 'doconv': self.shifter = DOConv2d(3, 3, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'gated': self.shifter = GatedConv2dWithActivation(3, 3, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'TBC': self.shifter = TiedBlockConv2d(3, 3, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'dynamic': self.shifter = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'MBConv': self.shifter = MBConv(3, 3, 1, 2, True) elif cfg['network_G']['conv'] == 'Involution': self.shifter = Involution(in_channel=3, kernel_size=1, stride=1) elif cfg['network_G']['conv'] == 'CondConv': self.shifter = CondConv(in_planes=3, out_planes=3, kernel_size=1, stride=1, padding=0, bias=False) elif cfg['network_G']['conv'] == 'fft': self.shifter = FourierUnit(in_channels=3, out_channels=3, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho') elif cfg['network_G']['conv'] == 'WSConv': self.conv = WSConv2d(3, 3, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'conv2d': self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) self.shifter.bias.data = torch.Tensor([r, g, b]) else: r = rgbMean[0] * rgbRange * float(sign) g = rgbMean[1] * rgbRange * float(sign) b = rgbMean[2] * rgbRange * float(sign) if cfg['network_G']['conv'] == 'doconv': self.shifter = DOConv2d(6, 6, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'gated': self.shifter = GatedConv2dWithActivation(6, 6, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'TBC': self.shifter = TiedBlockConv2d(6, 6, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'dynamic': self.shifter = DynamicConvolution(nof_kernels_param, reduce_param, in_channels=6, out_channels=6, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'MBConv': self.shifter = MBConv(6, 6, 1, 2, True) elif cfg['network_G']['conv'] == 'Involution': self.shifter = Involution(in_channel=6, kernel_size=1, stride=1) elif cfg['network_G']['conv'] == 'CondConv': self.shifter = CondConv(in_planes=6, out_planes=6, kernel_size=1, stride=1, padding=0, bias=False) elif cfg['network_G']['conv'] == 'fft': self.shifter = FourierUnit(in_channels=6, out_channels=6, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho') elif cfg['network_G']['conv'] == 'WSConv': self.conv = WSConv2d(6, 6, kernel_size=1, stride=1, padding=0) elif cfg['network_G']['conv'] == 'conv2d': self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0) self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1) self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b]) # Freeze the meanShift layer for params in self.shifter.parameters(): params.requires_grad = False
def test_wsconv2d(): c = WSConv2d(3, 6, 3) assert c(torch.randn(1, 3, 32, 32)) is not None, "Conv failed."