Esempio n. 1
0
    def __init__(self, n_resblock=24, n_feats=256, scale=2, bias=True, norm_type=False,
                 act_type='prelu'):
        super(NET, self).__init__()

        self.scale = scale
        m = [common.default_conv(1, n_feats, 3, stride=2)]
        m += [nn.PixelShuffle(2),
              common.ConvBlock(n_feats//4, n_feats, bias=True, act_type=act_type)
              ]

        m += [common.ResBlock(n_feats, 3, norm_type, act_type, res_scale=1, bias=bias)
                             for _ in range(n_resblock)]

        for _ in range(int(math.log(scale, 2))):
            m += [nn.PixelShuffle(2),
                  common.ConvBlock(n_feats//4, n_feats, bias=True, act_type=act_type)
                  ]

        m += [common.default_conv(n_feats, 3, 3)]

        self.model = nn.Sequential(*m)
        for m in self.modules():
            # pdb.set_trace()
            if isinstance(m, nn.Conv2d):
                # Xavier
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True
Esempio n. 2
0
    def __init__(self, args, conv3x3=common.default_conv,  conv1x1=common.default_conv):
        super(SRRESNET_CLUSTER, self).__init__()

        n_resblock = args.n_resblocks
        n_feats = args.n_feats
        scale = args.scale[0]
        act_res_flag = args.act_res

        kernel_size = 3
        if act_res_flag == 'No':
            act_res = None
        else:
            act_res = 'prelu'

        head = [common.default_conv(args.n_colors, n_feats, kernel_size=9), nn.PReLU()]
        body = [common.ResBlock(conv3x3, n_feats, kernel_size, bn=True, act=act_res) for _ in range(n_resblock)]
        body.extend([conv3x3(n_feats, n_feats, kernel_size), nn.BatchNorm2d(n_feats)])

        tail = [
            common.Upsampler(conv3x3, scale, n_feats, act=act_res),
            conv3x3(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*head)
        self.body = nn.Sequential(*body)
        self.tail = nn.Sequential(*tail)

        if conv3x3 == common.default_conv:
            # print('Loading from checkpoint {}'.format(args.pretrain_cluster))
            # for (k1, v1), (k2, v2) in zip(self.state_dict().items(), torch.load(args.pretrain_cluster).items()):
            #     print('{:<50}\t{:<50}\t{} {}'.format(k1, k2, list(v1.shape), list(v2.shape)))
            self.load_state_dict(torch.load(args.pretrain_cluster))
Esempio n. 3
0
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        modules_body = []
        modules_body = [
            RCAB(
                common.default_conv, 64, 3, 16, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
            for _ in range(20)]
        modules_body.append(common.default_conv(64, 64, 3))
        self.body = nn.Sequential(*modules_body)

        self.nfe = 0
Esempio n. 4
0
    def __init__(self, args, conv=common.groups_conv):
        super(EMSR, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)
        groups = args.n_groups
        self.sub_mean = common.MeanShift(args.rgb_range)
        self.add_mean = common.MeanShift(args.rgb_range, sign=1)

        # define head module
        m_head = [common.default_conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.group_ResBlock(
                conv,
                n_feats,
                kernel_size,
                act=act,
                res_scale=args.res_scale,
                groups=groups,
            ) for _ in range(n_resblocks)
        ]
        m_body.append(common.default_conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.group_Upsampler(common.default_conv,
                                   scale,
                                   n_feats,
                                   act=False),
            common.default_conv(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)
Esempio n. 5
0
    def __init__(self, channels_in, channels_out, kernel_size, reduction):
        super(DA_conv, self).__init__()
        self.channels_out = channels_out
        self.channels_in = channels_in
        self.kernel_size = kernel_size

        self.kernel = nn.Sequential(
            nn.Linear(64, 64, bias=False), nn.LeakyReLU(0.1, True),
            nn.Linear(64, 64 * self.kernel_size * self.kernel_size,
                      bias=False))
        self.conv = common.default_conv(channels_in, channels_out, 1)
        self.ca = CA_layer(channels_in, channels_out, reduction)

        self.relu = nn.LeakyReLU(0.1, True)
Esempio n. 6
0
    def __init__(self, args, conv=common.SeparableConv):
        super(EDSR, self).__init__()

        n_resblock = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)

        # define head module
        m_head = [common.default_conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(conv,
                            n_feats,
                            kernel_size,
                            act=act,
                            res_scale=args.res_scale)
            for _ in range(n_resblock)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_feats, act=False),
            nn.Conv2d(n_feats,
                      args.n_colors,
                      kernel_size,
                      padding=(kernel_size // 2))
        ]

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)
Esempio n. 7
0
    def __init__(self, args):
        super(MWPDO, self).__init__()
        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        self.scale_idx = 0
        nColor = args.n_colors

        act = nn.ReLU(True)

        self.DWT = common.DWT()
        self.IWT = common.IWT()

        n = 1
        m_head = [mwpdo_layers.BBlock(nColor, n_feats, p=8, act=act, bn=False)]
        d_l0 = []
        d_l0.append(
            mwpdo_layers.DBlock_com1(n_feats, n_feats, p=8, act=act, bn=False))

        d_l1 = [
            mwpdo_layers.BBlock1(n_feats * 4,
                                 n_feats * 2,
                                 p=8,
                                 act=act,
                                 bn=False)
        ]
        d_l1.append(
            mwpdo_layers.DBlock_com1(n_feats * 2,
                                     n_feats * 2,
                                     p=8,
                                     act=act,
                                     bn=False))

        d_l2 = []
        d_l2.append(
            mwpdo_layers.BBlock1(n_feats * 8,
                                 n_feats * 4,
                                 p=8,
                                 act=act,
                                 bn=False))
        d_l2.append(
            mwpdo_layers.DBlock_com1(n_feats * 4,
                                     n_feats * 4,
                                     p=8,
                                     act=act,
                                     bn=False))
        pro_l3 = []
        pro_l3.append(
            mwpdo_layers.BBlock1(n_feats * 16,
                                 n_feats * 8,
                                 p=8,
                                 act=act,
                                 bn=False))
        pro_l3.append(
            mwpdo_layers.DBlock_com(n_feats * 8,
                                    n_feats * 8,
                                    p=8,
                                    act=act,
                                    bn=False))
        pro_l3.append(
            mwpdo_layers.DBlock_inv(n_feats * 8,
                                    n_feats * 8,
                                    p=8,
                                    act=act,
                                    bn=False))
        pro_l3.append(
            mwpdo_layers.BBlock1(n_feats * 8,
                                 n_feats * 16,
                                 p=8,
                                 act=act,
                                 bn=False))

        i_l2 = [
            mwpdo_layers.DBlock_inv1(n_feats * 4,
                                     n_feats * 4,
                                     p=8,
                                     act=act,
                                     bn=False)
        ]
        i_l2.append(
            mwpdo_layers.BBlock1(n_feats * 4,
                                 n_feats * 8,
                                 p=8,
                                 act=act,
                                 bn=False))

        i_l1 = [
            mwpdo_layers.DBlock_inv1(n_feats * 2,
                                     n_feats * 2,
                                     p=8,
                                     act=act,
                                     bn=False)
        ]
        i_l1.append(
            mwpdo_layers.BBlock1(n_feats * 2,
                                 n_feats * 4,
                                 p=8,
                                 act=act,
                                 bn=False))

        i_l0 = [
            mwpdo_layers.DBlock_inv1(n_feats, n_feats, p=8, act=act, bn=False)
        ]

        #m_tail = [mwpdo_layers.g_conv2d(n_feats, nColor, p=8,partial=mwpdo_layers.partial_dict_0,tran=mwpdo_layers.tran_to_partial_coef_0)]
        m_tail = [common.default_conv(n_feats * 8, nColor, kernel_size)]

        self.head = nn.Sequential(*m_head)
        self.d_l2 = nn.Sequential(*d_l2)
        self.d_l1 = nn.Sequential(*d_l1)
        self.d_l0 = nn.Sequential(*d_l0)
        self.pro_l3 = nn.Sequential(*pro_l3)
        self.i_l2 = nn.Sequential(*i_l2)
        self.i_l1 = nn.Sequential(*i_l1)
        self.i_l0 = nn.Sequential(*i_l0)
        self.tail = nn.Sequential(*m_tail)
Esempio n. 8
0
    def __init__(self, net, args, checkpoint, all_params=False):
        super().__init__()
        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))
        kernel_size = 3
        self.args = args

        self.head = net.model.head

        #         self.tail=net.model.tail
        modules_head = [
            common.default_conv(args.channels, args.n_feats, kernel_size)
        ]

        rgb_mean_pr = [0.00216697]
        rgb_std_pr = [1.0]
        self.sub_mean_pr = common.MeanShift(993.9646, rgb_mean_pr, rgb_std_pr,
                                            1)

        if all_params:
            modules_tail = [
                common.Upsampler(common.default_conv,
                                 args.scale[0],
                                 args.n_feats,
                                 act=False),
                #             common.default_conv(args.n_feats, 1, kernel_size)
                common.default_conv(args.n_feats, args.channels, kernel_size)
            ]
            self.add_mean = common.MeanShift(993.9646, rgb_mean_pr, rgb_std_pr,
                                             args.channels, 1)
        else:
            modules_tail = [
                common.Upsampler(common.default_conv,
                                 args.scale[0],
                                 args.n_feats,
                                 act=False),
                common.default_conv(args.n_feats, 1, kernel_size)
                #                 common.default_conv(args.n_feats, args.channels, kernel_size)
            ]
            self.add_mean = common.MeanShift(600, rgb_mean_pr, rgb_std_pr, 1,
                                             1)

#         rgb_mean = (0.0020388064770,0.0020388064770,0.0020388064770)
#         if args.channels==1:
#             rgb_mean = [0.0020388064770]
#             rgb_std = [1.0]
#         if args.channels==2:
#             rgb_mean = [0.0020388064770,0.0020388064770]
#             rgb_std = [1.0,1.0]
#         if args.channels==3:
#             rgb_mean = [0.0020388064770,0.0020388064770,0.0020388064770]
#             rgb_std = [1.0,1.0,1.0]

        rgb_mean_dem = [0.05986051]
        rgb_std_dem = [1.0]
        self.sub_mean_dem = common.MeanShift(2228.3303, rgb_mean_dem,
                                             rgb_std_dem, 1)

        rgb_mean_psl = [0.980945]
        rgb_std_psl = [1.0]
        self.sub_mean_psl = common.MeanShift(103005.8125, rgb_mean_psl,
                                             rgb_std_psl, 1)
        ########################################################################################

        rgb_mean_zg = [0.88989586]
        rgb_std_zg = [1.]
        self.sub_mean_zg = common.MeanShift(1693.5935, rgb_mean_zg, rgb_std_zg,
                                            1)

        rgb_mean_tasmax = [0.8674892]
        rgb_std_tasmax = [1.0]
        self.sub_mean_tasmax = common.MeanShift(41.89, rgb_mean_tasmax,
                                                rgb_std_tasmax, 1)

        rgb_mean_tasmin = [0.964896]
        rgb_std_tasmin = [1.0]
        self.sub_mean_tasmin = common.MeanShift(308.69238, rgb_mean_tasmin,
                                                rgb_std_tasmin, 1)

        self.body = net.model.body
        self.head = nn.Sequential(*modules_head)
        self.tail = nn.Sequential(*modules_tail)