Exemple #1
0
    def __init__(self,in_channels, out_channels, num_features, bp_stages, upscale_factor=4, norm_type=None, act_type='prelu', mode='NAC', upsample_mode='upconv'):
        super(D_DBPN, self).__init__()

        if upscale_factor == 2:
            stride = 2
            padding = 2
            projection_filter = 6
        elif upscale_factor == 4:
            stride = 4
            padding = 2
            projection_filter = 8
        elif upscale_factor == 8:
            stride = 8
            padding = 2
            projection_filter = 12

        feature_extract_1 = B.ConvBlock(in_channels, 256, kernel_size=3, norm_type=norm_type, act_type=act_type)
        feature_extract_2 = B.ConvBlock(256, num_features, kernel_size=1, norm_type=norm_type, act_type=act_type)

        bp_units = B.DensebackprojBlock(num_features, num_features, projection_filter, bp_stages, stride=stride, valid_padding=False,
                                                padding=padding, norm_type=norm_type, act_type=act_type)

        conv_hr = B.ConvBlock(num_features*bp_stages, out_channels, kernel_size=3, norm_type=None, act_type=None)

        self.network = B.sequential(feature_extract_1, feature_extract_2, bp_units, conv_hr)
Exemple #2
0
    def __init__(self,in_channels, out_channels, num_features, bp_stages, upscale_factor=4, norm_type=None, act_type='prelu', mode='NAC', upsample_mode='upconv'):
        super(DBPN, self).__init__()

        if upscale_factor == 2:
            stride = 2
            padding = 2
            projection_filter = 6
        elif upscale_factor == 4:
            stride = 4
            padding = 2
            projection_filter = 8
        elif upscale_factor == 8:
            stride = 8
            padding = 2
            projection_filter = 12

        feature_extract_1 = B.ConvBlock(in_channels, 128, kernel_size=3, norm_type=norm_type, act_type=act_type)
        feature_extract_2 = B.ConvBlock(128, num_features, kernel_size=1, norm_type=norm_type, act_type=act_type)

        bp_units = []
        for _ in range(bp_stages-1):
            bp_units.extend([B.UpprojBlock(num_features, num_features, projection_filter, stride=stride, valid_padding=False,
                                                padding=padding, norm_type=norm_type, act_type=act_type),
                            B.DownprojBlock(num_features, num_features, projection_filter, stride=stride, valid_padding=False,
                                                  padding=padding, norm_type=norm_type, act_type=act_type)])

        last_bp_unit = B.UpprojBlock(num_features, num_features, projection_filter, stride=stride, valid_padding=False,
                                           padding=padding, norm_type=norm_type, act_type=act_type)
        conv_hr = B.ConvBlock(num_features, out_channels, kernel_size=1, norm_type=None, act_type=None)

        self.network = B.sequential(feature_extract_1, feature_extract_2, *bp_units, last_bp_unit, conv_hr)
Exemple #3
0
    def __init__(self, in_channels, out_channels, num_branch):
        super(srcnn, self).__init__()
        self.num_branch = num_branch

        self.patch_extraction = B.ConvBlock(in_channels,
                                            64,
                                            kernel_size=9,
                                            norm_type=None,
                                            act_type='relu',
                                            valid_padding=False,
                                            padding=0)
        self.mapping = B.ConvBlock(64,
                                   32,
                                   kernel_size=1,
                                   norm_type=None,
                                   act_type='relu',
                                   valid_padding=False,
                                   padding=0)
        self.reconstruct = B.ConvBlock(32,
                                       out_channels,
                                       kernel_size=5,
                                       norm_type=None,
                                       act_type=None,
                                       valid_padding=False,
                                       padding=0)
Exemple #4
0
    def __init__(self, in_channels, out_channels, num_branch):
        super(vdsr, self).__init__()
        self.conv = nn.ModuleList()

        self.conv_in = B.ConvBlock(in_channels, 64, kernel_size=3, norm_type=None, act_type='relu',valid_padding=True, bias=False)

        for i in range(18):
            self.conv.append(B.ConvBlock(64, 64, kernel_size=3, norm_type=None, act_type='relu',valid_padding=True, bias=False))

        self.conv_out = B.ConvBlock(64, out_channels, kernel_size=3, norm_type=None, act_type=None, valid_padding=True, bias=False)
Exemple #5
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_features,
                 num_blocks,
                 upscale_factor=4,
                 norm_type='bn',
                 act_type='relu',
                 mode='NAC',
                 upsample_mode='upconv'):
        super(SRResNet, self).__init__()

        feature_extract = B.ConvBlock(in_channels,
                                      num_features,
                                      kernel_size=9,
                                      norm_type=None,
                                      act_type='prelu')
        res_blocks = [
            B.ResBlock(num_features,
                       num_features,
                       num_features,
                       kernel_size=3,
                       norm_type=norm_type,
                       act_type=act_type,
                       mode=mode) for _ in range(num_blocks)
        ]
        conv_lr = B.ConvBlock(num_features,
                              num_features,
                              kernel_size=3,
                              norm_type=norm_type,
                              act_type=act_type,
                              mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.UpsampleConvBlock(upscale_factor=upscale_factor, in_channels=num_features, out_channels=num_features,\
                                               kernel_size=3, stride=1)
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.PixelShuffleBlock()
        else:
            raise NotImplementedError('Upsample mode [%s] is not supported!' %
                                      upsample_mode)

        conv_hr = B.ConvBlock(num_features,
                              out_channels,
                              kernel_size=9,
                              norm_type=None,
                              act_type=None)

        # TODO: dense connection
        # TODO: Notice: We must unpack the residual blocks using ‘*’ before building a nn.Sequential
        self.network = B.sequential(
            feature_extract,
            B.ShortcutBlock(B.sequential(*res_blocks,
                                         conv_lr)), upsample_block, conv_hr)
Exemple #6
0
    def __init__(self, in_channels, out_channels, num_branch):
        super(vdsr_k, self).__init__()
        split_layer = 4
        self.num_branch = num_branch

        # self.conv = nn.Sequential()
        conv_in = B.ConvBlock(in_channels, 64, kernel_size=3, norm_type=None, act_type='relu',valid_padding=True, bias=False)
        conv_blocks = [B.ConvBlock(64, 64, kernel_size=3, norm_type=None, act_type='relu',valid_padding=True, bias=False) for _ in range(18 - split_layer)]
        self.conv = B.sequential(conv_in, *conv_blocks)

        self.conv_branch = nn.ModuleList()

        for _ in range(self.num_branch):
            sub_branch = [B.ConvBlock(64, 64, kernel_size=3, norm_type=None, act_type='relu',valid_padding=True, bias=False) for _ in range(split_layer)]
            conv_out = B.ConvBlock(64, out_channels, kernel_size=3, norm_type=None, act_type=None, valid_padding=True, bias=False)
            self.conv_branch.append(B.sequential(*sub_branch, conv_out))
Exemple #7
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_features,
                 num_recurs,
                 upscale_factor,
                 norm_type=None,
                 act_type='prelu'):
        super(DRUDN, self).__init__()

        if upscale_factor == 2:
            stride = 2
            padding = 2
            projection_filter = 6
        if upscale_factor == 3:
            stride = 3
            padding = 2
            projection_filter = 7
        elif upscale_factor == 4:
            stride = 4
            padding = 2
            projection_filter = 8

        self.num_recurs = num_recurs

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)

        self.sub_mean = B.MeanShift(rgb_mean, rgb_std)
        self.conv_in = B.ConvBlock(in_channels,
                                   num_features,
                                   kernel_size=3,
                                   act_type=act_type,
                                   norm_type=None)

        self.up1_1 = B.DeconvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=act_type)
        self.up1_2 = B.DeconvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=act_type)
        self.up1_3 = B.DeconvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=act_type)
        self.down1_1 = B.ConvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=act_type)
        self.down1_2 = B.ConvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=act_type)
        self.down1_3 = B.ConvBlock(num_features,
                                   num_features,
                                   projection_filter,
                                   stride=stride,
                                   padding=padding,
                                   norm_type=norm_type,
                                   act_type=None)
        self.deconv1 = B.DeconvBlock(num_features,
                                     num_features,
                                     projection_filter,
                                     stride=stride,
                                     padding=padding,
                                     norm_type=norm_type,
                                     act_type=act_type)
        self.conv_feat = B.ConvBlock(num_features,
                                     num_features,
                                     kernel_size=3,
                                     act_type=act_type,
                                     norm_type=None)
        self.conv_out = B.ConvBlock(num_features,
                                    out_channels,
                                    kernel_size=3,
                                    act_type=None,
                                    norm_type=None)
        self.add_mean = B.MeanShift(rgb_mean, rgb_std, 1)