def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size): if isinstance(gen, DataParallel): z_dim = gen.module.z_dim num_classes = gen.module.num_classes conditional_strategy = dis.module.conditional_strategy else: z_dim = gen.z_dim num_classes = gen.num_classes conditional_strategy = dis.conditional_strategy zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device) if latent_op: zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, self.device) with torch.no_grad(): batch_images = gen(zs, fake_labels, evaluation=True) return batch_images
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if isinstance(gen_model, DataParallel) or isinstance( gen_model, DistributedDataParallel): z_dim = gen_model.module.z_dim num_classes = gen_model.module.num_classes conditional_strategy = dis_model.module.conditional_strategy else: z_dim = gen_model.z_dim num_classes = gen_model.num_classes conditional_strategy = dis_model.conditional_strategy zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device, real_label) if latent_op: zs = latent_optimise(zs, fake_labels, gen_model, dis_model, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) with torch.no_grad(): batch_images = gen_model(zs, fake_labels, evaluation=True) return batch_images, list(fake_labels.detach().cpu().numpy())
def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if isinstance(gen, DataParallel): z_dim = gen.module.z_dim num_classes = gen.module.num_classes else: z_dim = gen.z_dim num_classes = gen.num_classes z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: z = latent_optimise(z, fake_labels, gen, dis, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) with torch.no_grad(): batch_images = gen(z, fake_labels) return batch_images
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if isinstance(gen_model, DataParallel): z_dim = gen_model.module.z_dim num_classes = gen_model.module.num_classes else: z_dim = gen_model.z_dim num_classes = gen_model.num_classes z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device, real_label) if latent_op: z = latent_optimise(z, fake_labels, gen_moidel, dis_model, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) with torch.no_grad(): batch_images = gen_model(z, fake_labels) return batch_images, list(fake_labels.detach().cpu().numpy())
def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False): data_iter = iter(dataloader) batch_size = dataloader.batch_size disable_tqdm = device != 0 if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel): z_dim = generator.module.z_dim num_classes = generator.module.num_classes conditional_strategy = discriminator.module.conditional_strategy else: z_dim = generator.z_dim num_classes = generator.num_classes conditional_strategy = discriminator.conditional_strategy total_batch = num_evaluate//batch_size if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]: cutoff = 0.0 elif D_loss.__name__ == "loss_hinge_dis": cutoff = 0.0 elif D_loss.__name__ == "loss_wgan_dis": raise NotImplementedError if device == 0: logger.info("Calculate Accuracies....") if eval_generated_sample: for batch_id in tqdm(range(total_batch), disable=disable_tqdm): zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) real_images, real_labels = next(data_iter) real_images, real_labels = real_images.to(device), real_labels.to(device) fake_images = generator(zs, fake_labels, evaluation=True) with torch.no_grad(): if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: _, _, dis_out_fake = discriminator(fake_images, fake_labels) _, _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ACGAN": _, dis_out_fake = discriminator(fake_images, fake_labels) _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ProjGAN" or conditional_strategy == "no": dis_out_fake = discriminator(fake_images, fake_labels) dis_out_real = discriminator(real_images, real_labels) else: raise NotImplementedError dis_out_fake = dis_out_fake.detach().cpu().numpy() dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = np.concatenate((dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) else: confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) real_confid = confid[confid_label==1.0] fake_confid = confid[confid_label==0.0] true_positive = real_confid[np.where(real_confid>cutoff)] true_negative = fake_confid[np.where(fake_confid<cutoff)] only_real_acc = len(true_positive)/len(real_confid) only_fake_acc = len(true_negative)/len(fake_confid) return only_real_acc, only_fake_acc else: for batch_id in tqdm(range(total_batch), disable=disable_tqdm): real_images, real_labels = next(data_iter) real_images, real_labels = real_images.to(device), real_labels.to(device) with torch.no_grad(): if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: _, _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ACGAN": _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ProjGAN" or conditional_strategy == "no": dis_out_real = discriminator(real_images, real_labels) else: raise NotImplementedError dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = dis_out_real confid_label = np.asarray([1.0]*len(dis_out_real), np.float32) else: confid = np.concatenate((confid, dis_out_real), axis=0) confid_label = np.concatenate((confid_label, [1.0]*len(dis_out_real)), axis=0) real_confid = confid[confid_label==1.0] true_positive = real_confid[np.where(real_confid>cutoff)] only_real_acc = len(true_positive)/len(real_confid) return only_real_acc
def evaluation(self, step): with torch.no_grad(): self.logger.info( "Start Evaluation ({step} Step): {run_name}".format( step=step, run_name=self.run_name)) is_best = False self.dis_model.eval() self.gen_model.eval() if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: generator = self.gen_model if self.dataset_name == "imagenet" or self.dataset_name == "tiny_imagenet": num_eval = {'train': 50000, 'valid': 50000} elif self.dataset_name == "cifar10": num_eval = {'train': 50000, 'test': 10000} if self.latent_op: self.fixed_noise = latent_optimise( self.fixed_noise, self.fixed_fake_labels, generator, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) fid_score, self.m1, self.s1 = calculate_fid_score( self.eval_dataloader, generator, self.dis_model, self.inception_model, num_eval[self.type4eval_dataset], self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, self.latent_op_beta, self.second_device, self.mu, self.sigma, self.run_name) ### pre-calculate an inception score ### calculating inception score using the below will give you an underestimated one. ### plz use the official tensorflow implementation(inception_tensorflow.py). kl_score, kl_std = calculate_incep_score( self.eval_dataloader, generator, self.dis_model, self.inception_model, num_eval[self.type4eval_dataset], self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, self.latent_op_beta, 10, self.second_device) real_train_acc, fake_acc = calculate_accuracy( self.train_dataloader, generator, self.dis_model, self.D_loss, 10000, self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, self.latent_op_alpha, self.latent_op_beta, self.second_device, eval_generated_sample=True) if self.type4eval_dataset == 'train': acc_dict = {'real_train': real_train_acc, 'fake': fake_acc} else: real_eval_acc = calculate_accuracy(self.eval_dataloader, generator, self.dis_model, self.D_loss, 10000, self.truncated_factor, self.prior, self.latent_op, self.latent_op_step, self.latent_op_alpha, self.latent_op_beta, self.second_device, eval_generated_sample=False) acc_dict = { 'real_train': real_train_acc, 'real_valid': real_eval_acc, 'fake': fake_acc } self.writer.add_scalars('Accuracy', acc_dict, step) if self.best_fid is None: self.best_fid, self.best_step, is_best = fid_score, step, True else: if fid_score <= self.best_fid: self.best_fid, self.best_step, is_best = fid_score, step, True self.writer.add_scalars( 'FID score', { 'using {type} moments'.format(type=self.type4eval_dataset): fid_score }, step) self.writer.add_scalars( 'IS score', { '{num} generated images'.format(num=str(num_eval[self.type4eval_dataset])): kl_score }, step) self.logger.info( 'FID score (Step: {step}, Using {type} moments): {FID}'.format( step=step, type=self.type4eval_dataset, FID=fid_score)) self.logger.info( 'Inception score (Step: {step}, {num} generated images): {IS}'. format(step=step, num=str(num_eval[self.type4eval_dataset]), IS=kl_score)) self.logger.info( 'Best FID score (Step: {step}, Using {type} moments): {FID}'. format(step=self.best_step, type=self.type4eval_dataset, FID=self.best_fid)) self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() return is_best
def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) if self.diff_aug: images = DiffAugment(images, policy=self.policy) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_real, dis_out_real = self.dis_model( images, real_labels) cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real = self.dis_model(images, real_labels) dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.conditional_strategy == "ACGAN": dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += gradient_penelty_lambda * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) if self.conditional_strategy == "ACGAN": cls_out_real_aug, dis_out_real_aug = self.dis_model( images_aug, real_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real_aug = self.dis_model( images_aug, real_labels) else: raise NotImplementedError consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.d_spectral_norm: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError gen_acml_loss = self.G_loss(dis_out_fake) if self.conditional_strategy == "ACGAN": gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.ema: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.g_spectral_norm: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) with torch.no_grad(): if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: self.gen_model.eval() generator = self.gen_model generated_images = generator(self.fixed_noise, self.fixed_fake_labels) self.writer.add_images('Generated samples', (generated_images + 1) / 2, step_count) self.gen_model.train() if step_count % self.save_every == 0 or step_count == total_step: if self.evaluate: is_best = self.evaluation(step_count) self.save(step_count, is_best) else: self.save(step_count, False)
def generate_images(z_prior, truncation_factor, batch_size, z_dim, num_classes, y_sampler, radius, generator, discriminator, is_train, LOSS, RUN, MODEL, device, is_stylegan, generator_mapping, generator_synthesis, style_mixing_p, stylegan_update_emas, cal_trsp_cost): if is_train: truncation_factor = -1.0 lo_steps = LOSS.lo_steps4train apply_langevin = False else: lo_steps = LOSS.lo_steps4eval if truncation_factor != -1: if is_stylegan: assert 0 <= truncation_factor <= 1, "Stylegan truncation_factor must lie btw 0(strong truncation) ~ 1(no truncation)" else: assert 0 <= truncation_factor, "truncation_factor must lie btw 0(strong truncation) ~ inf(no truncation)" zs, fake_labels, zs_eps = sample_zy( z_prior=z_prior, batch_size=batch_size, z_dim=z_dim, num_classes=num_classes, truncation_factor=-1 if is_stylegan else truncation_factor, y_sampler=y_sampler, radius=radius, device=device) info_discrete_c, info_conti_c = None, None if MODEL.info_type in ["discrete", "both"]: info_discrete_c = torch.randint( MODEL.info_dim_discrete_c, (batch_size, MODEL.info_num_discrete_c), device=device) zs = torch.cat( (zs, F.one_hot(info_discrete_c, MODEL.info_dim_discrete_c).view( batch_size, -1)), dim=1) if MODEL.info_type in ["continuous", "both"]: info_conti_c = torch.rand( batch_size, MODEL.info_num_conti_c, device=device) * 2 - 1 zs = torch.cat((zs, info_conti_c), dim=1) trsp_cost = None if LOSS.apply_lo: zs, trsp_cost = losses.latent_optimise(zs=zs, fake_labels=fake_labels, generator=generator, discriminator=discriminator, batch_size=batch_size, lo_rate=LOSS.lo_rate, lo_steps=lo_steps, lo_alpha=LOSS.lo_alpha, lo_beta=LOSS.lo_beta, eval=not is_train, cal_trsp_cost=cal_trsp_cost, device=device) if not is_train and RUN.langevin_sampling: zs = langevin_sampling(zs=zs, z_dim=z_dim, fake_labels=fake_labels, generator=generator, discriminator=discriminator, batch_size=batch_size, langevin_rate=RUN.langevin_rate, langevin_noise_std=RUN.langevin_noise_std, langevin_decay=RUN.langevin_decay, langevin_decay_steps=RUN.langevin_decay_steps, langevin_steps=RUN.langevin_steps, device=device) if is_stylegan: ws, fake_images = stylegan_generate_images( zs=zs, fake_labels=fake_labels, num_classes=num_classes, style_mixing_p=style_mixing_p, update_emas=stylegan_update_emas, generator_mapping=generator_mapping, generator_synthesis=generator_synthesis, truncation_psi=truncation_factor, truncation_cutoff=RUN.truncation_cutoff) else: fake_images = generator(zs, fake_labels, eval=not is_train) ws = None if zs_eps is not None: if is_stylegan: ws_eps, fake_images_eps = stylegan_generate_images( zs=zs_eps, fake_labels=fake_labels, num_classes=num_classes, style_mixing_p=style_mixing_p, update_emas=stylegan_update_emas, generator_mapping=generator_mapping, generator_synthesis=generator_synthesis, truncation_psi=truncation_factor, truncation_cutoff=RUN.truncation_cutoff) else: _, fake_images_eps = generator(zs_eps, fake_labels, eval=not is_train) else: fake_images_eps = None return fake_images, fake_labels, fake_images_eps, trsp_cost, ws, info_discrete_c, info_conti_c
def valid_and_save(self, step): self.dis_model.eval() self.gen_model.eval() if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: generator = self.gen_model if self.evaluate: if self.latent_op: self.fixed_noise = latent_optimise( self.fixed_noise, self.fixed_fake_labels, generator, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) fake_images = generator(self.fixed_noise, self.fixed_fake_labels).detach().cpu() plot_generated_samples_path = join( 'figures', self.run_name, "[{}]generated_samples.png".format(step)) plot_img_canvas(fake_images, plot_generated_samples_path, self.logger) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step) fid_score, self.m1, self.s1 = calculate_fid_score( self.evaluation_dataloader, generator, self.dis_model, self.inception_model, len(self.evaluation_dataloader.dataset), self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, self.latent_op_beta, self.second_device, self.mu, self.sigma) ### pre-calculate an inception score ### calculating inception score using the below will give you an underestimated one. ### plz use the official tensorflow implementation(inception_tensorflow.py). kl_score, kl_std = calculate_incep_score( self.evaluation_dataloader, generator, self.dis_model, self.inception_model, 10000, self.truncated_factor, self.prior, self.latent_op, self.latent_op_step4eval, self.latent_op_alpha, self.latent_op_beta, 2, self.second_device) self.writer.add_scalars('FID_score', {'using_train_moments': fid_score}, step) self.writer.add_scalars('IS_score', {'50000_generated_images': kl_score}, step) else: kl_score = "xxx" kl_std = "xxx" fid_score = "xxx" checkpoint_name = SAVE_FORMAT.format( step=step, Inception_mean=kl_score, Inception_std=kl_std, FID=fid_score, ) g_states = { 'seed': self.config['seed'], 'run_name': self.run_name, 'step': step, 'state_dict': self.gen_model.state_dict(), 'optimizer': self.G_optimizer.state_dict(), } d_states = { 'seed': self.config['seed'], 'run_name': self.run_name, 'step': step, 'state_dict': self.dis_model.state_dict(), 'optimizer': self.D_optimizer.state_dict(), 'best_val_fid': self.best_val_fid, 'best_checkpoint_fid_path': self.best_checkpoint_fid_path, 'best_val_is': self.best_val_is, 'best_checkpoint_is_path': self.best_checkpoint_is_path, } checkpoint_output_path = join(self.checkpoint_dir, checkpoint_name) g_checkpoint_output_path = join(self.checkpoint_dir, "model=G-" + checkpoint_name) d_checkpoint_output_path = join(self.checkpoint_dir, "model=D-" + checkpoint_name) torch.save(g_states, g_checkpoint_output_path) torch.save(d_states, d_checkpoint_output_path) if self.Gen_copy is not None: g_ema_states = {'state_dict': self.Gen_copy.state_dict()} g_ema_checkpoint_output_path = join( self.checkpoint_dir, "model=G_ema-" + checkpoint_name) torch.save(g_ema_states, g_ema_checkpoint_output_path) if self.evaluate: representative_val_fid = fid_score if self.best_val_fid is None or self.best_val_fid > representative_val_fid: self.best_val_fid = representative_val_fid self.best_checkpoint_fid_path = checkpoint_output_path representative_val_is = kl_score if self.best_val_is is None or self.best_val_is < representative_val_is: self.best_val_is = representative_val_is self.best_checkpoint_is_path = checkpoint_output_path if self.logger: self.logger.info( "Saved model to {}".format(checkpoint_output_path)) if self.evaluate: self.logger.info("Current best model(FID) is {}".format( self.best_checkpoint_fid_path)) self.logger.info("Current best model(IS) is {}".format( self.best_checkpoint_is_path)) self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train()
def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) _, cls_out_real, dis_out_real = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.auxiliary_classifier: dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += self.lambda4gp * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) _, _, dis_out_real_aug = self.dis_model( images_aug, real_labels) consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.config['d_spectral_norm']: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_D_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_D_update': l2_norm_grad_z_aft_D_update.mean().item() }, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_out_fake) if self.auxiliary_classifier: gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[ 'ortho_reg']: if isinstance(self.gen_model, DataParallel): ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.module.shared.parameters() ]) else: ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.shared.parameters() ]) self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.config['ema']: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.config['g_spectral_norm']: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_G_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_G_update': l2_norm_grad_z_aft_G_update.mean().item() }, step_count) if step_count % self.save_every == 0 or step_count == total_step: self.valid_and_save(step_count)
def calculate_acc_confidence(dataloader, generator, discriminator, G_loss, num_evaluate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): generator.eval() discriminator.eval() data_iter = iter(dataloader) batch_size = dataloader.batch_size if isinstance(generator, DataParallel): z_dim = generator.module.z_dim num_classes = generator.module.num_classes else: z_dim = generator.z_dim num_classes = generator.num_classes if num_evaluate % batch_size == 0: total_batch = num_evaluate // batch_size else: raise Exception("num_evaluate '%' batch4metrics should be 0!") if G_loss.__name__ == "loss_dcgan_gen": cutoff = 0.5 fake_target = 0.0 elif G_loss.__name__ == "loss_hinge_gen": cutoff = 0.0 fake_target = -1.0 elif G_loss.__name__ == "loss_wgan_gen": raise NotImplementedError for batch_id in tqdm(range(total_batch)): z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: z = latent_optimise(z, fake_labels, generator, discriminator, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) images_real, real_labels = next(data_iter) images_real, real_labels = images_real.to(device), real_labels.to( device) with torch.no_grad(): images_gen = generator(z, fake_labels) _, _, dis_out_fake = discriminator(images_gen, fake_labels) _, _, dis_out_real = discriminator(images_real, real_labels) dis_out_fake = dis_out_fake.detach().cpu().numpy() dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = np.concatenate((dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate( ([fake_target] * batch_size, [1.0] * batch_size), axis=0) else: confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate( (confid_label, [fake_target] * batch_size, [1.0] * batch_size), axis=0) real_confid = confid[confid_label == 1.0] fake_confid = confid[confid_label == fake_target] true_positive = real_confid[np.where(real_confid > cutoff)] true_negative = fake_confid[np.where(fake_confid < cutoff)] only_real_acc = len(true_positive) / len(real_confid) only_fake_acc = len(true_negative) / len(fake_confid) mixed_acc = (len(true_positive) + len(true_negative)) / len(confid) generator.train() discriminator.train() return only_real_acc, only_fake_acc, mixed_acc, confid, confid_label