示例#1
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    if self.diff_aug:
                        images = DiffAugment(images, policy=self.policy)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_real, dis_out_real = self.dis_model(
                            images, real_labels)
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_real = self.dis_model(images, real_labels)
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)

                    if self.conditional_strategy == "ACGAN":
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += gradient_penelty_lambda * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        if self.conditional_strategy == "ACGAN":
                            cls_out_real_aug, dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                            dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        else:
                            raise NotImplementedError
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.d_spectral_norm:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.conditional_strategy == "ACGAN":
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.ema:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.g_spectral_norm:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                with torch.no_grad():
                    if self.Gen_copy is not None:
                        self.Gen_copy.train()
                        generator = self.Gen_copy
                    else:
                        self.gen_model.eval()
                        generator = self.gen_model

                    generated_images = generator(self.fixed_noise,
                                                 self.fixed_fake_labels)
                    self.writer.add_images('Generated samples',
                                           (generated_images + 1) / 2,
                                           step_count)
                    self.gen_model.train()

            if step_count % self.save_every == 0 or step_count == total_step:
                if self.evaluate:
                    is_best = self.evaluation(step_count)
                    self.save(step_count, is_best)
                else:
                    self.save(step_count, False)
示例#2
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    _, cls_out_real, dis_out_real = self.dis_model(
                        images, real_labels)
                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)
                    if self.auxiliary_classifier:
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += self.lambda4gp * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        _, _, dis_out_real_aug = self.dis_model(
                            images_aug, real_labels)
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.config['d_spectral_norm']:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_D_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_D_update':
                            l2_norm_grad_z_aft_D_update.mean().item()
                        }, step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.auxiliary_classifier:
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[
                        'ortho_reg']:
                    if isinstance(self.gen_model, DataParallel):
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.module.shared.parameters()
                              ])
                    else:
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.shared.parameters()
                              ])

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.config['ema']:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.config['g_spectral_norm']:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                self.writer.add_images('Generated samples',
                                       (fake_images + 1) / 2, step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_G_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_G_update':
                            l2_norm_grad_z_aft_G_update.mean().item()
                        }, step_count)

            if step_count % self.save_every == 0 or step_count == total_step:
                self.valid_and_save(step_count)
示例#3
0
    def run_ours(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            if self.tempering_type == 'continuous':
                t = self.start_temperature + step_count * (
                    self.end_temperature - self.start_temperature) / total_step
            elif self.tempering_type == 'discrete':
                t = self.start_temperature + \
                    (step_count//self.tempering_interval)*(self.end_temperature-self.start_temperature)/self.tempering_step
            else:
                t = self.start_temperature
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    if self.diff_aug:
                        images = DiffAugment(images, policy=self.policy)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    real_cls_mask = make_mask(real_labels, self.num_classes,
                                              self.second_device)

                    cls_real_proxies, cls_real_embed, dis_real_authen_out = self.dis_model(
                        images, real_labels)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_real_authen_out,
                                                dis_fake_authen_out)
                    dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                        cls_real_embed, cls_real_proxies, real_cls_mask,
                        real_labels, t)
                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model(
                            images_aug, real_labels)
                        consistency_loss = self.l2_loss(
                            dis_real_authen_out, dis_real_aug_authen_out)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps
                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.d_spectral_norm:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    fake_cls_mask = make_mask(fake_labels, self.num_classes,
                                              self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_fake_authen_out)
                    gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                        cls_fake_embed, cls_fake_proxies, fake_cls_mask,
                        fake_labels, t)
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps
                    gen_acml_loss.backward()

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.ema:
                    self.Gen_ema.update(step_count)

                step_count += 1
            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature=t,
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.g_spectral_norm:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)

                with torch.no_grad():
                    if self.Gen_copy is not None:
                        self.Gen_copy.train()
                        generator = self.Gen_copy
                    else:
                        self.gen_model.eval()
                        generator = self.gen_model

                    generated_images = generator(self.fixed_noise,
                                                 self.fixed_fake_labels)
                    self.writer.add_images('Generated samples',
                                           (generated_images + 1) / 2,
                                           step_count)
                    self.gen_model.train()

            if step_count % self.save_every == 0 or step_count == total_step:
                if self.evaluate:
                    is_best = self.evaluation(step_count)
                    self.save(step_count, is_best)
                else:
                    self.save(step_count, False)
示例#4
0
    def run_ours(self, current_step, total_step):
        if self.tempering and self.discrete_tempering:
            temperature_range = self.end_temperature - self.start_temperature
            temperature_change = temperature_range / self.tempering_times
            temperatures = [
                self.start_temperature + time * temperature_change
                for time in range(self.tempering_times + 1)
            ]
            temperatures += [
                self.start_temperature +
                self.tempering_times * temperature_change
            ]
            interval = total_step // (self.tempering_times + 1)

        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        cls_real_aug_embed = None
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            if self.tempering and not self.discrete_tempering:
                t = self.start_temperature + step_count * (
                    self.end_temperature - self.start_temperature) / total_step
            elif self.tempering and self.discrete_tempering:
                t = temperatures[step_count // interval]
            else:
                t = self.start_temperature
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                dis_loss = 0
                for acml_step in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg or self.make_positive_aug:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg or self.make_positive_aug:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    real_cls_mask = make_mask(real_labels, self.num_classes,
                                              self.second_device)

                    cls_real_anchor, cls_real_embed, dis_real_authen_out = self.dis_model(
                        images, real_labels)

                    fake_images = self.gen_model(z, fake_labels)
                    cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_real_authen_out,
                                                dis_fake_authen_out)

                    if self.consistency_reg or self.make_positive_aug:
                        images_aug = images_aug.to(self.second_device)
                        _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model(
                            images_aug, real_labels)
                        if self.consistency_reg:
                            consistency_loss = self.l2_loss(
                                dis_real_authen_out, dis_real_aug_authen_out)
                            dis_acml_loss += self.consistency_lambda * consistency_loss

                    if self.softmax_posterior:
                        dis_acml_loss += self.ce_criterion(
                            cls_real_embed, real_labels)
                    elif self.contrastive_softmax:
                        dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                            cls_real_embed, cls_real_anchor, real_cls_mask,
                            real_labels, t, cls_real_aug_embed)
                    dis_acml_loss = dis_acml_loss / self.accumulation_steps
                    dis_acml_loss.backward()
                    dis_loss += dis_acml_loss.item()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.config['d_spectral_norm']:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_D_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_D_update':
                            l2_norm_grad_z_aft_D_update.mean().item()
                        }, step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                gen_loss = 0
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    fake_cls_mask = make_mask(fake_labels, self.num_classes,
                                              self.second_device)

                    fake_images = self.gen_model(z, fake_labels)

                    cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_fake_authen_out)
                    if self.softmax_posterior:
                        gen_acml_loss += self.ce_criterion(
                            cls_fake_embed, fake_labels)
                    elif self.contrastive_softmax:
                        gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                            cls_fake_embed, cls_fake_anchor, fake_cls_mask,
                            fake_labels, t)
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps
                    gen_acml_loss.backward()
                    gen_loss += gen_acml_loss.item()

                if isinstance(
                        self.lambda4ortho, float
                ) and self.lambda4ortho > 0 and self.config['ortho_reg']:
                    if isinstance(self.gen_model, DataParallel):
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.module.shared.parameters()
                              ])
                    else:
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.shared.parameters()
                              ])

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.config['ema']:
                    self.Gen_ema.update(step_count)

                step_count += 1
            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature=t,
                    dis_loss=dis_loss,
                    gen_loss=gen_loss,
                )
                self.logger.info(log_message)

                if self.config['g_spectral_norm']:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                self.writer.add_images('Generated samples',
                                       (fake_images + 1) / 2, step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_G_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_G_update':
                            l2_norm_grad_z_aft_G_update.mean().item()
                        }, step_count)

            if step_count % self.save_every == 0 or step_count == total_step:
                self.valid_and_save(step_count)