示例#1
0
    def __init__(self, args, conv=basenet.default_conv):
        super(EDSR, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale
        act = nn.ReLU(True)

        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            basenet.ResBlock(conv, n_feats, kernel_size, act=act)
            for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            basenet.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)
示例#2
0
    def __init__(self, args, conv=basenet.default_conv):
        super(Upscalar, self).__init__()

        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale
        act = nn.ReLU(True)

        m_upscalar = [
            basenet.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        self.upscalar = nn.Sequential(*m_upscalar)
示例#3
0
    def __init__(self, args, conv=basenet.default_conv):
        super(RCAN, self).__init__()

        n_resgroups = args.n_resgroups
        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        reduction = args.reduction
        #scale = args.scale[0]
        scale = args.scale
        act = nn.ReLU(True)

        # RGB mean for DIV2K
        #self.sub_mean = basenet.MeanShift(args.rgb_range)

        # define head module
        modules_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        modules_body = [
            ResidualGroup(
                conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) \
            for _ in range(n_resgroups)]

        modules_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        modules_tail = [
            basenet.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        #self.add_mean = basenet.MeanShift(args.rgb_range, sign=1)

        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
示例#4
0
    def __init__(self, args, conv=basenet.default_conv):
        super(MDSR, self).__init__()
        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        act = nn.ReLU(True)
        self.scale_idx = 0
        self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
        self.sub_mean = basenet.MeanShift(args.rgb_range)
        self.add_mean = basenet.MeanShift(args.rgb_range, sign=1)

        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        self.pre_process = nn.ModuleList([
            nn.Sequential(
                basenet.ResBlock(conv, n_feats, 5, act=act),
                basenet.ResBlock(conv, n_feats, 5, act=act)
            ) for _ in args.scale
        ])

        m_body = [
            basenet.ResBlock(
                conv, n_feats, kernel_size, act=act
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        self.upsample = nn.ModuleList([
            basenet.Upsampler(conv, s, n_feats, act=False) for s in args.scale
        ])

        m_tail = [conv(n_feats, args.n_colors, kernel_size)]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)
示例#5
0
    def __init__(self, args):
        super(DRLN, self).__init__()
        
        #n_resgroups = args.n_resgroups
        #n_resblocks = args.n_resblocks
        #n_feats = args.n_feats
        #kernel_size = 3
        #reduction = args.reduction 
        #scale = args.scale[0]
        #act = nn.ReLU(True)

        #self.scale = args.scale[0]
        self.scale = args.scale
        #chs=64
        chs=args.n_feats

        #self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        #self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)
        
        self.head = nn.Conv2d(3, chs, 3, 1, 1)

        self.b1 = Block(chs, chs)
        self.b2 = Block(chs, chs)
        self.b3 = Block(chs, chs)
        self.b4 = Block(chs, chs)
        self.b5 = Block(chs, chs)
        self.b6 = Block(chs, chs)
        self.b7 = Block(chs, chs)
        self.b8 = Block(chs, chs)
        self.b9 = Block(chs, chs)
        self.b10 = Block(chs, chs)
        self.b11 = Block(chs, chs)
        self.b12 = Block(chs, chs)
        self.b13 = Block(chs, chs)
        self.b14 = Block(chs, chs)
        self.b15 = Block(chs, chs)
        self.b16 = Block(chs, chs)
        self.b17 = Block(chs, chs)
        self.b18 = Block(chs, chs)
        self.b19 = Block(chs, chs)
        self.b20 = Block(chs, chs)

        self.c1 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c2 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c3 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c4 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c5 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c6 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c7 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c8 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c9 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c10 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c11 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c12 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c13 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c14 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c15 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c16 = basenet.BasicBlock(chs*5, chs, 3, 1, 1)
        self.c17 = basenet.BasicBlock(chs*2, chs, 3, 1, 1)
        self.c18 = basenet.BasicBlock(chs*3, chs, 3, 1, 1)
        self.c19 = basenet.BasicBlock(chs*4, chs, 3, 1, 1)
        self.c20 = basenet.BasicBlock(chs*5, chs, 3, 1, 1)

        #self.upsample = basenet.UpsampleBlock(chs, self.scale , multi_scale=False)
        self.upsample = basenet.Upsampler(basenet.default_conv, self.scale, chs, act=False)
        #self.convert = ops.ConvertBlock(chs, chs, 20)
        self.tail = nn.Conv2d(chs, 3, 3, 1, 1)