コード例 #1
0
    def forward(self, batch_size=None):
        bs = self.opt.batch_size if batch_size is None else batch_size
        if self.opt.z_type == 'Gaussian':
            z = torch.randn(bs, self.opt.z_dim, 1, 1, device=self.device)
        elif self.opt.z_type == 'Uniform':
            z = torch.rand(bs, self.opt.z_dim, 1, 1,
                           device=self.device) * 2. - 1.

        if not self.opt.cgan:
            self.gen_imgs = self.netG(z)
        else:
            y = self.CatDis.sample([bs])
            self.y_ = one_hot(y, [bs, self.opt.cat_num])
            self.gen_imgs = self.netG(z, self.y_)
コード例 #2
0
 def forward(self) -> dict:
     batch_size = self.opt.batch_size
     if self.opt.gan_mode == "conditional":
         z = get_prior(self.opt.batch_size, self.opt.z_dim, self.opt.z_type,
                       self.device)
         y = self.CatDis.sample([batch_size])
         y = one_hot(y, [batch_size, self.opt.cat_num])
         gen_data = self.netG(z, y)
         self.set_output(gen_data)
         return {'data': gen_data, 'condition': y}
     elif self.opt.gan_mode == 'unconditional':
         gen_data = self.netG(self.inputs)
         self.set_output(gen_data)
         return {'data': gen_data}
     elif self.opt.gan_mode == 'unconditional-z':
         z = get_prior(self.opt.batch_size, self.opt.z_dim, self.opt.z_type,
                       self.device)
         gen_data = self.netG(z)
         self.set_output(gen_data)
         return {'data': gen_data}
     else:
         raise ValueError(f'unsupported gan_mode {self.opt.gan_mode}')
コード例 #3
0
    def __init__(self, opt):
        """Initialize this model class.

        Parameters:
            opt -- training/test options

        A few things can be done here.
        - (required) call the initialization function of BaseModel
        - define loss function, visualization images, model names, and optimizers
        """
        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel

        self.opt = opt
        if opt.d_loss_mode == 'wgan' and not opt.use_gp:
            raise NotImplementedError('using wgan on D must be with use_gp = True.')

        self.loss_names = ['G_real', 'G_fake', 'D_real', 'D_fake', 'D_gp', 'G', 'D']
        self.visual_names = ['real_visual', 'gen_visual']

        if self.isTrain:  # only defined during training time
            self.model_names = ['G', 'D']
        else:
            self.model_names = ['G']

        if self.opt.cgan:
            probs = np.ones(self.opt.cat_num)/self.opt.cat_num 
            self.CatDis = Categorical(torch.tensor(probs))

        # define networks 
        self.netG = networks.define_G(opt.z_dim, opt.output_nc, opt.ngf, opt.netG,
                opt.g_norm, opt.cgan, opt.cat_num, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                          opt.d_norm, opt.cgan, opt.cat_num, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # only defined during training time
            # define G mutations 
            self.G_mutations = []
            for g_loss in opt.g_loss_mode: 
                self.G_mutations.append(networks.GANLoss(g_loss, 'G', opt.which_D).to(self.device))
            # define loss functions
            self.criterionD = networks.GANLoss(opt.d_loss_mode, 'D', opt.which_D).to(self.device)
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, opt.beta2))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
        
        # Evolutinoary candidatures setting (init) 
        self.G_candis = [] 
        self.optG_candis = [] 
        for i in range(opt.candi_num): 
            self.G_candis.append(copy.deepcopy(self.netG.state_dict()))
            self.optG_candis.append(copy.deepcopy(self.optimizer_G.state_dict()))
        
        # visulize settings 
        self.N =int(np.trunc(np.sqrt(min(opt.batch_size, 64))))
        if self.opt.z_type == 'Gaussian': 
            self.z_fixed = torch.randn(self.N*self.N, opt.z_dim, 1, 1, device=self.device) 
        elif self.opt.z_type == 'Uniform': 
            self.z_fixed = torch.rand(self.N*self.N, opt.z_dim, 1, 1, device=self.device)*2. - 1. 
        if self.opt.cgan:
            yf = self.CatDis.sample([self.N*self.N])
            self.y_fixed = one_hot(yf, [self.N*self.N, self.opt.cat_num])

        # the # of image for each evluation
        self.eval_size = max(math.ceil((opt.batch_size * opt.D_iters) / opt.candi_num), opt.eval_size)
コード例 #4
0
    def get_current_scores(self):
        if self.opt.model == 'egan':
            # load current best G
            F = self.Fitness[:, 2]
            idx = np.where(F == max(F))[0][0]
            self.netG.load_state_dict(self.G_candis[idx])

        # load current best G
        scores_ret = OrderedDict()

        samples = torch.zeros((self.opt.evaluation_size, 3, self.opt.crop_size,
                               self.opt.crop_size),
                              device=self.device)
        n_fid_batches = self.opt.evaluation_size // self.opt.fid_batch_size

        for i in range(n_fid_batches):
            frm = i * self.opt.fid_batch_size
            to = frm + self.opt.fid_batch_size

            if self.opt.z_type == 'Gaussian':
                z = torch.randn(self.opt.fid_batch_size,
                                self.opt.z_dim,
                                1,
                                1,
                                device=self.device)
            elif self.opt.z_type == 'Uniform':
                z = torch.rand(self.opt.fid_batch_size,
                               self.opt.z_dim,
                               1,
                               1,
                               device=self.device) * 2. - 1.

            if self.opt.cgan:
                y = self.CatDis.sample([self.opt.fid_batch_size])
                y = one_hot(y, [self.opt.fid_batch_size])

            if not self.opt.cgan:
                gen_s = self.netG(z).detach()
            else:
                gen_s = self.netG(z, y).detach()
            samples[frm:to] = gen_s
            print("\rgenerate fid sample batch %d/%d " %
                  (i + 1, n_fid_batches))

        print("%d samples generating done" % self.opt.evaluation_size)

        if self.opt.use_pytorch_scores:
            self.IS_mean, self.IS_var, self.FID = self.get_inception_metrics(
                samples, self.opt.evaluation_size, num_splits=10)
            if 'FID' in self.opt.score_name:
                print(self.FID)
                scores_ret['FID'] = float(self.FID)
            if 'IS' in self.opt.score_name:
                print(self.IS_mean, self.IS_var)
                scores_ret['IS_mean'] = float(self.IS_mean)
                scores_ret['IS_var'] = float(self.IS_var)

        else:
            # Cast, reshape and transpose (BCHW -> BHWC)
            samples = samples.cpu().numpy()
            samples = ((samples + 1.0) * 127.5).astype('uint8')
            samples = samples.reshape(self.opt.evaluation_size, 3,
                                      self.opt.crop_size, self.opt.crop_size)
            samples = samples.transpose(0, 2, 3, 1)
            for name in self.opt.score_name:
                if name == 'FID':
                    mu_gen, sigma_gen = fid.calculate_activation_statistics(
                        samples,
                        self.sess,
                        batch_size=self.opt.fid_batch_size,
                        verbose=True)
                    print("calculate FID:")
                    try:
                        self.FID = fid.calculate_frechet_distance(
                            mu_gen, sigma_gen, self.mu_real, self.sigma_real)
                    except Exception as e:
                        print(e)
                        self.FID = 500
                    print(self.FID)
                    scores_ret[name] = float(self.FID)
                if name == 'IS':
                    Imlist = []
                    for i in range(len(samples)):
                        im = samples[i, :, :, :]
                        Imlist.append(im)
                    print(np.array(Imlist).shape)
                    self.IS_mean, self.IS_var = get_inception_score(Imlist)

                    scores_ret['IS_mean'] = float(self.IS_mean)
                    scores_ret['IS_var'] = float(self.IS_var)
                    print(self.IS_mean, self.IS_var)

        return scores_ret