示例#1
0
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt = opt['network_G']
    which_model = opt['which_model_G']

    if which_model == 'sr_resnet':
        netG = arch.SRResNet(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
            nb=opt['nb'], upscale=opt['scale'], norm_type=opt['norm_type'], mode=opt['mode'],\
            upsample_mode='pixelshuffle')

    elif which_model == 'sft_arch':
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_Net':
        netG = arch.RRDB_Net(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
            nb=opt['nb'], gc=opt['gc'], upscale=opt['scale'], norm_type=opt['norm_type'], \
            act_type='leakyrelu', mode=opt['mode'], res_scale=1, upsample_mode='upconv')
    # if which_model != 'sr_resnet':  # need to investigate, the original is better?
    #     init_weights(netG, init_type='orthogonal')
    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
示例#2
0
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            nb=opt_net['nb'],
                            gc=opt_net['gc'],
                            upscale=opt_net['scale'],
                            norm_type=opt_net['norm_type'],
                            act_type='leakyrelu',
                            mode=opt_net['mode'],
                            upsample_mode='upconv')
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
示例#3
0
def define_G(opt):
    gpu_ids = opt["gpu_ids"]
    opt_net = opt["network_G"]
    which_model = opt_net["which_model_G"]

    if which_model == "sr_resnet":  # SRResNet
        netG = arch.SRResNet(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
            upscale=opt_net["scale"],
            norm_type=opt_net["norm_type"],
            act_type="relu",
            mode=opt_net["mode"],
            upsample_mode="pixelshuffle",
        )

    elif which_model == "sft_arch":  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == "RRDB_net":  # RRDB
        netG = arch.RRDBNet(
            in_nc=opt_net["in_nc"],
            out_nc=opt_net["out_nc"],
            nf=opt_net["nf"],
            nb=opt_net["nb"],
            gc=opt_net["gc"],
            upscale=opt_net["scale"],
            norm_type=opt_net["norm_type"],
            act_type="leakyrelu",
            mode=opt_net["mode"],
            upsample_mode="upconv",
        )
    else:
        raise NotImplementedError(
            "Generator model [{:s}] not recognized".format(which_model))

    if opt["is_train"]:
        init_weights(netG, init_type="kaiming", scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            nb=opt_net['nb'],
                            gc=opt_net['gc'],
                            upscale=opt_net['scale'],
                            norm_type=opt_net['norm_type'],
                            act_type='leakyrelu',
                            mode=opt_net['mode'],
                            upsample_mode='upconv')
    #define the ex_G
    elif which_model == 'RRDBNet_G':
        netG = arch.RRDBNet(in_nc=opt_net['nf'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            nb=opt_net['nb'],
                            gc=opt_net['gc'],
                            upscale=opt_net['scale'],
                            norm_type=opt_net['norm_type'],
                            act_type='leakyrelu',
                            mode=opt_net['mode'],
                            upsample_mode='upconv')
    elif which_model == 'RCAN_G':
        netG = rcan_g.RCAN_G(n_resblocks=opt_net['n_resblocks'],
                             n_resgroups=opt_net['n_resgroups'],
                             n_feats=opt_net['n_feats'],
                             reduction=16,
                             scale=4,
                             n_colors=3,
                             rgb_range=255,
                             res_scale=1)
    elif which_model == 'DualSR_RCAN':
        #def __init__(self, n_resblocks, n_resgroups_mask, n_resgroups_share, n_resgroups_high_1, n_resgroups_high_2,
        #            n_resgroups_low1, n_resgroups_low2, n_feats, reduction, scale, n_colors, rgb_range, res_scale,
        #            conv=common.default_conv):

        netG =DualSR.DualSR(n_resblocks=opt_net['n_resblocks'], n_resgroups_mask=opt_net['n_resgroups_mask'],\
                                n_resgroups_share=opt_net['n_resgroups_share'], n_resgroups_high_1=opt_net['n_resgroups_high_1'], \
                                n_resgroups_high_2=opt_net['n_resgroups_high_2'], n_resgroups_low_1=opt_net['n_resgroups_low_1'], \
                                n_resgroups_low_2=opt_net['n_resgroups_low_2'],\
                                n_feats=opt_net['n_feats'], reduction=16, scale=4, n_colors=3, rgb_range=255, res_scale=1)
    elif which_model == 'DualSR_SR':
        # def __init__(self, n_resblocks, n_resgroups_mask, n_resgroups_share, n_resgroups_high_1, n_resgroups_high_2,
        #            n_resgroups_low1, n_resgroups_low2, n_feats, reduction, scale, n_colors, rgb_range, res_scale,
        #            conv=common.default_conv):

        netG = DualSR_SR.DualSR_SR(n_resblocks=opt_net['n_resblocks'], n_resgroups_mask=opt_net['n_resgroups_mask'], \
                             n_resgroups_share=opt_net['n_resgroups_share'],
                             n_resgroups_high_1=opt_net['n_resgroups_high_1'], \
                             n_resgroups_high_2=opt_net['n_resgroups_high_2'],
                             n_resgroups_low_1=opt_net['n_resgroups_low_1'], \
                             n_resgroups_low_2=opt_net['n_resgroups_low_2'], \
                             n_feats=opt_net['n_feats'], reduction=16, scale=4, n_colors=3, rgb_range=255, res_scale=1)
    elif which_model == 'DualSR_RRDB':
        netG = DualSR_RRDB.DualSR_RRDB(in_nc=opt_net['in_nc'],
                                       out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'],
                                       nb_l_1=opt_net['nb_l_1'],
                                       nb_l_2=opt_net['nb_l_2'],
                                       nb_h_1=opt_net['nb_h_1'],
                                       nb_e=opt_net['nb_e'],
                                       nb_h_2=opt_net['nb_h_2'],
                                       nb_m=opt_net['nb_m'],
                                       gc=opt_net['gc'],
                                       upscale=opt_net['scale'],
                                       norm_type=opt_net['norm_type'],
                                       act_type='leakyrelu',
                                       mode=opt_net['mode'],
                                       upsample_mode='upconv')

    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
示例#5
0
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'Octave_SRResNet':  # SRResNet based on octave
        netG = arch.Octave_SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'M_NP_Octave_RRDBNet':  # SRResNet based on octave
        netG = arch.M_NP_Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'Octave_RRDBNet':  # SRResNet based on octave
        netG = arch.Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'DWT_Octave_RRDBNet':  # SRResNet based on octave
        netG = arch.DWT_Octave_RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'modified_resnet':  # SRResNet based on octave
        netG = arch.Modified_SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'Octave_CARN':  # CARN based on octave
        netG = arch.Octave_CARN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'carn':  # CARN based on octave
        netG = arch.CARN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')

    elif which_model == 'modulate_sr_resnet':
        netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'],
                                     out_nc=opt_net['out_nc'],
                                     nf=opt_net['nf'],
                                     nb=opt_net['nb'],
                                     upscale=opt_net['scale'],
                                     norm_type=opt_net['norm_type'],
                                     mode=opt_net['mode'],
                                     upsample_mode='pixelshuffle',
                                     ada_ksize=opt_net['ada_ksize'],
                                     gate_conv_bias=opt_net['gate_conv_bias'])

    elif which_model == 'arcnn':
        netG = arch.ARCNN(in_nc=opt_net['in_nc'],
                          out_nc=opt_net['out_nc'],
                          nf=opt_net['nf'],
                          norm_type=opt_net['norm_type'],
                          mode=opt_net['mode'],
                          ada_ksize=opt_net['ada_ksize'])

    elif which_model == 'srcnn':
        netG = arch.SRCNN(in_nc=opt_net['in_nc'],
                          out_nc=opt_net['out_nc'],
                          nf=opt_net['nf'],
                          norm_type=opt_net['norm_type'],
                          mode=opt_net['mode'],
                          ada_ksize=opt_net['ada_ksize'])

    elif which_model == 'noise_plainnet':
        netG = arch.NoisePlainNet(in_nc=opt_net['in_nc'],
                                  out_nc=opt_net['out_nc'],
                                  nf=opt_net['nf'],
                                  norm_type=opt_net['norm_type'],
                                  mode=opt_net['mode'])

    elif which_model == 'denoise_resnet':
        netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'],
                                  out_nc=opt_net['out_nc'],
                                  nf=opt_net['nf'],
                                  nb=opt_net['nb'],
                                  upscale=opt_net['scale'],
                                  norm_type=opt_net['norm_type'],
                                  mode=opt_net['mode'],
                                  upsample_mode='pixelshuffle',
                                  ada_ksize=opt_net['ada_ksize'],
                                  down_scale=opt_net['down_scale'],
                                  fea_norm=opt_net['fea_norm'],
                                  upsample_norm=opt_net['upsample_norm'])
    elif which_model == 'modulate_denoise_resnet':
        netG = arch.ModulateDenoiseResNet(
            in_nc=opt_net['in_nc'],
            out_nc=opt_net['out_nc'],
            nf=opt_net['nf'],
            nb=opt_net['nb'],
            upscale=opt_net['scale'],
            norm_type=opt_net['norm_type'],
            mode=opt_net['mode'],
            upsample_mode='pixelshuffle',
            ada_ksize=opt_net['ada_ksize'],
            gate_conv_bias=opt_net['gate_conv_bias'])
    elif which_model == 'noise_subnet':
        netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'],
                                out_nc=opt_net['out_nc'],
                                nf=opt_net['nf'],
                                nb=opt_net['nb'],
                                norm_type=opt_net['norm_type'],
                                mode=opt_net['mode'])
    elif which_model == 'cond_denoise_resnet':
        netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'],
                                      out_nc=opt_net['out_nc'],
                                      nf=opt_net['nf'],
                                      nb=opt_net['nb'],
                                      upscale=opt_net['scale'],
                                      upsample_mode='pixelshuffle',
                                      ada_ksize=opt_net['ada_ksize'],
                                      down_scale=opt_net['down_scale'],
                                      num_classes=opt_net['num_classes'],
                                      norm_type=opt_net['norm_type'])

    elif which_model == 'adabn_denoise_resnet':
        netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'],
                                          nf=opt_net['nf'],
                                          nb=opt_net['nb'],
                                          upscale=opt_net['scale'],
                                          down_scale=opt_net['down_scale'])

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(in_nc=opt_net['in_nc'],
                            out_nc=opt_net['out_nc'],
                            nf=opt_net['nf'],
                            nb=opt_net['nb'],
                            gc=opt_net['gc'],
                            upscale=opt_net['scale'],
                            norm_type=opt_net['norm_type'],
                            act_type='leakyrelu',
                            mode=opt_net['mode'],
                            upsample_mode='upconv')
    else:
        raise NotImplementedError(
            'Generator model [{:s}] not recognized'.format(which_model))

    if opt['init_type'] is not None:
        init_weights(netG, init_type=opt['init_type'], scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG
示例#6
0
# model_path = '../experiments/pretrained_models/sft_net_torch.pth' # torch version
model_path = '../experiments/pretrained_models/SFTGAN_bicx4_noBN_OST_bg.pth'  # pytorch training

test_img_folder_name = 'samples'  # image folder name
test_img_folder = '../data/' + test_img_folder_name  # HR images
test_prob_path = '../data/' + test_img_folder_name + '_segprob'  # probability maps
save_result_path = '../data/' + test_img_folder_name + '_result'  # results

# make dirs
util.mkdirs([save_result_path])

if 'torch' in model_path:  # torch version
    model = sft.SFT_Net_torch()
else:
    model = sft.SFT_Net()
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.cuda()

print('sftgan testing...')

idx = 0
for path in glob.glob(test_img_folder + '/*'):
    idx += 1
    basename = os.path.basename(path)
    base = os.path.splitext(basename)[0]
    print(idx, base)
    # read image
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = modcrop(img, 8)
def define_G(opt):
    gpu_ids = opt['gpu_ids']
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'sr_resnet':  # SRResNet
        netG = arch.SRResNet(
            in_nc         = opt_net['in_nc'],
            out_nc        = opt_net['out_nc'],
            nf            = opt_net['nf'],
            nb            = opt_net['nb'],
            upscale       = opt_net['scale'],
            norm_type     = opt_net['norm_type'],
            act_type      = 'relu',
            mode          = opt_net['mode'],
            upsample_mode = 'pixelshuffle'
        )

    elif which_model == 'sft_arch':  # SFT-GAN
        netG = sft_arch.SFT_Net()

    elif which_model == 'RRDB_net':  # RRDB
        netG = arch.RRDBNet(
            in_nc         = opt_net['in_nc'], 
            out_nc        = opt_net['out_nc'],
            nf            = opt_net['nf'],
            nb            = opt_net['nb'],
            gc            = opt_net['gc'],
            upscale       = opt_net['scale'],
            norm_type     = opt_net['norm_type'],
            act_type      = 'leakyrelu',
            mode          = opt_net['mode'],
            upsample_mode = 'upconv'
        )

    elif which_model == 'DDDB_net':  # DDDB
        self_attention = False
        self_attention_normalise = True           
        if 'self_attention' in opt_net:
            if opt_net['self_attention']=="include":
                self_attention = True
            if 'self_attention_normalise' in opt_net:
                if opt_net['self_attention_normalise']=='no':
                    self_attention_normalise = False

        logger.info(f"SelfAttention {self_attention} normalise {self_attention_normalise}")
        netG = arch.DDDBNet(
            in_nc         = opt_net['in_nc'], 
            out_nc        = opt_net['out_nc'], 
            nf            = opt_net['nf'],
            nb            = opt_net['nb'], 
            gc            = opt_net['gc'], 
            upscale       = opt_net['scale'], 
            norm_type     = opt_net['norm_type'],
            act_type      = 'leakyrelu', 
            mode          = opt_net['mode'],
            upsample_mode = 'upconv',
            self_attention= self_attention,
            self_attention_normalise = self_attention_normalise
        )

    elif which_model == 'secordresnet':
        netG = arch.SecOrdResNet(
            in_nc         = opt_net['in_nc'], 
            out_nc        = opt_net['out_nc'], 
            nf            = opt_net['nf'],
            nb            = opt_net['nb'], 
            upscale       = opt_net['scale'], 
            norm_type     = opt_net['norm_type'],
            act_type      = 'leakyrelu', 
            mode          = opt_net['mode'],
            upsample_mode = 'upconv',
            merge_mode    = opt_net['merge_mode'],
        )
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))

    if opt['is_train']:
        init_weights(netG, init_type='kaiming', scale=0.1)
    if gpu_ids:
        assert torch.cuda.is_available()
        netG = nn.DataParallel(netG)
    return netG