def generate_images(self, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, batch_size): if isinstance(gen, DataParallel): z_dim = gen.module.z_dim num_classes = gen.module.num_classes conditional_strategy = dis.module.conditional_strategy else: z_dim = gen.z_dim num_classes = gen.num_classes conditional_strategy = dis.conditional_strategy zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, self.device) if latent_op: zs = latent_optimise(zs, fake_labels, gen, dis, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, self.device) with torch.no_grad(): batch_images = gen(zs, fake_labels, evaluation=True) return batch_images
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if isinstance(gen_model, DataParallel) or isinstance( gen_model, DistributedDataParallel): z_dim = gen_model.module.z_dim num_classes = gen_model.module.num_classes conditional_strategy = dis_model.module.conditional_strategy else: z_dim = gen_model.z_dim num_classes = gen_model.num_classes conditional_strategy = dis_model.conditional_strategy zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device, real_label) if latent_op: zs = latent_optimise(zs, fake_labels, gen_model, dis_model, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) with torch.no_grad(): batch_images = gen_model(zs, fake_labels, evaluation=True) return batch_images, list(fake_labels.detach().cpu().numpy())
def apply_accumulate_stat(generator, acml_step, prior, batch_size, z_dim, num_classes, device): generator.train() generator.apply(reset_bn_stat) for i in range(acml_step): new_batch_size = random.randint(1, batch_size) z, fake_labels = sample_latents(prior, new_batch_size, z_dim, 1, num_classes, None, device) generated_images = generator(z, fake_labels) generator.eval()
def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): if isinstance(gen, DataParallel): z_dim = gen.module.z_dim num_classes = gen.module.num_classes else: z_dim = gen.z_dim num_classes = gen.num_classes z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: z = latent_optimise(z, fake_labels, gen, dis, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) with torch.no_grad(): batch_images = gen(z, fake_labels) return batch_images
def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False): data_iter = iter(dataloader) batch_size = dataloader.batch_size disable_tqdm = device != 0 if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel): z_dim = generator.module.z_dim num_classes = generator.module.num_classes conditional_strategy = discriminator.module.conditional_strategy else: z_dim = generator.z_dim num_classes = generator.num_classes conditional_strategy = discriminator.conditional_strategy total_batch = num_evaluate//batch_size if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]: cutoff = 0.0 elif D_loss.__name__ == "loss_hinge_dis": cutoff = 0.0 elif D_loss.__name__ == "loss_wgan_dis": raise NotImplementedError if device == 0: logger.info("Calculate Accuracies....") if eval_generated_sample: for batch_id in tqdm(range(total_batch), disable=disable_tqdm): zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) real_images, real_labels = next(data_iter) real_images, real_labels = real_images.to(device), real_labels.to(device) fake_images = generator(zs, fake_labels, evaluation=True) with torch.no_grad(): if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: _, _, dis_out_fake = discriminator(fake_images, fake_labels) _, _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ACGAN": _, dis_out_fake = discriminator(fake_images, fake_labels) _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ProjGAN" or conditional_strategy == "no": dis_out_fake = discriminator(fake_images, fake_labels) dis_out_real = discriminator(real_images, real_labels) else: raise NotImplementedError dis_out_fake = dis_out_fake.detach().cpu().numpy() dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = np.concatenate((dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) else: confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0) real_confid = confid[confid_label==1.0] fake_confid = confid[confid_label==0.0] true_positive = real_confid[np.where(real_confid>cutoff)] true_negative = fake_confid[np.where(fake_confid<cutoff)] only_real_acc = len(true_positive)/len(real_confid) only_fake_acc = len(true_negative)/len(fake_confid) return only_real_acc, only_fake_acc else: for batch_id in tqdm(range(total_batch), disable=disable_tqdm): real_images, real_labels = next(data_iter) real_images, real_labels = real_images.to(device), real_labels.to(device) with torch.no_grad(): if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]: _, _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ACGAN": _, dis_out_real = discriminator(real_images, real_labels) elif conditional_strategy == "ProjGAN" or conditional_strategy == "no": dis_out_real = discriminator(real_images, real_labels) else: raise NotImplementedError dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = dis_out_real confid_label = np.asarray([1.0]*len(dis_out_real), np.float32) else: confid = np.concatenate((confid, dis_out_real), axis=0) confid_label = np.concatenate((confid_label, [1.0]*len(dis_out_real)), axis=0) real_confid = confid[confid_label==1.0] true_positive = real_confid[np.where(real_confid>cutoff)] only_real_acc = len(true_positive)/len(real_confid) return only_real_acc
def __init__( self, run_name, best_step, dataset_name, type4eval_dataset, logger, writer, n_gpus, gen_model, dis_model, inception_model, Gen_copy, Gen_ema, train_dataloader, eval_dataloader, conditional_strategy, z_dim, num_classes, hypersphere_dim, d_spectral_norm, g_spectral_norm, G_optimizer, D_optimizer, batch_size, g_steps_per_iter, d_steps_per_iter, accumulation_steps, total_step, G_loss, D_loss, contrastive_lambda, tempering_type, tempering_step, start_temperature, end_temperature, gradient_penalty_for_dis, gradient_penelty_lambda, weight_clipping_for_dis, weight_clipping_bound, consistency_reg, consistency_lambda, diff_aug, prior, truncated_factor, ema, latent_op, latent_op_rate, latent_op_step, latent_op_step4eval, latent_op_alpha, latent_op_beta, latent_norm_reg_weight, default_device, second_device, print_every, save_every, checkpoint_dir, evaluate, mu, sigma, best_fid, best_fid_checkpoint_path, train_config, model_config, ): self.run_name = run_name self.best_step = best_step self.dataset_name = dataset_name self.type4eval_dataset = type4eval_dataset self.logger = logger self.writer = writer self.n_gpus = n_gpus self.gen_model = gen_model self.dis_model = dis_model self.inception_model = inception_model self.Gen_copy = Gen_copy self.Gen_ema = Gen_ema self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.conditional_strategy = conditional_strategy self.z_dim = z_dim self.num_classes = num_classes self.hypersphere_dim = hypersphere_dim self.d_spectral_norm = d_spectral_norm self.g_spectral_norm = g_spectral_norm self.G_optimizer = G_optimizer self.D_optimizer = D_optimizer self.batch_size = batch_size self.g_steps_per_iter = g_steps_per_iter self.d_steps_per_iter = d_steps_per_iter self.accumulation_steps = accumulation_steps self.total_step = total_step self.G_loss = G_loss self.D_loss = D_loss self.contrastive_lambda = contrastive_lambda self.tempering_type = tempering_type self.tempering_step = tempering_step self.start_temperature = start_temperature self.end_temperature = end_temperature self.gradient_penalty_for_dis = gradient_penalty_for_dis self.gradient_penelty_lambda = gradient_penelty_lambda self.weight_clipping_for_dis = weight_clipping_for_dis self.weight_clipping_bound = weight_clipping_bound self.consistency_reg = consistency_reg self.consistency_lambda = consistency_lambda self.diff_aug = diff_aug self.prior = prior self.truncated_factor = truncated_factor self.ema = ema self.latent_op = latent_op self.latent_op_rate = latent_op_rate self.latent_op_step = latent_op_step self.latent_op_step4eval = latent_op_step4eval self.latent_op_alpha = latent_op_alpha self.latent_op_beta = latent_op_beta self.latent_norm_reg_weight = latent_norm_reg_weight self.default_device = default_device self.second_device = second_device self.print_every = print_every self.save_every = save_every self.checkpoint_dir = checkpoint_dir self.evaluate = evaluate self.mu = mu self.sigma = sigma self.best_fid = best_fid self.best_fid_checkpoint_path = best_fid_checkpoint_path self.train_config = train_config self.model_config = model_config self.start_time = datetime.now() self.l2_loss = torch.nn.MSELoss() self.ce_loss = torch.nn.CrossEntropyLoss() if self.conditional_strategy == 'ContraGAN': self.contrastive_criterion = Conditional_Embedding_Contrastive_loss( self.second_device, self.batch_size) self.tempering_range = self.end_temperature - self.start_temperature assert tempering_type == "constant" or tempering_type == "continuous" or tempering_type == "discrete", \ "tempering_type should be one of constant, continuous, or discrete" if self.tempering_type == 'discrete': self.tempering_interval = self.total_step // ( self.tempering_step + 1) if self.conditional_strategy != "no": if self.dataset_name == "cifar10": cls_wise_sampling = "all" else: cls_wise_sampling = "some" else: cls_wise_sampling = "no" self.policy = "color,translation,cutout" self.fixed_noise, self.fixed_fake_labels = sample_latents( self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device, cls_wise_sampling=cls_wise_sampling)
def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) if self.diff_aug: images = DiffAugment(images, policy=self.policy) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_real, dis_out_real = self.dis_model( images, real_labels) cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real = self.dis_model(images, real_labels) dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.conditional_strategy == "ACGAN": dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += gradient_penelty_lambda * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) if self.conditional_strategy == "ACGAN": cls_out_real_aug, dis_out_real_aug = self.dis_model( images_aug, real_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real_aug = self.dis_model( images_aug, real_labels) else: raise NotImplementedError consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.d_spectral_norm: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError gen_acml_loss = self.G_loss(dis_out_fake) if self.conditional_strategy == "ACGAN": gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.ema: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.g_spectral_norm: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) with torch.no_grad(): if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: self.gen_model.eval() generator = self.gen_model generated_images = generator(self.fixed_noise, self.fixed_fake_labels) self.writer.add_images('Generated samples', (generated_images + 1) / 2, step_count) self.gen_model.train() if step_count % self.save_every == 0 or step_count == total_step: if self.evaluate: is_best = self.evaluation(step_count) self.save(step_count, is_best) else: self.save(step_count, False)
def run_ours(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # if self.tempering_type == 'continuous': t = self.start_temperature + step_count * ( self.end_temperature - self.start_temperature) / total_step elif self.tempering_type == 'discrete': t = self.start_temperature + \ (step_count//self.tempering_interval)*(self.end_temperature-self.start_temperature)/self.tempering_step else: t = self.start_temperature for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) if self.diff_aug: images = DiffAugment(images, policy=self.policy) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) real_cls_mask = make_mask(real_labels, self.num_classes, self.second_device) cls_real_proxies, cls_real_embed, dis_real_authen_out = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_real_authen_out, dis_fake_authen_out) dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_real_embed, cls_real_proxies, real_cls_mask, real_labels, t) if self.consistency_reg: images_aug = images_aug.to(self.second_device) _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model( images_aug, real_labels) consistency_loss = self.l2_loss( dis_real_authen_out, dis_real_aug_authen_out) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.d_spectral_norm: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) fake_cls_mask = make_mask(fake_labels, self.num_classes, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_fake_authen_out) gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_fake_embed, cls_fake_proxies, fake_cls_mask, fake_labels, t) gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.ema: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature=t, dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.g_spectral_norm: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) with torch.no_grad(): if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: self.gen_model.eval() generator = self.gen_model generated_images = generator(self.fixed_noise, self.fixed_fake_labels) self.writer.add_images('Generated samples', (generated_images + 1) / 2, step_count) self.gen_model.train() if step_count % self.save_every == 0 or step_count == total_step: if self.evaluate: is_best = self.evaluation(step_count) self.save(step_count, is_best) else: self.save(step_count, False)
def __init__(self, run_name, logger, writer, n_gpus, gen_model, dis_model, inception_model, Gen_copy, Gen_ema, train_dataloader, evaluation_dataloader, G_loss, D_loss, auxiliary_classifier, contrastive_training, softmax_posterior, contrastive_softmax, hyper_dim, contrastive_lambda, tempering, discrete_tempering, tempering_times, start_temperature, end_temperature, gradient_penalty_for_dis, lambda4lp, lambda4gp, weight_clipping_for_dis, weight_clipping_bound, latent_op, latent_op_rate, latent_op_step, latent_op_step4eval, latent_op_alpha, latent_op_beta, latent_norm_reg_weight, consistency_reg, consistency_lambda, make_positive_aug, G_optimizer, D_optimizer, default_device, second_device, batch_size, z_dim, num_classes, truncated_factor, prior, g_steps_per_iter, d_steps_per_iter, accumulation_steps, lambda4ortho, print_every, save_every, checkpoint_dir, evaluate, mu, sigma, best_val_fid, best_checkpoint_fid_path, best_val_is, best_checkpoint_is_path, config): self.run_name = run_name self.logger = logger self.writer = writer self.n_gpus = n_gpus self.gen_model = gen_model self.dis_model = dis_model self.inception_model = inception_model self.Gen_copy = Gen_copy self.Gen_ema = Gen_ema self.train_dataloader = train_dataloader self.evaluation_dataloader = evaluation_dataloader self.G_loss = G_loss self.D_loss = D_loss self.auxiliary_classifier = auxiliary_classifier self.contrastive_training = contrastive_training self.softmax_posterior = softmax_posterior self.contrastive_softmax = contrastive_softmax self.hyper_dim = hyper_dim self.contrastive_lambda = contrastive_lambda self.tempering = tempering self.discrete_tempering = discrete_tempering self.tempering_times = tempering_times self.start_temperature = start_temperature self.end_temperature = end_temperature self.gradient_penalty_for_dis = gradient_penalty_for_dis self.lambda4lp = lambda4lp self.lambda4gp = lambda4gp self.weight_clipping_for_dis = weight_clipping_for_dis self.weight_clipping_bound = weight_clipping_bound self.latent_op = latent_op self.latent_op_rate = latent_op_rate self.latent_op_step = latent_op_step self.latent_op_step4eval = latent_op_step4eval self.latent_op_alpha = latent_op_alpha self.latent_op_beta = latent_op_beta self.latent_norm_reg_weight = latent_norm_reg_weight self.consistency_reg = consistency_reg self.consistency_lambda = consistency_lambda self.make_positive_aug = make_positive_aug self.G_optimizer = G_optimizer self.D_optimizer = D_optimizer self.default_device = default_device self.second_device = second_device self.batch_size = batch_size self.z_dim = z_dim self.num_classes = num_classes self.truncated_factor = truncated_factor self.prior = prior self.g_steps_per_iter = g_steps_per_iter self.d_steps_per_iter = d_steps_per_iter self.accumulation_steps = accumulation_steps self.lambda4ortho = lambda4ortho self.print_every = print_every self.save_every = save_every self.checkpoint_dir = checkpoint_dir self.config = config self.start_time = datetime.now() self.l2_loss = torch.nn.MSELoss() self.ce_loss = torch.nn.CrossEntropyLoss() if self.softmax_posterior: self.ce_criterion = Cross_Entropy_loss( self.hyper_dim, self.num_classes, self.config['d_spectral_norm']).to(self.second_device) if self.contrastive_softmax: self.contrastive_criterion = Conditional_Embedding_Contrastive_loss( self.second_device, self.batch_size) fixed_feed = next(iter(self.train_dataloader)) self.fixed_images, self.fixed_real_labels = fixed_feed[0].to( self.second_device), fixed_feed[1].to(self.second_device) self.fixed_noise, self.fixed_fake_labels = sample_latents( self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) self.evaluate = evaluate self.mu = mu self.sigma = sigma self.best_val_fid = best_val_fid self.best_val_is = best_val_is self.best_checkpoint_fid_path = best_checkpoint_fid_path self.best_checkpoint_is_path = best_checkpoint_is_path
def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) _, cls_out_real, dis_out_real = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.auxiliary_classifier: dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += self.lambda4gp * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) _, _, dis_out_real_aug = self.dis_model( images_aug, real_labels) consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.config['d_spectral_norm']: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_D_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_D_update': l2_norm_grad_z_aft_D_update.mean().item() }, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_out_fake) if self.auxiliary_classifier: gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[ 'ortho_reg']: if isinstance(self.gen_model, DataParallel): ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.module.shared.parameters() ]) else: ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.shared.parameters() ]) self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.config['ema']: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.config['g_spectral_norm']: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_G_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_G_update': l2_norm_grad_z_aft_G_update.mean().item() }, step_count) if step_count % self.save_every == 0 or step_count == total_step: self.valid_and_save(step_count)
def run_ours(self, current_step, total_step): if self.tempering and self.discrete_tempering: temperature_range = self.end_temperature - self.start_temperature temperature_change = temperature_range / self.tempering_times temperatures = [ self.start_temperature + time * temperature_change for time in range(self.tempering_times + 1) ] temperatures += [ self.start_temperature + self.tempering_times * temperature_change ] interval = total_step // (self.tempering_times + 1) self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) cls_real_aug_embed = None while step_count <= total_step: # ================== TRAIN D ================== # if self.tempering and not self.discrete_tempering: t = self.start_temperature + step_count * ( self.end_temperature - self.start_temperature) / total_step elif self.tempering and self.discrete_tempering: t = temperatures[step_count // interval] else: t = self.start_temperature for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() dis_loss = 0 for acml_step in range(self.accumulation_steps): try: if self.consistency_reg or self.make_positive_aug: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg or self.make_positive_aug: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) real_cls_mask = make_mask(real_labels, self.num_classes, self.second_device) cls_real_anchor, cls_real_embed, dis_real_authen_out = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_real_authen_out, dis_fake_authen_out) if self.consistency_reg or self.make_positive_aug: images_aug = images_aug.to(self.second_device) _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model( images_aug, real_labels) if self.consistency_reg: consistency_loss = self.l2_loss( dis_real_authen_out, dis_real_aug_authen_out) dis_acml_loss += self.consistency_lambda * consistency_loss if self.softmax_posterior: dis_acml_loss += self.ce_criterion( cls_real_embed, real_labels) elif self.contrastive_softmax: dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_real_embed, cls_real_anchor, real_cls_mask, real_labels, t, cls_real_aug_embed) dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() dis_loss += dis_acml_loss.item() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.config['d_spectral_norm']: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_D_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_D_update': l2_norm_grad_z_aft_D_update.mean().item() }, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() gen_loss = 0 for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) fake_cls_mask = make_mask(fake_labels, self.num_classes, self.second_device) fake_images = self.gen_model(z, fake_labels) cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_fake_authen_out) if self.softmax_posterior: gen_acml_loss += self.ce_criterion( cls_fake_embed, fake_labels) elif self.contrastive_softmax: gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_fake_embed, cls_fake_anchor, fake_cls_mask, fake_labels, t) gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() gen_loss += gen_acml_loss.item() if isinstance( self.lambda4ortho, float ) and self.lambda4ortho > 0 and self.config['ortho_reg']: if isinstance(self.gen_model, DataParallel): ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.module.shared.parameters() ]) else: ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.shared.parameters() ]) self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.config['ema']: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature=t, dis_loss=dis_loss, gen_loss=gen_loss, ) self.logger.info(log_message) if self.config['g_spectral_norm']: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_G_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_G_update': l2_norm_grad_z_aft_G_update.mean().item() }, step_count) if step_count % self.save_every == 0 or step_count == total_step: self.valid_and_save(step_count)
def calculate_acc_confidence(dataloader, generator, discriminator, G_loss, num_evaluate, truncated_factor, prior, latent_op, latent_op_step, latent_op_alpha, latent_op_beta, device): generator.eval() discriminator.eval() data_iter = iter(dataloader) batch_size = dataloader.batch_size if isinstance(generator, DataParallel): z_dim = generator.module.z_dim num_classes = generator.module.num_classes else: z_dim = generator.z_dim num_classes = generator.num_classes if num_evaluate % batch_size == 0: total_batch = num_evaluate // batch_size else: raise Exception("num_evaluate '%' batch4metrics should be 0!") if G_loss.__name__ == "loss_dcgan_gen": cutoff = 0.5 fake_target = 0.0 elif G_loss.__name__ == "loss_hinge_gen": cutoff = 0.0 fake_target = -1.0 elif G_loss.__name__ == "loss_wgan_gen": raise NotImplementedError for batch_id in tqdm(range(total_batch)): z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device) if latent_op: z = latent_optimise(z, fake_labels, generator, discriminator, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device) images_real, real_labels = next(data_iter) images_real, real_labels = images_real.to(device), real_labels.to( device) with torch.no_grad(): images_gen = generator(z, fake_labels) _, _, dis_out_fake = discriminator(images_gen, fake_labels) _, _, dis_out_real = discriminator(images_real, real_labels) dis_out_fake = dis_out_fake.detach().cpu().numpy() dis_out_real = dis_out_real.detach().cpu().numpy() if batch_id == 0: confid = np.concatenate((dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate( ([fake_target] * batch_size, [1.0] * batch_size), axis=0) else: confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0) confid_label = np.concatenate( (confid_label, [fake_target] * batch_size, [1.0] * batch_size), axis=0) real_confid = confid[confid_label == 1.0] fake_confid = confid[confid_label == fake_target] true_positive = real_confid[np.where(real_confid > cutoff)] true_negative = fake_confid[np.where(fake_confid < cutoff)] only_real_acc = len(true_positive) / len(real_confid) only_fake_acc = len(true_negative) / len(fake_confid) mixed_acc = (len(true_positive) + len(true_negative)) / len(confid) generator.train() discriminator.train() return only_real_acc, only_fake_acc, mixed_acc, confid, confid_label