Exemplo n.º 1
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    # 16x superresolution
    elif which_model == 'RRDBNet_16x':
        netG = RRDBNet_arch.RRDBNet_16x(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])

    # 16x superresolution with transposed conv
    elif which_model == 'RRDBNetTRConv_16x':
        netG = RRDBNet_arch.RRDBNetTRConv_16x(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])

    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
                              groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'], center=opt_net['center'],
                              predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 2
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    elif which_model == 'ORDSRNet':
        netG = ORDSRModel(in_nc=opt_net['in_nc'],
                          out_nc=opt_net['out_nc'],
                          N=opt_net['base_size'],
                          S=opt_net['stride'],
                          upscale=opt_net['scale'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 3
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
                              groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'], center=opt_net['center'],
                              predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'], scale=opt['scale'])
    elif which_model == 'DUF':
        if opt_net['layers'] == 16:
            netG = DUF_arch.DUF_16L(scale=opt['scale'], adapt_official=True)
        elif opt_net['layers'] == 28:
            netG = DUF_arch.DUF_28L(scale=opt['scale'], adapt_official=True)
        else:
            netG = DUF_arch.DUF_52L(scale=opt['scale'], adapt_official=True)

    elif which_model == 'TOF':
        netG = TOF_arch.TOFlow(adapt_official=True)

    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 4
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])

    elif which_model == 'AdaFMNet':
        netG = AdaFMNet_arch.AdaFMNet(in_nc=opt_net['in_nc'],
                                      out_nc=opt_net['out_nc'],
                                      nf=opt_net['nf'],
                                      nb=opt_net['nb'],
                                      adafm_ksize=opt_net['adafm_ksize'])

    elif which_model == 'CResMDNet':
        netG = CResMDNet_arch.CResMDNet(in_nc=opt_net['in_nc'],
                                        out_nc=opt_net['out_nc'],
                                        nf=opt_net['nf'],
                                        nb=opt_net['nb'],
                                        cond_dim=opt_net['cond_dim'])
    elif which_model == 'BaseNet':
        netG = CResMDNet_arch.BaseNet(in_nc=opt_net['in_nc'],
                                      out_nc=opt_net['out_nc'],
                                      nf=opt_net['nf'],
                                      nb=opt_net['nb'])
    elif which_model == 'CondNet':
        netG = CResMDNet_arch.CondNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])

    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 5
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        import models.archs.EDVR_arch as EDVR_arch
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'],
                              w_GCB=opt_net['w_GCB'])
    # elif which_model == 'EDVR_woDCN':
    #     import models.archs.EDVR_woDCN_arch as EDVR_arch
    #     netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
    #                           groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
    #                           back_RBs=opt_net['back_RBs'], center=opt_net['center'],
    #                           predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
    #                           w_TSA=opt_net['w_TSA'], w_GCB=opt_net['w_GCB'])
    elif which_model == 'MGANet':
        netG = Gen_Guided_UNet(input_size=opt_net['input_size'])
    elif which_model == 'Unet':
        import repo.CycleGAN.networks as unet_networks
        netG = unet_networks.define_G(2 * 3, 1, opt_net['nf'],
                                      opt_net['G_type'], opt_net['norm'],
                                      opt_net['dropout'], opt_net['init_type'],
                                      opt_net['init_gain'])

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 6
0
 def _torch_infer(self, img):
     if torch.cuda.is_available():
         device = torch.device('cuda')
     else:
         device = torch.device('cpu')
     torch_input = torch.from_numpy(img).to(device)
     model = arch.RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, upscale=2)
     model_bytes = torch.load(self._model_path)
     model.load_state_dict(model_bytes, strict=False)
     model.eval()
     model = model.to(device)
     with torch.no_grad():
         pred = model(torch_input)
     return pred
Exemplo n.º 7
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'CNLRN':
        netG = CNLRN_arch.CNLRN(n_colors=opt_net['n_colors'],
                                n_deblur_blocks=opt_net['n_deblur_blocks'],
                                n_nlrgs_body=opt_net['n_nlrgs_body'],
                                n_nlrgs_up1=opt_net['n_nlrgs_up1'],
                                n_nlrgs_up2=opt_net['n_nlrgs_up2'],
                                n_subgroups=opt_net['n_subgroups'],
                                n_rcabs=opt_net['n_rcabs'],
                                n_feats=opt_net['n_feats'],
                                nonlocal_psize=opt_net['nonlocal_psize'],
                                scale=opt_net['scale'])
    elif which_model == 'PreDeblur':
        netG = PreDeblur_arch.PreDeblur(
            n_colors=opt_net['n_colors'],
            n_deblur_blocks=opt_net['n_deblur_blocks'],
            n_feats=opt_net['n_feats'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 8
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'JASRNet':
        netG = JASRNet_arch.JASR(n_Parts=opt_net['n_Parts'],
                                 n_resblocks=opt_net['n_resblocks'],
                                 n_feats=opt_net['n_feats'],
                                 scale=opt_net['scale'],
                                 rgb_range=opt_net['rgb_range'],
                                 n_colors=opt_net['n_colors'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 9
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])

    elif which_model == 'RCAN':
        netG = RCAN_arch.RCAN(in_nc=opt_net['in_nc'],
                              out_nc=opt_net['out_nc'],
                              n_features=opt_net['nf'],
                              n_resgroups=opt_net['ng'],
                              n_resblocks=opt_net['nb'],
                              reduction=opt_net['reduction'],
                              scale=opt_net['scale'],
                              res_scale=opt_net['res_scale'])

    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])

    elif which_model == 'EDVR_DN':
        netG = EDVR_arch.EDVR_DN(nf=opt_net['nf'],
                                 nframes=opt_net['nframes'],
                                 groups=opt_net['groups'],
                                 front_RBs=opt_net['front_RBs'],
                                 back_RBs=opt_net['back_RBs'],
                                 center=opt_net['center'],
                                 predeblur=opt_net['predeblur'],
                                 HR_in=opt_net['HR_in'],
                                 w_TSA=opt_net['w_TSA'])

    elif which_model == 'EDVR_pyramid':
        netG = EDVR_arch.EDVR_pyramid(nf=opt_net['nf'],
                                      nframes=opt_net['nframes'],
                                      groups=opt_net['groups'],
                                      front_RBs=opt_net['front_RBs'],
                                      back_RBs=opt_net['back_RBs'],
                                      center=opt_net['center'],
                                      predeblur=opt_net['predeblur'],
                                      HR_in=opt_net['HR_in'],
                                      w_TSA=opt_net['w_TSA'])

    elif which_model == 'PFNL':
        netG = PFNL_arch.PFNL(nf=opt_net['nf'],
                              nc=opt_net['nc'],
                              nt=opt_net['nt'],
                              r=opt_net['r'],
                              scale=opt_net['scale'])

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 10
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVR2X':
        netG = EDVR_arch.EDVR2X(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                predeblur=opt_net['predeblur'],
                                HR_in=opt_net['HR_in'],
                                w_TSA=opt_net['w_TSA'])
    elif which_model == 'EDVRImg':
        netG = EDVR_arch.EDVRImage(nf=opt_net['nf'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   down_scale=opt_net['down_scale'])
    elif which_model == 'EDVR3D':
        netG = EDVR_arch.EDVR3D(nf=opt_net['nf'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                down_scale=opt_net['down_scale'])
    elif which_model == 'UPEDVR':
        netG = EDVR_arch.UPEDVR(nf=opt_net['nf'],
                                nframes=opt_net['nframes'],
                                groups=opt_net['groups'],
                                front_RBs=opt_net['front_RBs'],
                                back_RBs=opt_net['back_RBs'],
                                center=opt_net['center'],
                                w_TSA=opt_net['w_TSA'],
                                down_scale=opt_net['down_scale'],
                                align_target=opt_net['align_target'],
                                ret_valid=opt_net['ret_valid'])
    elif which_model == 'UPContEDVR':
        netG = EDVR_arch.UPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    elif which_model == 'FlowUPContEDVR':
        netG = EDVR_arch.FlowUPControlEDVR(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            w_TSA=opt_net['w_TSA'],
            down_scale=opt_net['down_scale'],
            align_target=opt_net['align_target'],
            ret_valid=opt_net['ret_valid'],
            multi_scale_cont=opt_net['multi_scale_cont'])
    # video SR for multiple target frames
    elif which_model == 'MultiEDVR':
        netG = EDVR_arch.MultiEDVR(nf=opt_net['nf'],
                                   nframes=opt_net['nframes'],
                                   groups=opt_net['groups'],
                                   front_RBs=opt_net['front_RBs'],
                                   back_RBs=opt_net['back_RBs'],
                                   center=opt_net['center'],
                                   predeblur=opt_net['predeblur'],
                                   HR_in=opt_net['HR_in'],
                                   w_TSA=opt_net['w_TSA'])
    # arbitrary magnification video super-resolution
    elif which_model == 'MetaEDVR':
        netG = EDVR_arch.MetaEDVR(nf=opt_net['nf'],
                                  nframes=opt_net['nframes'],
                                  groups=opt_net['groups'],
                                  front_RBs=opt_net['front_RBs'],
                                  back_RBs=opt_net['back_RBs'],
                                  center=opt_net['center'],
                                  predeblur=opt_net['predeblur'],
                                  HR_in=opt_net['HR_in'],
                                  w_TSA=opt_net['w_TSA'],
                                  fix_edvr=opt_net['fix_edvr'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG
Exemplo n.º 11
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    # image restoration
    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])
    # video restoration
    elif which_model == 'EDVR':
        netG = EDVR_arch.EDVR(nf=opt_net['nf'],
                              nframes=opt_net['nframes'],
                              groups=opt_net['groups'],
                              front_RBs=opt_net['front_RBs'],
                              back_RBs=opt_net['back_RBs'],
                              center=opt_net['center'],
                              predeblur=opt_net['predeblur'],
                              HR_in=opt_net['HR_in'],
                              w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_FusionDenoise':
        netG = my_EDVR_arch.MYEDVR_FusionDenoise(
            nf=opt_net['nf'],
            nframes=opt_net['nframes'],
            groups=opt_net['groups'],
            front_RBs=opt_net['front_RBs'],
            back_RBs=opt_net['back_RBs'],
            center=opt_net['center'],
            predeblur=opt_net['predeblur'],
            HR_in=opt_net['HR_in'],
            w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_RES':
        netG = my_EDVR_arch.MYEDVR_RES(nf=opt_net['nf'],
                                       nframes=opt_net['nframes'],
                                       groups=opt_net['groups'],
                                       front_RBs=opt_net['front_RBs'],
                                       back_RBs=opt_net['back_RBs'],
                                       center=opt_net['center'],
                                       predeblur=opt_net['predeblur'],
                                       HR_in=opt_net['HR_in'],
                                       w_TSA=opt_net['w_TSA'])
    elif which_model == 'MY_EDVR_PreEnhance':
        netG = my_EDVR_arch.MYEDVR_PreEnhance(nf=opt_net['nf'],
                                              nframes=opt_net['nframes'],
                                              groups=opt_net['groups'],
                                              front_RBs=opt_net['front_RBs'],
                                              back_RBs=opt_net['back_RBs'],
                                              center=opt_net['center'],
                                              predeblur=opt_net['predeblur'],
                                              HR_in=opt_net['HR_in'],
                                              w_TSA=opt_net['w_TSA'])

    elif which_model == 'Recurr_ResBlocks':
        netG = Recurr_arch.Recurr_ResBlocks(
            nf=opt_net['nf'],
            N_RBs=opt_net['N_RBs'],
            N_flow_lv=opt_net['N_flow_lv'],
            pretrain_flow=opt_net['pretrain_flow'])
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    return netG