Esempio n. 1
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'],
                                    upscale=opt_net['scale'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 2
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])

    elif which_model == 'RCAN':
        netG = RCAN_arch.RCAN(n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'],
                              res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'],rgb_range=opt_net['rgb_range'],
                              scale=opt_net['scale'],reduction=opt_net['reduction'],n_resgroups=opt_net['n_resgroups'])
    elif which_model == 'CARN_M':
        netG = CARN_arch.CARN_M(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], scale=opt_net['scale'], group=opt_net['group'])

    elif which_model == 'fsrcnn':
        netG = FSRCNN_arch.FSRCNN_net(input_channels=opt_net['in_nc'],upscale=opt_net['scale'],d=opt_net['d'],
                                      s=opt_net['s'],m=opt_net['m'])


    elif which_model == 'classSR_3class_fsrcnn_net':
        netG = classSR_3class_arch.classSR_3class_fsrcnn_net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_rcan':
        netG = classSR_rcan_arch.classSR_3class_rcan(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_srresnet':
        netG = classSR_srresnet_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])
    elif which_model == 'classSR_3class_carn':
        netG = classSR_carn_arch.ClassSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'])

    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 3
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    elif which_model == 'ORDSRNet':
        netG = ORDSRModel(in_nc=opt_net['in_nc'],
                          out_nc=opt_net['out_nc'],
                          N=opt_net['base_size'],
                          S=opt_net['stride'],
                          upscale=opt_net['scale'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 4
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        import models.archs.EDVR_arch as EDVR_arch
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'],
                              w_GCB=opt_net['w_GCB'])
    # elif which_model == 'EDVR_woDCN':
    #     import models.archs.EDVR_woDCN_arch as EDVR_arch
    #     netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
    #                           groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
    #                           back_RBs=opt_net['back_RBs'], center=opt_net['center'],
    #                           predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
    #                           w_TSA=opt_net['w_TSA'], w_GCB=opt_net['w_GCB'])
    elif which_model == 'MGANet':
        netG = Gen_Guided_UNet(input_size=opt_net['input_size'])
    elif which_model == 'Unet':
        import repo.CycleGAN.networks as unet_networks
        netG = unet_networks.define_G(2 * 3, 1, opt_net['nf'],
                                      opt_net['G_type'], opt_net['norm'],
                                      opt_net['dropout'], opt_net['init_type'],
                                      opt_net['init_gain'])

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 5
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        import models.archs.SRResNet_arch as SRResNet_arch
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        import models.archs.RRDBNet_arch as RRDBNet_arch
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        import models.archs.EDVR_arch as EDVR_arch
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'],
                              scale=opt['scale'])
    elif which_model == 'DUF':
        import models.archs.DUF_arch as DUF_arch
        if opt_net['layers'] == 16:
            netG = DUF_arch.DUF_16L(scale=opt['scale'], adapt_official=True)
        elif opt_net['layers'] == 28:
            netG = DUF_arch.DUF_28L(scale=opt['scale'], adapt_official=True)
        else:
            netG = DUF_arch.DUF_52L(scale=opt['scale'], adapt_official=True)

    elif which_model == 'TOF':
        import models.archs.TOF_arch as TOF_arch
        netG = TOF_arch.TOFlow(adapt_official=True)

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 6
0
import os.path as osp
import sys
import torch
try:
    sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
    import models.archs.SRResNet_arch as SRResNet_arch
except ImportError:
    pass

pretrained_net = torch.load(
    '../../experiments/pretrained_models/MSRResNetx4.pth')
crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
crt_net = crt_model.state_dict()

for k, v in crt_net.items():
    if k in pretrained_net and 'upconv1' not in k:
        crt_net[k] = pretrained_net[k]
        print('replace ... ', k)

# x4 -> x3
crt_net['upconv1.weight'][
    0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][
    256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
crt_net['upconv1.weight'][
    512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2

torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
Esempio n. 7
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVR2X':
        netG = EDVR_arch.EDVR2X(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                predeblur=opt_net['predeblur'],
                                HR_in=opt_net['HR_in'],
                                w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVRImg':
        netG = EDVR_arch.EDVRImage(nf=opt_net['nf'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   down_scale=opt_net['down_scale'])
    elif which_model == 'EDVR3D':
        netG = EDVR_arch.EDVR3D(nf=opt_net['nf'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                down_scale=opt_net['down_scale'])
    elif which_model == 'UPEDVR':
        netG = EDVR_arch.UPEDVR(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                w_TSA=opt_net['w_TSA'],
                                down_scale=opt_net['down_scale'],
                                align_target=opt_net['align_target'],
                                ret_valid=opt_net['ret_valid'])
    elif which_model == 'UPContEDVR':
        netG = EDVR_arch.UPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    elif which_model == 'FlowUPContEDVR':
        netG = EDVR_arch.FlowUPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    # video SR for multiple target frames
    elif which_model == 'MultiEDVR':
        netG = EDVR_arch.MultiEDVR(nf=opt_net['nf'],
                                   nframes=opt_net['nframes'],
                                   groups=opt_net['groups'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   center=opt_net['center'],
                                   predeblur=opt_net['predeblur'],
                                   HR_in=opt_net['HR_in'],
                                   w_TSA=opt_net['w_TSA'])
    # arbitrary magnification video super-resolution
    elif which_model == 'MetaEDVR':
        netG = EDVR_arch.MetaEDVR(nf=opt_net['nf'],
                                  nframes=opt_net['nframes'],
                                  groups=opt_net['groups'],
                                  front_RBs=opt_net['front_RBs'],
                                  back_RBs=opt_net['back_RBs'],
                                  center=opt_net['center'],
                                  predeblur=opt_net['predeblur'],
                                  HR_in=opt_net['HR_in'],
                                  w_TSA=opt_net['w_TSA'],
                                  fix_edvr=opt_net['fix_edvr'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Esempio n. 8
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_FusionDenoise':
        netG = my_EDVR_arch.MYEDVR_FusionDenoise(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            predeblur=opt_net['predeblur'],
            HR_in=opt_net['HR_in'],
            w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_RES':
        netG = my_EDVR_arch.MYEDVR_RES(nf=opt_net['nf'],
                                       nframes=opt_net['nframes'],
                                       groups=opt_net['groups'],
                                       front_RBs=opt_net['front_RBs'],
                                       back_RBs=opt_net['back_RBs'],
                                       center=opt_net['center'],
                                       predeblur=opt_net['predeblur'],
                                       HR_in=opt_net['HR_in'],
                                       w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_PreEnhance':
        netG = my_EDVR_arch.MYEDVR_PreEnhance(nf=opt_net['nf'],
                                              nframes=opt_net['nframes'],
                                              groups=opt_net['groups'],
                                              front_RBs=opt_net['front_RBs'],
                                              back_RBs=opt_net['back_RBs'],
                                              center=opt_net['center'],
                                              predeblur=opt_net['predeblur'],
                                              HR_in=opt_net['HR_in'],
                                              w_TSA=opt_net['w_TSA'])

    elif which_model == 'Recurr_ResBlocks':
        netG = Recurr_arch.Recurr_ResBlocks(
            nf=opt_net['nf'],
            N_RBs=opt_net['N_RBs'],
            N_flow_lv=opt_net['N_flow_lv'],
            pretrain_flow=opt_net['pretrain_flow'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG