示例#1
0
文件: sgnet.py 项目: laulampaul/sgnet
    def __init__(self, opt):
        super(green_res, self).__init__()
        sr_n_resblocks = opt.sr_n_resblocks
        dm_n_resblocks = opt.dm_n_resblocks
        sr_n_feats = opt.channels
        dm_n_feats = opt.channels
        scale = opt.scale

        denoise = opt.denoise
        block_type = opt.block_type
        act_type = opt.act_type
        bias = opt.bias
        norm_type = opt.norm_type
        self.head = common.ConvBlock(2,
                                     dm_n_feats,
                                     5,
                                     act_type=act_type,
                                     bias=True)
        self.r1 = common.RRDB(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type,
                              act_type, 0.2)
        self.r2 = common.RRDB(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type,
                              act_type, 0.2)
        #self.r3 = common.RRDB2(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
        #self.r4 = common.RRDB2(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
        #self.r5 = common.RRDB2(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
        #self.r6 = common.RRDB2(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
        self.final = common.ConvBlock(dm_n_feats, dm_n_feats, 3, bias=bias)

        self.up = nn.Sequential(
            common.Upsampler(2, dm_n_feats, norm_type, act_type, bias=bias),
            common.ConvBlock(dm_n_feats, 1, 3, bias=True),
            nn.LeakyReLU(0.2, inplace=True))
示例#2
0
    def __init__(self, opt):
        super(NET, self).__init__()

        n_resblocks = opt.n_resblocks
        n_feats = opt.channels
        bias = opt.bias
        norm_type = opt.norm_type
        act_type = opt.act_type
        block_type = opt.block_type

        head = [common.ConvBlock(4, n_feats, 5, act_type=act_type, bias=True)]
        if block_type.lower() == 'rrdb':
            resblock = [common.RRDB(n_feats, n_feats, 3,
                                       1, bias, norm_type, act_type, 0.2)
                            for _ in range(n_resblocks)]
        elif block_type.lower() == 'res':
            resblock = [common.ResBlock(n_feats, 3, norm_type, act_type, res_scale=1, bias=bias)
                            for _ in range(n_resblocks)]
        else:
            raise RuntimeError('block_type is not supported')

        resblock += [common.ConvBlock(n_feats, n_feats, 3, bias=True)]
        tail = [common.ConvBlock(n_feats, 3, 3, bias=True)]

        self.model = nn.Sequential(*head, common.ShortcutBlock(nn.Sequential(*resblock)), *tail)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
示例#3
0
    def __init__(self, opt):
        super(NET, self).__init__()
        # parameter
        denoise = opt.denoise
        block_type = opt.block_type
        n_feats = opt.channels
        act_type = opt.act_type
        bias = opt.bias
        norm_type = opt.norm_type
        n_resblocks = opt.n_resblocks

        # architecture
        if denoise:
            dm_head = [
                common.ConvBlock(5, n_feats, 5, act_type=act_type, bias=True)
            ]
        else:
            dm_head = [
                common.ConvBlock(4, n_feats, 5, act_type=act_type, bias=True)
            ]
        if block_type.lower() == 'rrdb':
            dm_resblock = [
                common.RRDB(n_feats, n_feats, 3, 1, bias, norm_type, act_type,
                            0.2) for _ in range(n_resblocks)
            ]
        elif block_type.lower() == 'res':
            dm_resblock = [
                common.ResBlock(n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        dm_resblock += [common.ConvBlock(n_feats, n_feats, 3, bias=True)]
        m_dm_up = [
            common.Upsampler(2, n_feats, norm_type, act_type, bias=bias),
            common.ConvBlock(n_feats, 3, 3, bias=True)
        ]

        self.model_dm = nn.Sequential(
            *dm_head, common.ShortcutBlock(nn.Sequential(*dm_resblock)),
            *m_dm_up)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
示例#4
0
    def __init__(self, opt):
        super(NET, self).__init__()

        sr_n_resblocks = opt.sr_n_resblocks
        dm_n_resblocks = opt.dm_n_resblocks
        sr_n_feats = opt.channels
        dm_n_feats = opt.channels
        scale = opt.scale

        denoise = opt.denoise
        block_type = opt.block_type
        act_type = opt.act_type
        bias = opt.bias
        norm_type = opt.norm_type

        # define sr module
        if denoise:
            m_sr_head = [
                common.ConvBlock(5,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        else:
            m_sr_head = [
                common.ConvBlock(4,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        if block_type.lower() == 'rrdb':
            m_sr_resblock = [
                common.RRDB(sr_n_feats, sr_n_feats, 3, 1, bias, norm_type,
                            act_type, 0.2) for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'dudb':
            m_sr_resblock = [
                common.DUDB(sr_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'res':
            m_sr_resblock = [
                common.ResBlock(sr_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(sr_n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        m_sr_resblock += [
            common.ConvBlock(sr_n_feats, sr_n_feats, 3, bias=bias)
        ]
        m_sr_up = [
            common.Upsampler(scale, sr_n_feats, norm_type, act_type,
                             bias=bias),
            common.ConvBlock(sr_n_feats, 4, 3, bias=True)
        ]

        # branch for sr_raw output
        m_sr_tail = [nn.PixelShuffle(2)]

        # define demosaick module
        m_dm_head = [
            common.ConvBlock(4, dm_n_feats, 5, act_type=act_type, bias=True)
        ]

        if block_type.lower() == 'rrdb':
            m_dm_resblock = [
                common.RRDB(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type,
                            act_type, 0.2) for _ in range(dm_n_resblocks)
            ]
        elif block_type.lower() == 'dudb':
            m_dm_resblock = [
                common.DUDB(dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(dm_n_resblocks)
            ]
        elif block_type.lower() == 'res':
            m_dm_resblock = [
                common.ResBlock(dm_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(dm_n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        m_dm_resblock += [
            common.ConvBlock(dm_n_feats, dm_n_feats, 3, bias=bias)
        ]
        m_dm_up = [
            common.Upsampler(2, dm_n_feats, norm_type, act_type, bias=bias),
            common.ConvBlock(dm_n_feats, 3, 3, bias=True)
        ]

        self.model_sr = nn.Sequential(
            *m_sr_head, common.ShortcutBlock(nn.Sequential(*m_sr_resblock)),
            *m_sr_up)
        self.sr_output = nn.Sequential(*m_sr_tail)
        self.model_dm = nn.Sequential(
            *m_dm_head, common.ShortcutBlock(nn.Sequential(*m_dm_resblock)),
            *m_dm_up)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
示例#5
0
    def __init__(self,
                 sr_n_resblocks=6,
                 dm_n_resblock=6,
                 sr_n_feats=64,
                 dm_n_feats=64,
                 scale=2,
                 denoise=True,
                 bias=True,
                 norm_type=False,
                 act_type='relu',
                 block_type='rrdb'):
        super(NET, self).__init__()
        # act = nn.LeakyReLU(negative_slope=0.1, inplace=False)
        # act = nn.PReLU(n_feats)

        # define sr module
        if denoise:
            m_sr_head = [
                common.ConvBlock(5,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        else:
            m_sr_head = [
                common.ConvBlock(4,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        if block_type.lower() == 'rrdb':
            m_sr_resblock = [
                common.RRDB(sr_n_feats, sr_n_feats, 3, 1, bias, norm_type,
                            act_type, 0.2) for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'dudb':
            m_sr_resblock = [
                common.DUDB(sr_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'res':
            m_sr_resblock = [
                common.ResBlock(sr_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(sr_n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        m_sr_resblock += [
            common.ConvBlock(sr_n_feats, sr_n_feats, 3, bias=bias)
        ]
        m_sr_up = [
            common.Upsampler(scale, sr_n_feats, norm_type, act_type,
                             bias=bias),
            common.ConvBlock(sr_n_feats, 4, 3, bias=True)
        ]

        # branch for sr_raw output
        m_sr_tail = [nn.PixelShuffle(2)]

        # define demosaick module
        m_dm_head = [
            common.ConvBlock(4, dm_n_feats, 5, act_type=act_type, bias=True)
        ]

        if block_type.lower() == 'rrdb':
            m_dm_resblock = [
                common.RRDB(dm_n_feats, dm_n_feats, 3, 1, bias, norm_type,
                            act_type, 0.2) for _ in range(dm_n_resblock)
            ]
        elif block_type.lower() == 'dudb':
            m_dm_resblock = [
                common.DUDB(dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(dm_n_resblock)
            ]
        elif block_type.lower() == 'res':
            m_dm_resblock = [
                common.ResBlock(dm_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(dm_n_resblock)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        m_dm_resblock += [
            common.ConvBlock(dm_n_feats, dm_n_feats, 3, bias=bias)
        ]
        m_dm_up = [
            common.Upsampler(2, dm_n_feats, norm_type, act_type, bias=bias),
            common.ConvBlock(dm_n_feats, 3, 3, bias=True)
        ]

        self.model_sr = nn.Sequential(
            *m_sr_head, common.ShortcutBlock(nn.Sequential(*m_sr_resblock)),
            *m_sr_up)
        self.sr_output = nn.Sequential(*m_sr_tail)
        self.model_dm = nn.Sequential(
            *m_dm_head, common.ShortcutBlock(nn.Sequential(*m_dm_resblock)),
            *m_dm_up)

        for m in self.modules():
            # pdb.set_trace()
            if isinstance(m, nn.Conv2d):
                # Xavier
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
示例#6
0
文件: sgnet.py 项目: laulampaul/sgnet
    def __init__(self, opt):
        super(NET, self).__init__()

        sr_n_resblocks = opt.sr_n_resblocks
        dm_n_resblocks = opt.dm_n_resblocks
        sr_n_feats = opt.channels
        dm_n_feats = opt.channels
        scale = opt.scale

        denoise = opt.denoise
        block_type = opt.block_type
        act_type = opt.act_type
        bias = opt.bias
        norm_type = opt.norm_type

        # define sr module
        if denoise:
            m_sr_head = [
                common.ConvBlock(6,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        else:
            m_sr_head = [
                common.ConvBlock(4,
                                 sr_n_feats,
                                 5,
                                 act_type=act_type,
                                 bias=True)
            ]
        if block_type.lower() == 'rrdb':
            m_sr_resblock = [
                common.RRDB(sr_n_feats, sr_n_feats, 3, 1, bias, norm_type,
                            act_type, 0.2) for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'dudb':
            m_sr_resblock = [
                common.DUDB(sr_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(sr_n_resblocks)
            ]
        elif block_type.lower() == 'res':
            m_sr_resblock = [
                common.ResBlock(sr_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(sr_n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        m_sr_resblock += [
            common.ConvBlock(sr_n_feats, sr_n_feats, 3, bias=bias)
        ]
        m_sr_up = [
            common.Upsampler(scale, sr_n_feats, norm_type, act_type,
                             bias=bias),
            common.ConvBlock(sr_n_feats, 4, 3, bias=True)
        ]

        # branch for sr_raw output
        m_sr_tail = [nn.PixelShuffle(2)]

        # define demosaick module
        m_dm_head = [
            common.ConvBlock(4, dm_n_feats, 5, act_type=act_type, bias=True)
        ]

        if block_type.lower() == 'rrdb':
            m_dm_resblock = m_res(
                opt)  #[common.RRDB(dm_n_feats, dm_n_feats, 3,
            #1, bias, norm_type, act_type, 0.2)
            #for _ in range(dm_n_resblocks)]
        elif block_type.lower() == 'dudb':
            m_dm_resblock = [
                common.DUDB(dm_n_feats, 3, 1, bias, norm_type, act_type, 0.2)
                for _ in range(dm_n_resblocks)
            ]
        elif block_type.lower() == 'res':
            m_dm_resblock = [
                common.ResBlock(dm_n_feats,
                                3,
                                norm_type,
                                act_type,
                                res_scale=1,
                                bias=bias) for _ in range(dm_n_resblocks)
            ]
        else:
            raise RuntimeError('block_type is not supported')

        #m_dm_resblock += [common.ConvBlock(dm_n_feats, dm_n_feats, 3, bias=bias)]
        m_dm_up = [
            common.Upsampler(2, dm_n_feats, norm_type, act_type, bias=bias)
        ]
        #common.ConvBlock(dm_n_feats, 3, 3, bias=True)]

        self.model_sr = nn.Sequential(
            *m_sr_head, common.ShortcutBlock(nn.Sequential(*m_sr_resblock)),
            *m_sr_up)
        self.sr_output = nn.Sequential(*m_sr_tail)
        self.model_dm1 = nn.Sequential(*m_dm_head)
        self.model_dm2 = m_dm_resblock
        self.model_dm3 = nn.Sequential(*m_dm_up)

        greenresblock = green_res(opt)
        self.green = greenresblock
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
        #self.sft = SFTLayer()
        self.combine = nn.Sequential(
            common.ConvBlock(dm_n_feats + 1, dm_n_feats, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True))
        self.greenup = nn.Sequential(common.ConvBlock(1, 4, 1, bias=True),
                                     nn.LeakyReLU(0.2, inplace=True),
                                     common.ConvBlock(4, 8, 1, bias=True),
                                     nn.LeakyReLU(0.2, inplace=True))

        self.pac = PacConvTranspose2d(64,
                                      64,
                                      kernel_size=5,
                                      stride=2,
                                      padding=2,
                                      output_padding=1)
        self.final = common.ConvBlock(dm_n_feats, 3, 3, bias=True)
        self.norm = nn.InstanceNorm2d(1)