def define_R(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_R'] which_model = opt_net['which_model_R'] if which_model == 'deg_net': # Degradation Simulator netR = arch.DegNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], n_deg_lr=opt_net['n_deg_lr'], n_deg_hr=opt_net['n_deg_hr'], n_rec=opt_net['n_rec'], upscale=opt_net['scale'], is_train=opt_net['is_train'], output=opt_net['output']) else: raise NotImplementedError( 'Degradation Simulator model [{:s}] not recognized'.format( which_model)) if opt['is_train']: init_weights(netR, init_type='kaiming', scale=0.1) else: netR.eval() # No need to train if gpu_ids: assert torch.cuda.is_available() netR = nn.DataParallel(netR) return netR
def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'deg_net': # Degradation Simulator netG = arch.DegNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], n_deg_lr=opt_net['n_deg_lr'], n_deg_hr=opt_net['n_deg_hr'], n_rec=opt_net['n_rec'], upscale=opt_net['scale'], is_train=opt_net['is_train'], output=opt_net['output']) elif which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode']) elif which_model == 'edsr': # EDSR netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], res_scale=opt_net['res_scale'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode']) elif which_model == 'rcan': # RCAN netG = arch.RCAN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], res_scale=opt_net['res_scale'], upscale=opt_net['scale'], n_resgroups=opt_net['n_resgroups'], n_resblocks=opt_net['n_resblocks']) elif which_model == 'sr_resnet_lh': # SRResNet_L6H10 netG = arch.SRResNetLH(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb_lr=opt_net['nb_lr'], nb_hr=opt_net['nb_hr'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type=opt_net['act_type'], mode=opt_net['mode'], upsample_mode=opt_net['upsample_mode']) elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G2(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] netG = arch.DegNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG