示例#1
0
文件: networks.py 项目: FVL2020/SRDRL
def define_R(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_R']
    which_model = opt_net['which_model_R']

    if which_model == 'deg_net':  # Degradation Simulator
        netR = arch.DegNet(in_nc=opt_net['in_nc'],
                           out_nc=opt_net['out_nc'],
                           nf=opt_net['nf'],
                           n_deg_lr=opt_net['n_deg_lr'],
                           n_deg_hr=opt_net['n_deg_hr'],
                           n_rec=opt_net['n_rec'],
                           upscale=opt_net['scale'],
                           is_train=opt_net['is_train'],
                           output=opt_net['output'])
    else:
        raise NotImplementedError(
            'Degradation Simulator model [{:s}] not recognized'.format(
                which_model))

    if opt['is_train']:
        init_weights(netR, init_type='kaiming', scale=0.1)
    else:
        netR.eval()  # No need to train
    if gpu_ids:
        assert torch.cuda.is_available()
        netR = nn.DataParallel(netR)
    return netR
示例#2
0
文件: networks.py 项目: FVL2020/SRDRL
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'deg_net':  # Degradation Simulator
        netG = arch.DegNet(in_nc=opt_net['in_nc'],
                           out_nc=opt_net['out_nc'],
                           nf=opt_net['nf'],
                           n_deg_lr=opt_net['n_deg_lr'],
                           n_deg_hr=opt_net['n_deg_hr'],
                           n_rec=opt_net['n_rec'],
                           upscale=opt_net['scale'],
                           is_train=opt_net['is_train'],
                           output=opt_net['output'])
    elif which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
                             nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
                             act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode'])
    elif which_model == 'edsr':  # EDSR
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], res_scale=opt_net['res_scale'], \
                             nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
                             act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode'])
    elif which_model == 'rcan':  # RCAN
        netG = arch.RCAN(in_nc=opt_net['in_nc'],
                         out_nc=opt_net['out_nc'],
                         nf=opt_net['nf'],
                         res_scale=opt_net['res_scale'],
                         upscale=opt_net['scale'],
                         n_resgroups=opt_net['n_resgroups'],
                         n_resblocks=opt_net['n_resblocks'])
    elif which_model == 'sr_resnet_lh':  # SRResNet_L6H10
        netG = arch.SRResNetLH(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
                             nb_lr=opt_net['nb_lr'], nb_hr=opt_net['nb_hr'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
                             act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode'])
    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            nb=opt_net['nb'],
                            gc=opt_net['gc'],
                            upscale=opt_net['scale'],
                            norm_type=opt_net['norm_type'],
                            act_type='leakyrelu',
                            mode=opt_net['mode'],
                            upsample_mode='upconv')
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
示例#3
0
def define_G2(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']

    netG = arch.DegNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
        nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
        act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
    
    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG