Esempio n. 1
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE
Esempio n. 2
0
    def initialize_networks(self, opt):
        netG_for_CT = networks.define_G(opt)
        netD_aligned = networks.define_D(opt) if opt.isTrain else None
        netG_for_MR = networks.define_G(opt)
        netD_unaligned = networks.define_D(opt) if opt.isTrain else None
        
        if not opt.isTrain or opt.continue_train:
            netG_for_CT = util.load_network(netG_for_CT, 'G_for_CT', opt.which_epoch, opt)
            netG_for_MR = util.load_network(netG_for_MR, 'G_for_MR', opt.which_epoch, opt)
            if opt.isTrain:
                netD_aligned = util.load_network(netD_aligned, 'D_aligned', opt.which_epoch, opt)
                netD_unaligned = util.load_network(netD_unaligned, 'D_unaligned', opt.which_epoch, opt)

        return netG_for_CT, netD_aligned, netG_for_MR, netD_unaligned
Esempio n. 3
0
    def initialize_networks(self, opt):
        print(opt.isTrain)
        netG = networks.define_G(opt)
        netD = networks.define_D(
            opt) if opt.isTrain and not opt.no_disc else None
        netE = networks.define_E(opt)

        if not opt.isTrain or opt.continue_train or opt.manipulation:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain or opt.needs_D:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                print('network D loaded')

        return netG, netD, netE
Esempio n. 4
0
    def initialize_networks(self, opt):
        if self.opt.two_step_model:
            netGbg = networks.define_Gbg(opt)
        else:
            netGbg = None
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None

        if not opt.isTrain or opt.continue_train:
            if self.opt.two_step_model:
                netGbg = util.load_network(netGbg, 'Gbg', opt.which_epoch, opt)
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)

        return netGbg, netG, netD
Esempio n. 5
0
    def initialize_networks(self, opt):
        # doc: initializes one of the generator classes in generator.py file
        netG = networks.define_G(opt)
        # doc: initializes one of the discriminator classes in the discriminator.py file
        netD = networks.define_D(opt) if opt.isTrain else None
        # doc: initializes the VAE
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE
Esempio n. 6
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netD2 = networks.define_D(
            opt) if opt.isTrain and opt.unpairTrain else None
        netE = networks.define_E(
            opt) if opt.use_vae else None  # this is for original spade network
        netIG = networks.define_IG(
            opt
        ) if opt.use_ig else None  # this is the orient inpainting network
        netSIG = networks.define_SIG(
            opt
        ) if opt.use_stroke else None  # this is the stroke orient inpainting network
        netFE = networks.define_FE(
            opt
        ) if opt.use_instance_feat else None  # this is the feat encoder from pix2pixHD
        netB = networks.define_B(opt) if opt.use_blender else None

        if not opt.isTrain or opt.continue_train:
            # if the pth exist
            save_filename = '%s_net_%s.pth' % (opt.which_epoch, 'G')
            save_dir = os.path.join(opt.checkpoints_dir, opt.name)
            G_path = os.path.join(save_dir, save_filename)
            if os.path.exists(G_path):

                netG = util.load_network(netG, 'G', opt.which_epoch, opt)
                if opt.fix_netG:
                    netG.eval()
                if opt.use_blender:
                    netB = util.load_blend_network(netB, 'B', opt.which_epoch,
                                                   opt)
                if opt.isTrain:
                    netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                    if opt.unpairTrain:
                        netD2 = util.load_network(netD2, 'D', opt.which_epoch,
                                                  opt)
                if opt.use_vae:
                    netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        if opt.use_ig:
            netIG = util.load_inpainting_network(netIG, opt)
            netIG.eval()
        if opt.use_stroke:
            netSIG = util.load_sinpainting_network(netSIG, opt)
            netSIG.eval()

        return netG, netD, netE, netIG, netFE, netB, netD2, netSIG
Esempio n. 7
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        opt.input_nc = 2
        netD1 = networks.define_D(opt) if opt.isTrain else None
        opt.input_nc = 7
        netD2 = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD1 = util.load_network(netD1, 'D1', opt.which_epoch, opt)
                netD2 = util.load_network(netD2, 'D2', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD1, netD2, netE
Esempio n. 8
0
    def initialize_fcn(self, opt):
        from models.fcn8 import VGG16_FCN8s
        net = VGG16_FCN8s(num_cls=opt.label_nc, pretrained=False)

        if not opt.isTrain or opt.continue_train or opt.joint_train:
            net = util.load_network(net, 'S', opt.which_epoch, opt)
        net.eval()
        return net
Esempio n. 9
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            print('net_G successfully loaded from {}.'.format(opt.which_epoch))
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                print('net_D successfully loaded from {}.'.format(
                    opt.which_epoch))
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
                print('net_E successfully loaded from {}.'.format(
                    opt.which_epoch))

        return netG, netD, netE
Esempio n. 10
0
    def initialize_networks(self, opt):
        net = {}
        net['netG'] = networks.define_G(opt)
        net['netD'] = networks.define_D(opt) if opt.isTrain else None
        net['netCorr'] = networks.define_Corr(opt)
        net['netDomainClassifier'] = networks.define_DomainClassifier(
            opt) if opt.weight_domainC > 0 and opt.domain_rela else None

        if not opt.isTrain or opt.continue_train:
            net['netG'] = util.load_network(net['netG'], 'G', opt.which_epoch,
                                            opt)
            if opt.isTrain:
                net['netD'] = util.load_network(net['netD'], 'D',
                                                opt.which_epoch, opt)
            net['netCorr'] = util.load_network(net['netCorr'], 'Corr',
                                               opt.which_epoch, opt)
            if opt.weight_domainC > 0 and opt.domain_rela:
                net['netDomainClassifier'] = util.load_network(
                    net['netDomainClassifier'], 'DomainClassifier',
                    opt.which_epoch, opt)
            if (not opt.isTrain) and opt.use_ema:
                net['netG'] = util.load_network(net['netG'], 'G_ema',
                                                opt.which_epoch, opt)
                net['netCorr'] = util.load_network(net['netCorr'],
                                                   'netCorr_ema',
                                                   opt.which_epoch, opt)
        return net
Esempio n. 11
0
    def initialize_networks(self, opt):

        netG2 = networks.define_G(opt)
        netD2 = networks.define_D(opt) if opt.isTrain else None
        netE2 = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG2 = util.load_network2(netG2, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD2 = util.load_network2(netD2, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE2 = util.load_network2(netE2, 'E', opt.which_epoch, opt)
        elif opt.use_vae and opt.pretrain_vae:
            netE2 = util.load_network2(netE2, 'E', opt.which_epoch, opt)

        if opt.edge_cat:
            opt.label_nc -= 1
            opt.semantic_nc -= 1

        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        elif opt.use_vae and opt.pretrain_vae:
            netE = util.load_network(netE, 'E', opt.which_epoch, opt)
            print('Load fixed netE.')

        if opt.edge_cat:
            opt.label_nc += 1
            opt.semantic_nc += 1

        return netG, netD, netE, netG2, netD2, netE2
Esempio n. 12
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None
        netAux = networks.define_Aux(opt) if (opt.isTrain
                                              and opt.use_aux) else None
        self.model_names = ['G']
        self.model_names.append('D') if opt.isTrain else None
        self.model_names.append('E') if opt.use_vae else None
        self.model_names.append('Aux') if (opt.isTrain
                                           and opt.use_aux) else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                if opt.use_aux:
                    netAux = util.load_network(netAux, 'Aux', opt.which_epoch,
                                               opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE, netAux
Esempio n. 13
0
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)
                or isinstance(m, nn.Embedding)):
            # if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'N02':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type in ['glorot', 'xavier']:
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'ortho':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    'initialization method [%s] is not implemented' %
                    init_type)
            # if hasattr(m, 'bias') and m.bias is not None:
            #     init.constant_(m.bias.data, 0.0)
        # elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
        #     init.normal_(m.weight.data, 1.0, init_gain)
        #     init.constant_(m.bias.data, 0.0)

    if init_type in ['N02', 'glorot', 'xavier', 'kaiming', 'ortho']:
        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function <init_func>
    else:
        print('loading the model from %s' % init_type)
        net = load_network(net, init_type, 'latest')
    return net
Esempio n. 14
0
 def initialize_networks(self, opt, verbose=True):
     # IDK
     netG = util.load_network(networks.define_G(opt, verbose), 'G',
                              opt.which_epoch, opt)
     return netG
Esempio n. 15
0
        styleImg_root = opt.image_dir
        styleList = sorted(os.listdir(styleImg_root))
        styles, labels, path = load_style(len(styleList))  #

    result_file = {'styles': styles, 'labels': labels, 'path': path}
    torch.save(result_file, 'ffhq_styles.npy')
    print('Extract styles done.')
else:
    print('==> Loading styles.')
    styles = torch.load('ffhq_styles.npy')
    styles, labels, path = styles['styles'], styles['labels'], styles['path']

fid = FID()
epochs = list(range(800000, 900000, 100000))
for epoch in epochs:
    if not os.path.exists('results/fid_sample/%s' % opt.name):
        os.mkdir('results/fid_sample/%s' % opt.name)
    model.netG = util.load_network(model.netG, 'G', str(epoch), opt)

    features = fid.extract_feature_from_samples(
        model,
        n_sample=1000,
        batch_size=6,
        saveName='%s/%s' % (opt.name, str(epoch))).numpy()
    sample_mean = np.mean(features, 0)
    sample_cov = np.cov(features, rowvar=False)

    fid_score = fid.calc_fid(sample_mean, sample_cov)

    print('%d fid: %s' % (epoch, fid_score))