Esempio n. 1
0
    def __init__(self, mid_channels=64, num_blocks=30, spynet_pretrained=None):

        super().__init__()

        self.mid_channels = mid_channels

        # optical flow network for feature alignment
        self.spynet = SPyNet(pretrained=spynet_pretrained)

        # propagation branches
        self.backward_resblocks = ResidualBlocksWithInputConv(
            mid_channels + 3, mid_channels, num_blocks)
        self.forward_resblocks = ResidualBlocksWithInputConv(
            mid_channels + 3, mid_channels, num_blocks)

        # upsample
        self.fusion = nn.Conv2d(
            mid_channels * 2, mid_channels, 1, 1, 0, bias=True)
        self.upsample1 = PixelShufflePack(
            mid_channels, mid_channels, 2, upsample_kernel=3)
        self.upsample2 = PixelShufflePack(
            mid_channels, 64, 2, upsample_kernel=3)
        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(
            scale_factor=4, mode='bilinear', align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
Esempio n. 2
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=64,
                 num_frames=5,
                 deform_groups=8,
                 num_blocks_extraction=5,
                 num_blocks_reconstruction=10,
                 center_frame_idx=2,
                 with_tsa=True):
        super(EDVRNet, self).__init__()
        self.center_frame_idx = center_frame_idx
        self.with_tsa = with_tsa
        act_cfg = dict(type='LeakyReLU', negative_slope=0.1)

        self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
        self.feature_extraction = make_layer(
            ResidualBlockNoBN,
            num_blocks_extraction,
            mid_channels=mid_channels)

        # generate pyramid features
        self.feat_l2_conv1 = ConvModule(
            mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
        self.feat_l2_conv2 = ConvModule(
            mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
        self.feat_l3_conv1 = ConvModule(
            mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
        self.feat_l3_conv2 = ConvModule(
            mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
        # pcd alignment
        self.pcd_alignment = PCDAlignment(
            mid_channels=mid_channels, deform_groups=deform_groups)
        # fusion
        if self.with_tsa:
            self.fusion = TSAFusion(
                mid_channels=mid_channels,
                num_frames=num_frames,
                center_frame_idx=self.center_frame_idx)
        else:
            self.fusion = nn.Conv2d(num_frames * mid_channels, mid_channels, 1,
                                    1)

        # reconstruction
        self.reconstruction = make_layer(
            ResidualBlockNoBN,
            num_blocks_reconstruction,
            mid_channels=mid_channels)
        # upsample
        self.upsample1 = PixelShufflePack(
            mid_channels, mid_channels, 2, upsample_kernel=3)
        self.upsample2 = PixelShufflePack(
            mid_channels, 64, 2, upsample_kernel=3)
        # we fix the output channels in the last few layers to 64.
        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(
            scale_factor=4, mode='bilinear', align_corners=False)
        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
Esempio n. 3
0
    def __init__(self,
                 mid_channels=64,
                 num_blocks=30,
                 keyframe_stride=5,
                 padding=2,
                 spynet_pretrained=None,
                 edvr_pretrained=None):

        super().__init__()

        self.mid_channels = mid_channels
        self.padding = padding
        self.keyframe_stride = keyframe_stride

        # optical flow network for alignment
        self.spynet = SPyNet(pretrained=spynet_pretrained)

        # information-refill
        self.edvr = EDVRFeatureExtractor(num_frames=padding * 2 + 1,
                                         center_frame_idx=padding,
                                         pretrained=edvr_pretrained)
        self.backward_fusion = nn.Conv2d(2 * mid_channels,
                                         mid_channels,
                                         3,
                                         1,
                                         1,
                                         bias=True)
        self.forward_fusion = nn.Conv2d(2 * mid_channels,
                                        mid_channels,
                                        3,
                                        1,
                                        1,
                                        bias=True)

        # propagation branches
        self.backward_resblocks = ResidualBlocksWithInputConv(
            mid_channels + 3, mid_channels, num_blocks)
        self.forward_resblocks = ResidualBlocksWithInputConv(
            2 * mid_channels + 3, mid_channels, num_blocks)

        # upsample
        self.upsample1 = PixelShufflePack(mid_channels,
                                          mid_channels,
                                          2,
                                          upsample_kernel=3)
        self.upsample2 = PixelShufflePack(mid_channels,
                                          64,
                                          2,
                                          upsample_kernel=3)
        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(scale_factor=4,
                                        mode='bilinear',
                                        align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
Esempio n. 4
0
def test_pixel_shuffle():

    # test on cpu
    model = PixelShufflePack(3, 3, 2, 3)
    model.init_weights()
    x = torch.rand(1, 3, 16, 16)
    y = model(x)
    assert y.shape == (1, 3, 32, 32)

    # test on gpu
    if torch.cuda.is_available():
        model = model.cuda()
        x = x.cuda()
        y = model(x)
        assert y.shape == (1, 3, 32, 32)
Esempio n. 5
0
    def __init__(self, scale, mid_channels):
        modules = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                modules.append(
                    PixelShufflePack(
                        mid_channels, mid_channels, 2, upsample_kernel=3))
        elif scale == 3:
            modules.append(
                PixelShufflePack(
                    mid_channels, mid_channels, scale, upsample_kernel=3))
        else:
            raise ValueError(f'scale {scale} is not supported. '
                             'Supported scales: 2^n and 3.')

        super(UpsampleModule, self).__init__(*modules)
Esempio n. 6
0
    def __init__(self,
                 in_channels=3,
                 mid_channels=64,
                 out_channels=3,
                 num_blocks_before_align=5,
                 num_blocks_after_align=10):

        super().__init__()

        self.feat_extract = nn.Sequential(
            ConvModule(in_channels, mid_channels, 3, padding=1),
            make_layer(ResidualBlockNoBN,
                       num_blocks_before_align,
                       mid_channels=mid_channels))

        self.feat_aggregate = nn.Sequential(
            nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True),
            DeformConv2dPack(mid_channels,
                             mid_channels,
                             3,
                             padding=1,
                             deform_groups=8),
            DeformConv2dPack(mid_channels,
                             mid_channels,
                             3,
                             padding=1,
                             deform_groups=8))
        self.align_1 = AugmentedDeformConv2dPack(mid_channels,
                                                 mid_channels,
                                                 3,
                                                 padding=1,
                                                 deform_groups=8)
        self.align_2 = DeformConv2dPack(mid_channels,
                                        mid_channels,
                                        3,
                                        padding=1,
                                        deform_groups=8)
        self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True)

        self.reconstruct = nn.Sequential(
            ConvModule(in_channels * 5, mid_channels, 3, padding=1),
            make_layer(ResidualBlockNoBN,
                       num_blocks_after_align,
                       mid_channels=mid_channels),
            PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
            PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
            nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False))
Esempio n. 7
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=64,
                 num_blocks=16,
                 upscale_factor=4):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.num_blocks = num_blocks
        self.upscale_factor = upscale_factor

        self.conv_first = nn.Conv2d(in_channels,
                                    mid_channels,
                                    3,
                                    1,
                                    1,
                                    bias=True)
        self.trunk_net = make_layer(ResidualBlockNoBN,
                                    num_blocks,
                                    mid_channels=mid_channels)

        # upsampling
        if self.upscale_factor in [2, 3]:
            self.upsample1 = PixelShufflePack(mid_channels,
                                              mid_channels,
                                              self.upscale_factor,
                                              upsample_kernel=3)
        elif self.upscale_factor == 4:
            self.upsample1 = PixelShufflePack(mid_channels,
                                              mid_channels,
                                              2,
                                              upsample_kernel=3)
            self.upsample2 = PixelShufflePack(mid_channels,
                                              mid_channels,
                                              2,
                                              upsample_kernel=3)
        else:
            raise ValueError(
                f'Unsupported scale factor {self.upscale_factor}. '
                f'Currently supported ones are '
                f'{self._supported_upscale_factors}.')

        self.conv_hr = nn.Conv2d(mid_channels,
                                 mid_channels,
                                 3,
                                 1,
                                 1,
                                 bias=True)
        self.conv_last = nn.Conv2d(mid_channels,
                                   out_channels,
                                   3,
                                   1,
                                   1,
                                   bias=True)

        self.img_upsampler = nn.Upsample(scale_factor=self.upscale_factor,
                                         mode='bilinear',
                                         align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
Esempio n. 8
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=64,
                 texture_channels=64,
                 num_blocks=(16, 16, 8, 4),
                 res_scale=1.0):
        super().__init__()

        self.texture_channels = texture_channels

        self.sfe = SFE(in_channels, mid_channels, num_blocks[0], res_scale)

        # stage 1
        self.conv_first1 = _conv3x3_layer(4 * texture_channels + mid_channels,
                                          mid_channels)

        self.res_block1 = make_layer(ResidualBlockNoBN,
                                     num_blocks[1],
                                     mid_channels=mid_channels,
                                     res_scale=res_scale)

        self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels)

        # up-sampling 1 -> 2
        self.up1 = PixelShufflePack(in_channels=mid_channels,
                                    out_channels=mid_channels,
                                    scale_factor=2,
                                    upsample_kernel=3)

        # stage 2
        self.conv_first2 = _conv3x3_layer(2 * texture_channels + mid_channels,
                                          mid_channels)

        self.csfi2 = CSFI2(mid_channels)

        self.res_block2_1 = make_layer(ResidualBlockNoBN,
                                       num_blocks[2],
                                       mid_channels=mid_channels,
                                       res_scale=res_scale)
        self.res_block2_2 = make_layer(ResidualBlockNoBN,
                                       num_blocks[2],
                                       mid_channels=mid_channels,
                                       res_scale=res_scale)

        self.conv_last2_1 = _conv3x3_layer(mid_channels, mid_channels)
        self.conv_last2_2 = _conv3x3_layer(mid_channels, mid_channels)

        # up-sampling 2 -> 3
        self.up2 = PixelShufflePack(in_channels=mid_channels,
                                    out_channels=mid_channels,
                                    scale_factor=2,
                                    upsample_kernel=3)

        # stage 3
        self.conv_first3 = _conv3x3_layer(texture_channels + mid_channels,
                                          mid_channels)

        self.csfi3 = CSFI3(mid_channels)

        self.res_block3_1 = make_layer(ResidualBlockNoBN,
                                       num_blocks[3],
                                       mid_channels=mid_channels,
                                       res_scale=res_scale)
        self.res_block3_2 = make_layer(ResidualBlockNoBN,
                                       num_blocks[3],
                                       mid_channels=mid_channels,
                                       res_scale=res_scale)
        self.res_block3_3 = make_layer(ResidualBlockNoBN,
                                       num_blocks[3],
                                       mid_channels=mid_channels,
                                       res_scale=res_scale)

        self.conv_last3_1 = _conv3x3_layer(mid_channels, mid_channels)
        self.conv_last3_2 = _conv3x3_layer(mid_channels, mid_channels)
        self.conv_last3_3 = _conv3x3_layer(mid_channels, mid_channels)

        # end, merge features
        self.merge_features = MergeFeatures(mid_channels, out_channels)
Esempio n. 9
0
    def __init__(self,
                 mid_channels=64,
                 num_blocks=7,
                 max_residue_magnitude=10,
                 is_low_res_input=True,
                 spynet_pretrained=None,
                 cpu_cache_length=100):

        super().__init__()
        self.mid_channels = mid_channels
        self.is_low_res_input = is_low_res_input
        self.cpu_cache_length = cpu_cache_length

        # optical flow
        self.spynet = SPyNet(pretrained=spynet_pretrained)

        # feature extraction module
        if is_low_res_input:
            self.feat_extract = ResidualBlocksWithInputConv(3, mid_channels, 5)
        else:
            self.feat_extract = nn.Sequential(
                nn.Conv2d(3, mid_channels, 3, 2, 1),
                nn.LeakyReLU(negative_slope=0.1, inplace=True),
                nn.Conv2d(mid_channels, mid_channels, 3, 2, 1),
                nn.LeakyReLU(negative_slope=0.1, inplace=True),
                ResidualBlocksWithInputConv(mid_channels, mid_channels, 5))

        # propagation branches
        self.deform_align = nn.ModuleDict()
        self.backbone = nn.ModuleDict()
        modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
        for i, module in enumerate(modules):
            if torch.cuda.is_available():
                self.deform_align[module] = SecondOrderDeformableAlignment(
                    2 * mid_channels,
                    mid_channels,
                    3,
                    padding=1,
                    deform_groups=16,
                    max_residue_magnitude=max_residue_magnitude)
            self.backbone[module] = ResidualBlocksWithInputConv(
                (2 + i) * mid_channels, mid_channels, num_blocks)

        # upsampling module
        self.reconstruction = ResidualBlocksWithInputConv(
            5 * mid_channels, mid_channels, 5)
        self.upsample1 = PixelShufflePack(mid_channels,
                                          mid_channels,
                                          2,
                                          upsample_kernel=3)
        self.upsample2 = PixelShufflePack(mid_channels,
                                          64,
                                          2,
                                          upsample_kernel=3)
        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(scale_factor=4,
                                        mode='bilinear',
                                        align_corners=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        # check if the sequence is augmented by flipping
        self.is_mirror_extended = False

        if len(self.deform_align) > 0:
            self.is_with_alignment = True
        else:
            self.is_with_alignment = False
            warnings.warn(
                'Deformable alignment module is not added. '
                'Probably your CUDA is not configured correctly. DCN can only '
                'be used with CUDA enabled. Alignment is skipped now.')
Esempio n. 10
0
    def __init__(self,
                 in_size,
                 out_size,
                 img_channels=3,
                 rrdb_channels=64,
                 num_rrdbs=23,
                 style_channels=512,
                 num_mlps=8,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01,
                 default_style_mode='mix',
                 eval_style_mode='single',
                 mix_prob=0.9,
                 pretrained=None,
                 bgr2rgb=False):

        super().__init__()

        # input size must be strictly smaller than output size
        if in_size >= out_size:
            raise ValueError('in_size must be smaller than out_size, but got '
                             f'{in_size} and {out_size}.')

        # latent bank (StyleGANv2), with weights being fixed
        self.generator = build_component(
            dict(type='StyleGANv2Generator',
                 out_size=out_size,
                 style_channels=style_channels,
                 num_mlps=num_mlps,
                 channel_multiplier=channel_multiplier,
                 blur_kernel=blur_kernel,
                 lr_mlp=lr_mlp,
                 default_style_mode=default_style_mode,
                 eval_style_mode=eval_style_mode,
                 mix_prob=mix_prob,
                 pretrained=pretrained,
                 bgr2rgb=bgr2rgb))
        self.generator.requires_grad_(False)

        self.in_size = in_size
        self.style_channels = style_channels
        channels = self.generator.channels

        # encoder
        num_styles = int(np.log2(out_size)) * 2 - 2
        encoder_res = [2**i for i in range(int(np.log2(in_size)), 1, -1)]
        self.encoder = nn.ModuleList()
        self.encoder.append(
            nn.Sequential(
                RRDBFeatureExtractor(img_channels,
                                     rrdb_channels,
                                     num_blocks=num_rrdbs),
                nn.Conv2d(rrdb_channels, channels[in_size], 3, 1, 1,
                          bias=True),
                nn.LeakyReLU(negative_slope=0.2, inplace=True)))
        for res in encoder_res:
            in_channels = channels[res]
            if res > 4:
                out_channels = channels[res // 2]
                block = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=True),
                    nn.LeakyReLU(negative_slope=0.2, inplace=True),
                    nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
                    nn.LeakyReLU(negative_slope=0.2, inplace=True))
            else:
                block = nn.Sequential(
                    nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=True),
                    nn.LeakyReLU(negative_slope=0.2, inplace=True),
                    nn.Flatten(),
                    nn.Linear(16 * in_channels, num_styles * style_channels))
            self.encoder.append(block)

        # additional modules for StyleGANv2
        self.fusion_out = nn.ModuleList()
        self.fusion_skip = nn.ModuleList()
        for res in encoder_res[::-1]:
            num_channels = channels[res]
            self.fusion_out.append(
                nn.Conv2d(num_channels * 2, num_channels, 3, 1, 1, bias=True))
            self.fusion_skip.append(
                nn.Conv2d(num_channels + 3, 3, 3, 1, 1, bias=True))

        # decoder
        decoder_res = [
            2**i
            for i in range(int(np.log2(in_size)), int(np.log2(out_size) + 1))
        ]
        self.decoder = nn.ModuleList()
        for res in decoder_res:
            if res == in_size:
                in_channels = channels[res]
            else:
                in_channels = 2 * channels[res]

            if res < out_size:
                out_channels = channels[res * 2]
                self.decoder.append(
                    PixelShufflePack(in_channels,
                                     out_channels,
                                     2,
                                     upsample_kernel=3))
            else:
                self.decoder.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels, 64, 3, 1, 1),
                        nn.LeakyReLU(negative_slope=0.2, inplace=True),
                        nn.Conv2d(64, img_channels, 3, 1, 1)))
Esempio n. 11
0
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 mid_channels=64,
                 num_frames=5,
                 deform_groups=8,
                 num_blocks_extraction=5,
                 num_blocks_reconstruction=10,
                 center_frame_idx=2,
                 hr_in=False,
                 with_predeblur=False,
                 with_tsa=True,
                 with_nonlocal=False,
                 with_car=False):
        super(EDVRNet_WoPre, self).__init__()
        self.center_frame_idx = center_frame_idx
        self.hr_in = hr_in
        self.with_predeblur = with_predeblur
        self.with_tsa = with_tsa
        self.with_nonlocal = with_nonlocal
        self.with_car = with_car

        act_cfg = dict(type='LeakyReLU', negative_slope=0.1)

        # extract features for each frame
        if self.with_predeblur:
            self.predeblur = PredeblurModule(
                mid_channels=mid_channels, hr_in=self.hr_in)
            self.conv_1x1 = nn.Conv2d(mid_channels, mid_channels, 1, 1)
        else:
            self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)

        # extract pyramid features
        if self.with_car:
            self.feature_extraction = make_layer(
                CAResidualBlockNoBN,
                num_blocks_extraction,
                mid_channels=mid_channels)
        else:
            self.feature_extraction = make_layer(
                ResidualBlockNoBN,
                num_blocks_extraction,
                mid_channels=mid_channels)

        self.feat_l2_conv1 = ConvModule(
            mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
        self.feat_l2_conv2 = ConvModule(
            mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
        self.feat_l3_conv1 = ConvModule(
            mid_channels, mid_channels, 3, 2, 1, act_cfg=act_cfg)
        self.feat_l3_conv2 = ConvModule(
            mid_channels, mid_channels, 3, 1, 1, act_cfg=act_cfg)
        # pcd alignment
        self.pcd_alignment = PCDAlignment(
            mid_channels=mid_channels, deform_groups=deform_groups)
        # tsa fusion
        if self.with_tsa:
            self.fusion = TSAFusion(
                mid_channels=mid_channels,
                num_frames=num_frames,
                center_frame_idx=self.center_frame_idx)
        else:
            self.fusion = nn.Conv2d(num_frames * mid_channels, mid_channels, 1,
                                    1)
        # non local module
        if self.with_nonlocal:
            self.non_local = NonLocalModule(
                mid_channels=mid_channels,
                num_frames=num_frames,
                center_frame_idx=self.center_frame_idx)

        # reconstruction
        if self.with_car:
            self.reconstruction = make_layer(
                CAResidualBlockNoBN,
                num_blocks_reconstruction,
                mid_channels=mid_channels)
        else:
            self.reconstruction = make_layer(
                ResidualBlockNoBN,
                num_blocks_reconstruction,
                mid_channels=mid_channels)
        # upsample
        self.upsample1 = PixelShufflePack(
            mid_channels, mid_channels, 2, upsample_kernel=3)
        self.upsample2 = PixelShufflePack(
            mid_channels, mid_channels, 2, upsample_kernel=3)
        # we fix the output channels in the last few layers to 64.
        self.conv_hr = nn.Conv2d(mid_channels, 64, 3, 1, 1)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
        self.img_upsample = nn.Upsample(
            scale_factor=4, mode='bilinear', align_corners=False)
        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)