def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) if self.diff_aug: images = DiffAugment(images, policy=self.policy) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_real, dis_out_real = self.dis_model( images, real_labels) cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real = self.dis_model(images, real_labels) dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.conditional_strategy == "ACGAN": dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += gradient_penelty_lambda * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) if self.conditional_strategy == "ACGAN": cls_out_real_aug, dis_out_real_aug = self.dis_model( images_aug, real_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_real_aug = self.dis_model( images_aug, real_labels) else: raise NotImplementedError consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.d_spectral_norm: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) if self.conditional_strategy == "ACGAN": cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no": dis_out_fake = self.dis_model(fake_images, fake_labels) else: raise NotImplementedError gen_acml_loss = self.G_loss(dis_out_fake) if self.conditional_strategy == "ACGAN": gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.ema: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.g_spectral_norm: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) with torch.no_grad(): if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: self.gen_model.eval() generator = self.gen_model generated_images = generator(self.fixed_noise, self.fixed_fake_labels) self.writer.add_images('Generated samples', (generated_images + 1) / 2, step_count) self.gen_model.train() if step_count % self.save_every == 0 or step_count == total_step: if self.evaluate: is_best = self.evaluation(step_count) self.save(step_count, is_best) else: self.save(step_count, False)
def run(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_index in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, False, self.second_device) _, cls_out_real, dis_out_real = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake) if self.auxiliary_classifier: dis_acml_loss += ( self.ce_loss(cls_out_real, real_labels) + self.ce_loss(cls_out_fake, fake_labels)) if self.gradient_penalty_for_dis: dis_acml_loss += self.lambda4gp * calc_derv4gp( self.dis_model, images, fake_images, real_labels, self.second_device) if self.consistency_reg: images_aug = images_aug.to(self.second_device) _, _, dis_out_real_aug = self.dis_model( images_aug, real_labels) consistency_loss = self.l2_loss( dis_out_real, dis_out_real_aug) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.config['d_spectral_norm']: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_D_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_D_update': l2_norm_grad_z_aft_D_update.mean().item() }, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) if self.latent_op: z, transport_cost = latent_optimise( z, fake_labels, self.gen_model, self.dis_model, self.latent_op_step, self.latent_op_rate, self.latent_op_alpha, self.latent_op_beta, True, self.second_device) fake_images = self.gen_model(z, fake_labels) _, cls_out_fake, dis_out_fake = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_out_fake) if self.auxiliary_classifier: gen_acml_loss += self.ce_loss(cls_out_fake, fake_labels) if self.latent_op: gen_acml_loss += transport_cost * self.latent_norm_reg_weight gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[ 'ortho_reg']: if isinstance(self.gen_model, DataParallel): ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.module.shared.parameters() ]) else: ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.shared.parameters() ]) self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.config['ema']: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature='No', dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.config['g_spectral_norm']: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_G_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_G_update': l2_norm_grad_z_aft_G_update.mean().item() }, step_count) if step_count % self.save_every == 0 or step_count == total_step: self.valid_and_save(step_count)
def run_ours(self, current_step, total_step): self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) while step_count <= total_step: # ================== TRAIN D ================== # if self.tempering_type == 'continuous': t = self.start_temperature + step_count * ( self.end_temperature - self.start_temperature) / total_step elif self.tempering_type == 'discrete': t = self.start_temperature + \ (step_count//self.tempering_interval)*(self.end_temperature-self.start_temperature)/self.tempering_step else: t = self.start_temperature for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): try: if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) if self.diff_aug: images = DiffAugment(images, policy=self.policy) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) real_cls_mask = make_mask(real_labels, self.num_classes, self.second_device) cls_real_proxies, cls_real_embed, dis_real_authen_out = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_real_authen_out, dis_fake_authen_out) dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_real_embed, cls_real_proxies, real_cls_mask, real_labels, t) if self.consistency_reg: images_aug = images_aug.to(self.second_device) _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model( images_aug, real_labels) consistency_loss = self.l2_loss( dis_real_authen_out, dis_real_aug_authen_out) dis_acml_loss += self.consistency_lambda * consistency_loss dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.d_spectral_norm: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) fake_cls_mask = make_mask(fake_labels, self.num_classes, self.second_device) fake_images = self.gen_model(z, fake_labels) if self.diff_aug: fake_images = DiffAugment(fake_images, policy=self.policy) cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_fake_authen_out) gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_fake_embed, cls_fake_proxies, fake_cls_mask, fake_labels, t) gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.ema: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature=t, dis_loss=dis_acml_loss.item(), gen_loss=gen_acml_loss.item(), ) self.logger.info(log_message) if self.g_spectral_norm: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) with torch.no_grad(): if self.Gen_copy is not None: self.Gen_copy.train() generator = self.Gen_copy else: self.gen_model.eval() generator = self.gen_model generated_images = generator(self.fixed_noise, self.fixed_fake_labels) self.writer.add_images('Generated samples', (generated_images + 1) / 2, step_count) self.gen_model.train() if step_count % self.save_every == 0 or step_count == total_step: if self.evaluate: is_best = self.evaluation(step_count) self.save(step_count, is_best) else: self.save(step_count, False)
def run_ours(self, current_step, total_step): if self.tempering and self.discrete_tempering: temperature_range = self.end_temperature - self.start_temperature temperature_change = temperature_range / self.tempering_times temperatures = [ self.start_temperature + time * temperature_change for time in range(self.tempering_times + 1) ] temperatures += [ self.start_temperature + self.tempering_times * temperature_change ] interval = total_step // (self.tempering_times + 1) self.dis_model.train() self.gen_model.train() if self.Gen_copy is not None: self.Gen_copy.train() step_count = current_step train_iter = iter(self.train_dataloader) cls_real_aug_embed = None while step_count <= total_step: # ================== TRAIN D ================== # if self.tempering and not self.discrete_tempering: t = self.start_temperature + step_count * ( self.end_temperature - self.start_temperature) / total_step elif self.tempering and self.discrete_tempering: t = temperatures[step_count // interval] else: t = self.start_temperature for step_index in range(self.d_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() dis_loss = 0 for acml_step in range(self.accumulation_steps): try: if self.consistency_reg or self.make_positive_aug: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) except StopIteration: train_iter = iter(self.train_dataloader) if self.consistency_reg or self.make_positive_aug: images, real_labels, images_aug = next(train_iter) else: images, real_labels = next(train_iter) images, real_labels = images.to( self.second_device), real_labels.to(self.second_device) z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) real_cls_mask = make_mask(real_labels, self.num_classes, self.second_device) cls_real_anchor, cls_real_embed, dis_real_authen_out = self.dis_model( images, real_labels) fake_images = self.gen_model(z, fake_labels) cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) dis_acml_loss = self.D_loss(dis_real_authen_out, dis_fake_authen_out) if self.consistency_reg or self.make_positive_aug: images_aug = images_aug.to(self.second_device) _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model( images_aug, real_labels) if self.consistency_reg: consistency_loss = self.l2_loss( dis_real_authen_out, dis_real_aug_authen_out) dis_acml_loss += self.consistency_lambda * consistency_loss if self.softmax_posterior: dis_acml_loss += self.ce_criterion( cls_real_embed, real_labels) elif self.contrastive_softmax: dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_real_embed, cls_real_anchor, real_cls_mask, real_labels, t, cls_real_aug_embed) dis_acml_loss = dis_acml_loss / self.accumulation_steps dis_acml_loss.backward() dis_loss += dis_acml_loss.item() self.D_optimizer.step() if self.weight_clipping_for_dis: for p in self.dis_model.parameters(): p.data.clamp_(-self.weight_clipping_bound, self.weight_clipping_bound) if step_count % self.print_every == 0 and step_count != 0 and self.logger: if self.config['d_spectral_norm']: dis_sigmas = calculate_all_sn(self.dis_model) self.writer.add_scalars('SN_of_dis', dis_sigmas, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_D_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_D_update': l2_norm_grad_z_aft_D_update.mean().item() }, step_count) # ================== TRAIN G ================== # for step_index in range(self.g_steps_per_iter): self.D_optimizer.zero_grad() self.G_optimizer.zero_grad() gen_loss = 0 for acml_step in range(self.accumulation_steps): z, fake_labels = sample_latents(self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None, self.second_device) fake_cls_mask = make_mask(fake_labels, self.num_classes, self.second_device) fake_images = self.gen_model(z, fake_labels) cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model( fake_images, fake_labels) gen_acml_loss = self.G_loss(dis_fake_authen_out) if self.softmax_posterior: gen_acml_loss += self.ce_criterion( cls_fake_embed, fake_labels) elif self.contrastive_softmax: gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion( cls_fake_embed, cls_fake_anchor, fake_cls_mask, fake_labels, t) gen_acml_loss = gen_acml_loss / self.accumulation_steps gen_acml_loss.backward() gen_loss += gen_acml_loss.item() if isinstance( self.lambda4ortho, float ) and self.lambda4ortho > 0 and self.config['ortho_reg']: if isinstance(self.gen_model, DataParallel): ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.module.shared.parameters() ]) else: ortho(self.gen_model, self.lambda4ortho, blacklist=[ param for param in self.gen_model.shared.parameters() ]) self.G_optimizer.step() # if ema is True: we update parameters of the Gen_copy in adaptive way. if self.config['ema']: self.Gen_ema.update(step_count) step_count += 1 if step_count % self.print_every == 0 and self.logger: log_message = LOG_FORMAT.format( step=step_count, progress=step_count / total_step, elapsed=elapsed_time(self.start_time), temperature=t, dis_loss=dis_loss, gen_loss=gen_loss, ) self.logger.info(log_message) if self.config['g_spectral_norm']: gen_sigmas = calculate_all_sn(self.gen_model) self.writer.add_scalars('SN_of_gen', gen_sigmas, step_count) self.writer.add_scalars( 'Losses', { 'discriminator': dis_acml_loss.item(), 'generator': gen_acml_loss.item() }, step_count) self.writer.add_images('Generated samples', (fake_images + 1) / 2, step_count) if self.config['calculate_z_grad']: _, l2_norm_grad_z_aft_G_update = calc_derv( self.fixed_noise, self.fixed_fake_labels, self.dis_model, self.second_device, self.gen_model) self.writer.add_scalars( 'L2_norm_grad', { 'z_aft_G_update': l2_norm_grad_z_aft_G_update.mean().item() }, step_count) if step_count % self.save_every == 0 or step_count == total_step: self.valid_and_save(step_count)