def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt, CEM=None, num_latent_channels=None, **kwargs): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] opt_net['latent_input'] = opt_net[ 'latent_input'] if opt_net['latent_input'] != "None" else None if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle',range_correction=opt_net['range_correction']) elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', latent_input=(opt_net['latent_input'] + '_' + opt_net['latent_input_domain']) if opt_net['latent_input'] is not None else None, num_latent_channels=num_latent_channels) elif which_model == 'DnCNN': chroma_mode = kwargs['chroma_mode'] if 'chroma_mode' in kwargs.keys( ) else False assert opt_net['in_nc'] == 64 and opt_net['out_nc'] == 64 in_nc = opt['scale']**2 + 2 * 64 if chroma_mode else 64 out_nc = 2 * (opt['scale']**2) if chroma_mode else 64 netG = arch.DnCNN(n_channels=opt_net['nf'], depth=opt_net['nb'], in_nc=in_nc, out_nc=out_nc, norm_type=opt_net['norm_type'], latent_input=opt_net['latent_input'] if opt_net['latent_input'] is not None else None, num_latent_channels=num_latent_channels, chroma_generator=chroma_mode) elif which_model == 'MSRResNet': # SRResNet netG = arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt_net['CEM_arch']: netG = CEM.WrapArchitecture_PyTorch( netG, opt['datasets']['train']['patch_size'] if opt['is_train'] else None) if opt['is_train'] and which_model != 'MSRResNet': # and which_model != 'DnCNN': init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt): gpu_ids = opt["gpu_ids"] opt_net = opt["network_G"] which_model = opt_net["which_model_G"] if which_model == "sr_resnet": # SRResNet netG = arch.SRResNet( in_nc=opt_net["in_nc"], out_nc=opt_net["out_nc"], nf=opt_net["nf"], nb=opt_net["nb"], upscale=opt_net["scale"], norm_type=opt_net["norm_type"], act_type="relu", mode=opt_net["mode"], upsample_mode="pixelshuffle", ) elif which_model == "sft_arch": # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == "RRDB_net": # RRDB netG = arch.RRDBNet( in_nc=opt_net["in_nc"], out_nc=opt_net["out_nc"], nf=opt_net["nf"], nb=opt_net["nb"], gc=opt_net["gc"], upscale=opt_net["scale"], norm_type=opt_net["norm_type"], act_type="leakyrelu", mode=opt_net["mode"], upsample_mode="upconv", ) else: raise NotImplementedError( "Generator model [{:s}] not recognized".format(which_model)) if opt["is_train"]: init_weights(netG, init_type="kaiming", scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt): gpu_ids = opt['gpu_ids'] opt = opt['network_G'] which_model = opt['which_model_G'] if which_model == 'sr_resnet': netG = arch.SRResNet(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \ nb=opt['nb'], upscale=opt['scale'], norm_type=opt['norm_type'], mode=opt['mode'],\ upsample_mode='pixelshuffle') elif which_model == 'sft_arch': netG = sft_arch.SFT_Net() # if which_model != 'sr_resnet': # need to investigate, the original is better? # init_weights(netG, init_type='orthogonal') if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG).cuda() return netG
def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') #define the ex_G elif which_model == 'RRDBNet_G': netG = arch.RRDBNet(in_nc=opt_net['nf'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') elif which_model == 'RCAN_G': netG = rcan_g.RCAN_G(n_resblocks=opt_net['n_resblocks'], n_resgroups=opt_net['n_resgroups'], n_feats=opt_net['n_feats'], reduction=16, scale=4, n_colors=3, rgb_range=255, res_scale=1) elif which_model == 'DualSR_RCAN': #def __init__(self, n_resblocks, n_resgroups_mask, n_resgroups_share, n_resgroups_high_1, n_resgroups_high_2, # n_resgroups_low1, n_resgroups_low2, n_feats, reduction, scale, n_colors, rgb_range, res_scale, # conv=common.default_conv): netG =DualSR.DualSR(n_resblocks=opt_net['n_resblocks'], n_resgroups_mask=opt_net['n_resgroups_mask'],\ n_resgroups_share=opt_net['n_resgroups_share'], n_resgroups_high_1=opt_net['n_resgroups_high_1'], \ n_resgroups_high_2=opt_net['n_resgroups_high_2'], n_resgroups_low_1=opt_net['n_resgroups_low_1'], \ n_resgroups_low_2=opt_net['n_resgroups_low_2'],\ n_feats=opt_net['n_feats'], reduction=16, scale=4, n_colors=3, rgb_range=255, res_scale=1) elif which_model == 'DualSR_SR': # def __init__(self, n_resblocks, n_resgroups_mask, n_resgroups_share, n_resgroups_high_1, n_resgroups_high_2, # n_resgroups_low1, n_resgroups_low2, n_feats, reduction, scale, n_colors, rgb_range, res_scale, # conv=common.default_conv): netG = DualSR_SR.DualSR_SR(n_resblocks=opt_net['n_resblocks'], n_resgroups_mask=opt_net['n_resgroups_mask'], \ n_resgroups_share=opt_net['n_resgroups_share'], n_resgroups_high_1=opt_net['n_resgroups_high_1'], \ n_resgroups_high_2=opt_net['n_resgroups_high_2'], n_resgroups_low_1=opt_net['n_resgroups_low_1'], \ n_resgroups_low_2=opt_net['n_resgroups_low_2'], \ n_feats=opt_net['n_feats'], reduction=16, scale=4, n_colors=3, rgb_range=255, res_scale=1) elif which_model == 'DualSR_RRDB': netG = DualSR_RRDB.DualSR_RRDB(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb_l_1=opt_net['nb_l_1'], nb_l_2=opt_net['nb_l_2'], nb_h_1=opt_net['nb_h_1'], nb_e=opt_net['nb_e'], nb_h_2=opt_net['nb_h_2'], nb_m=opt_net['nb_m'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'Octave_SRResNet': # SRResNet based on octave netG = arch.Octave_SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'M_NP_Octave_RRDBNet': # SRResNet based on octave netG = arch.M_NP_Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'Octave_RRDBNet': # SRResNet based on octave netG = arch.Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'DWT_Octave_RRDBNet': # SRResNet based on octave netG = arch.DWT_Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'modified_resnet': # SRResNet based on octave netG = arch.Modified_SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'Octave_CARN': # CARN based on octave netG = arch.Octave_CARN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'carn': # CARN based on octave netG = arch.CARN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'modulate_sr_resnet': netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'arcnn': netG = arch.ARCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'srcnn': netG = arch.SRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'noise_plainnet': netG = arch.NoisePlainNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) elif which_model == 'denoise_resnet': netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'], upsample_norm=opt_net['upsample_norm']) elif which_model == 'modulate_denoise_resnet': netG = arch.ModulateDenoiseResNet( in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'noise_subnet': netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) elif which_model == 'cond_denoise_resnet': netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], num_classes=opt_net['num_classes'], norm_type=opt_net['norm_type']) elif which_model == 'adabn_denoise_resnet': netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], down_scale=opt_net['down_scale']) elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt['init_type'] is not None: init_weights(netG, init_type=opt['init_type'], scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') elif which_model == 'RRDB_adain': # RRDB_AdaIN netG = arch.RRDBNet_AdaIn(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') elif which_model == 'RRDB_fsr': netG = arch.RRDBNet_FSR(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', with_prior=opt_net['with_prior']) elif which_model == 'RRDB_style': netG = arch.StyleRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', with_style=opt_net['with_style'], with_noise=opt_net['with_noise']) elif which_model == 'RRDB_unet': netG = arch.RRDBUNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', with_unet=opt_net['with_unet']) elif which_model == 'RRDB_unet2': netG = arch.RRDBUNet2(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') elif which_model == 'RRDB_unet3': netG = arch.RRDBUNet3(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') elif which_model == 'RRDB_unet4': netG = arch.RRDBUNet4(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], act_type='leakyrelu') elif which_model == 'RRDB_unet5': netG = arch.RRDBUNet5(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], na=opt_net['na'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], act_type='leakyrelu') elif which_model == 'RRDB_unet6': netG = arch.RRDBUNet6(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], na=opt_net['na'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], act_type='leakyrelu') elif which_model == 'RRDB_unet7': netG = arch.RRDBUNet7(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], na=opt_net['na'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], act_type='leakyrelu') elif which_model == 'ResNet_style': netG = arch.StyleResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv', with_style=opt_net['with_style'], with_noise=opt_net['with_noise']) elif which_model == 'Unet': netG = arch.Unet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf']) elif which_model == 'Unet_RDB': netG = arch.UnetRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], act_type='leakyrelu') elif which_model == 'Unet_RRDB': netG = arch.UnetRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], act_type='leakyrelu') elif which_model == 'Unet_RRDB2': netG = arch.UnetRRDBNet2(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], act_type='leakyrelu') elif which_model == 'Unet_style': netG = arch.UnetStyle(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], with_style=opt_net['with_style'], with_noise=opt_net['with_noise']) else: raise NotImplementedError( 'Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if opt['distributed']: assert torch.cuda.is_available() netG = nn.parallel.DistributedDataParallel(netG) elif gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet( in_nc = opt_net['in_nc'], out_nc = opt_net['out_nc'], nf = opt_net['nf'], nb = opt_net['nb'], upscale = opt_net['scale'], norm_type = opt_net['norm_type'], act_type = 'relu', mode = opt_net['mode'], upsample_mode = 'pixelshuffle' ) elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet( in_nc = opt_net['in_nc'], out_nc = opt_net['out_nc'], nf = opt_net['nf'], nb = opt_net['nb'], gc = opt_net['gc'], upscale = opt_net['scale'], norm_type = opt_net['norm_type'], act_type = 'leakyrelu', mode = opt_net['mode'], upsample_mode = 'upconv' ) elif which_model == 'DDDB_net': # DDDB self_attention = False self_attention_normalise = True if 'self_attention' in opt_net: if opt_net['self_attention']=="include": self_attention = True if 'self_attention_normalise' in opt_net: if opt_net['self_attention_normalise']=='no': self_attention_normalise = False logger.info(f"SelfAttention {self_attention} normalise {self_attention_normalise}") netG = arch.DDDBNet( in_nc = opt_net['in_nc'], out_nc = opt_net['out_nc'], nf = opt_net['nf'], nb = opt_net['nb'], gc = opt_net['gc'], upscale = opt_net['scale'], norm_type = opt_net['norm_type'], act_type = 'leakyrelu', mode = opt_net['mode'], upsample_mode = 'upconv', self_attention= self_attention, self_attention_normalise = self_attention_normalise ) elif which_model == 'secordresnet': netG = arch.SecOrdResNet( in_nc = opt_net['in_nc'], out_nc = opt_net['out_nc'], nf = opt_net['nf'], nb = opt_net['nb'], upscale = opt_net['scale'], norm_type = opt_net['norm_type'], act_type = 'leakyrelu', mode = opt_net['mode'], upsample_mode = 'upconv', merge_mode = opt_net['merge_mode'], ) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG