예제 #1
0
파일: tssn_ir.py 프로젝트: LaineyHu/TSSN
    def __init__(self, args):
        super(TSSN, self).__init__()
        r = args.scale[0]
        G0 = args.G0
        kSize = args.TSSNkSize

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)

        C = 16
        G = 64

        # Shallow feature extraction net
        self.SFENet1 = nn.Conv2d(args.n_colors,
                                 G0,
                                 kSize,
                                 padding=(kSize - 1) // 2,
                                 stride=1)
        self.SFENet2 = nn.Conv2d(G0,
                                 G0,
                                 kSize,
                                 padding=(kSize - 1) // 2,
                                 stride=1)

        #multi-branch
        self.branch1 = nn.Sequential(*[
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C)
        ])
        self.branch2 = nn.Sequential(*[
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C),
            SRB(growRate0=G0, growRate=G, nConvLayers=C)
        ])

        # Global Feature Fusion
        self.GFF = nn.Sequential(*[
            common.GELayer(),
            common.SELayer(2 * G0),
            nn.Conv2d(2 * G0, G0, kSize, padding=(kSize - 1) // 2, stride=1)
        ])

        self.restore = nn.Conv2d(G,
                                 args.n_colors,
                                 kSize,
                                 padding=(kSize - 1) // 2,
                                 stride=1)

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
예제 #2
0
    def __init__(self, args):
        super(TSSN, self).__init__()
        r = args.scale[0]
        G0 = args.G0
        kSize = args.TSSNkSize
        
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)

        #conv layers, out channels
        C = 16 
        G = 64

        # Shallow feature extraction net
        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)

        # multi-branch
        self.branch1 = nn.Sequential(*[
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C)
        ])
        self.branch2 = nn.Sequential(*[
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C),
            SRB(growRate0 = G0, growRate = G, nConvLayers = C)
        ])

        # Global Feature Fusion
        self.GFF = nn.Sequential(*[
            common.SELayer(2*G0),
            nn.Conv2d(2*G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        ])

        # Up-sampling net
        if r == 2 or r == 3:
            self.UPNet = nn.Sequential(*[
                nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(r),
                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
            ])
        elif r == 4:
            self.UPNet = nn.Sequential(*[
                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(2),
                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(2),
                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
            ])
        elif r == 8:
            self.UPNet = nn.Sequential(*[
                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(2),
                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(2),
                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
                nn.PixelShuffle(2),
                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
            ])

        else:
            raise ValueError("scale must be 2 or 3 or 4.")

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)