예제 #1
0
def define_G(opt, name='network_G'):
    opt_net = opt[name]
    which_model = opt_net['which_model_G']

    if which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    elif which_model == 'RRDBNetSEG':
        netG = RRDBNet_arch.RRDBNetSEG(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], segm_mask=opt['train']['segm_mask'],
                                       scale=opt_net['scale'])     
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    return netG
예제 #2
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'],
                                       differential=opt_net['diff'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'],
                                    differential=opt_net['diff'],
                                    time_dependent=opt_net['time_dependent'],
                                    adjoint=opt_net['adjoint'],
                                    sb=opt_net['sb'])
    elif which_model == 'ReacDiff':
        netG = rd_net_arch.ReacDiff(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'],
                                    differential=opt_net['diff'])
    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #3
0
파일: networks.py 프로젝트: zhuangzhong/IKC
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    elif which_model == 'Predictor':
        netG = sftmd_arch.Predictor()
    elif which_model == 'Corrector':
        netG = sftmd_arch.Corrector()
    elif which_model == 'SFTMD':
        netG = sftmd_arch.SFTMD()
    elif which_model == 'SRResNet':
        netG = sftmd_arch.SRResNet()
    elif which_model == 'SFTMD_DEMO':
        netG = sftmd_arch.SFTMD_DEMO()
    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #4
0
def define_G(opt):
    opt_net = opt["network_G"]
    which_model = opt_net["which_model_G"]

    if which_model == "MSRResNet":
        netG = SRResNet_arch.MSRResNet(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
            upscale=opt_net["scale"],
        )
    elif which_model == "RRDBNet":
        netG = RRDBNet_arch.RRDBNet(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
        )
    elif which_model == "Predictor":
        netG = sftmd_arch.Predictor(
            in_nc=opt_net["in_nc"], nf=opt_net["nf"], code_len=opt_net["code_length"]
        )
    elif which_model == "Corrector":
        netG = sftmd_arch.Corrector(
            in_nc=opt_net["in_nc"], nf=opt_net["nf"], code_len=opt_net["code_length"]
        )
    elif which_model == "SFTMD":
        netG = sftmd_arch.SFTMD(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
            scale=opt_net["upscale"],
            input_para=opt_net["code_length"],
        )
    elif which_model == "SRResNet":
        netG = sftmd_arch.SRResNet()
    elif which_model == "SFTMD_DEMO":
        netG = sftmd_arch.SFTMD_DEMO(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
            scale=opt_net["upscale"],
            input_para=opt_net["code_length"],
        )
    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError(
            "Generator model [{:s}] not recognized".format(which_model)
        )
    return netG
예제 #5
0
def define_G(opt):
    opt_net = opt
    which_model = opt_net.which_model_G

    if which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net.G_in_nc,
                                    out_nc=opt_net.out_nc,
                                    nf=opt_net.G_nf,
                                    nb=opt_net.nb)
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #6
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    elif which_model == 'DWUNet':
        netG = DWUNet_arch.DWUNet()
    elif which_model == 'RFDNet':
        netG = RFDNet_arch.RFDN()
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #7
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], opt=opt)
    elif which_model == 'EDSRNet':
        Arch = find_model_using_name(which_model)
        netG = Arch(scale=opt['scale'])
    elif which_model == 'rankSRGAN':
        Arch = find_model_using_name(which_model)
        netG = Arch(upscale=opt['scale'])
    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #8
0
def define_G(opt):
    opt_net = copy(opt['network_G'])
    which_model = opt_net.pop('which_model_G')
    if 'scale' in opt_net:
        scale = opt_net.pop('scale')

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(**opt_net, upscale=scale)
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(**opt_net)
    elif which_model == 'USRGAN':
        netG = USRGAN_arch.USRGAN(**opt_net)
    elif which_model == 'USRGANLarge':
        netG = USRGANLarge_arch.USRGANLarge(**opt_net)
    elif which_model == 'USRGAN_conns':
        netG = USRGAN_Connections_arch.USRGANLarge(**opt_net)
    elif which_model == 'BOWGAN':
        netG = BOWGAN_arch.BOWGAN(**opt_net)
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
    return netG
예제 #9
0
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb=opt_net['nb'],
                                       upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'],
                                    out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'],
                                    nb=opt_net['nb'])

    elif which_model == 'MResNet':
        netG = SRResNet_arch.MResNet(in_nc=opt_net['in_nc'],
                                     out_nc=opt_net['out_nc'],
                                     nf=opt_net['nf'],
                                     nb=opt_net['nb'])
    elif which_model == 'ResNet_alpha_beta':
        netG = SRResNet_arch.ResNet_alpha_beta()
    elif which_model == 'ResNet_alpha_beta_sconv':
        netG = SRResNet_arch.ResNet_alpha_beta_sconv()
    elif which_model == 'ResNet_alpha_beta_fc':
        netG = SRResNet_arch.ResNet_alpha_beta_fc()
    elif which_model == 'ResNet_alpha_beta_fc_statistics':
        netG = SRResNet_arch.ResNet_alpha_beta_fc_statistics()
    elif which_model == 'ResNet_alpha_beta_decoder_1x1':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_1x1()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3_BN':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3_BN()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3_IN':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3_IN()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3_IN_encoder':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3_IN_encoder()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW()
    elif which_model == 'ResNet_alpha_beta_decoder_3x3_IN_encoder_global2local':
        netG = SRResNet_arch.ResNet_alpha_beta_decoder_3x3_IN_encoder_global2local(
        )
    elif which_model == 'ResNet_plain':
        netG = SRResNet_arch.ResNet_plain()

    elif which_model == 'DFN':
        netG = DynamicF.DFN_16L_2d()
    elif which_model == 'DFN_1x1':
        netG = DynamicF.DFN_16L_2d_1x1()
    elif which_model == 'DFN_noRx':
        netG = DynamicF.DFN_16L_2d(res=False)
    elif which_model == 'DFN_alpha':
        netG = DynamicF.DFN_16L_2d_alpha()
    elif which_model == 'DCCN':
        netG = DynamicF.DCCN_16L_2d()
    elif which_model == 'DCCN_alpha':
        netG = DynamicF.DCCN_16L_2d_alpha()

    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))
    return netG