Exemple #1
0
    def generate_images(self, gen, dis, truncated_factor, prior, latent_op,
                        latent_op_step, latent_op_alpha, latent_op_beta,
                        batch_size):
        if isinstance(gen, DataParallel):
            z_dim = gen.module.z_dim
            num_classes = gen.module.num_classes
            conditional_strategy = dis.module.conditional_strategy
        else:
            z_dim = gen.z_dim
            num_classes = gen.num_classes
            conditional_strategy = dis.conditional_strategy

        zs, fake_labels = sample_latents(prior, batch_size, z_dim,
                                         truncated_factor, num_classes, None,
                                         self.device)

        if latent_op:
            zs = latent_optimise(zs, fake_labels, gen, dis,
                                 conditional_strategy, latent_op_step, 1.0,
                                 latent_op_alpha, latent_op_beta, False,
                                 self.device)

        with torch.no_grad():
            batch_images = gen(zs, fake_labels, evaluation=True)

        return batch_images
Exemple #2
0
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model,
                            truncated_factor, prior, latent_op, latent_op_step,
                            latent_op_alpha, latent_op_beta, device):
    if isinstance(gen_model, DataParallel) or isinstance(
            gen_model, DistributedDataParallel):
        z_dim = gen_model.module.z_dim
        num_classes = gen_model.module.num_classes
        conditional_strategy = dis_model.module.conditional_strategy
    else:
        z_dim = gen_model.z_dim
        num_classes = gen_model.num_classes
        conditional_strategy = dis_model.conditional_strategy

    zs, fake_labels = sample_latents(prior, batch_size, z_dim,
                                     truncated_factor, num_classes, None,
                                     device, real_label)

    if latent_op:
        zs = latent_optimise(zs, fake_labels, gen_model, dis_model,
                             conditional_strategy, latent_op_step, 1.0,
                             latent_op_alpha, latent_op_beta, False, device)

    with torch.no_grad():
        batch_images = gen_model(zs, fake_labels, evaluation=True)

    return batch_images, list(fake_labels.detach().cpu().numpy())
Exemple #3
0
def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step,
                    latent_op_alpha, latent_op_beta, device):
    if isinstance(gen, DataParallel):
        z_dim = gen.module.z_dim
        num_classes = gen.module.num_classes
    else:
        z_dim = gen.z_dim
        num_classes = gen.num_classes
    
    z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)

    if latent_op:
        z = latent_optimise(z, fake_labels, gen, dis, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device)
    
    with torch.no_grad():
        batch_images = gen(z, fake_labels)

    return batch_images
Exemple #4
0
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model,
                            truncated_factor, prior, latent_op, latent_op_step,
                            latent_op_alpha, latent_op_beta, device):
    if isinstance(gen_model, DataParallel):
        z_dim = gen_model.module.z_dim
        num_classes = gen_model.module.num_classes
    else:
        z_dim = gen_model.z_dim
        num_classes = gen_model.num_classes

    z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor,
                                    num_classes, None, device, real_label)

    if latent_op:
        z = latent_optimise(z, fake_labels, gen_moidel, dis_model,
                            latent_op_step, 1.0, latent_op_alpha,
                            latent_op_beta, False, device)

    with torch.no_grad():
        batch_images = gen_model(z, fake_labels)

    return batch_images, list(fake_labels.detach().cpu().numpy())
Exemple #5
0
def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op,
                       latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False):
    data_iter = iter(dataloader)
    batch_size = dataloader.batch_size
    disable_tqdm = device != 0

    if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel):
        z_dim = generator.module.z_dim
        num_classes = generator.module.num_classes
        conditional_strategy = discriminator.module.conditional_strategy
    else:
        z_dim = generator.z_dim
        num_classes = generator.num_classes
        conditional_strategy = discriminator.conditional_strategy

    total_batch = num_evaluate//batch_size

    if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]:
        cutoff = 0.0
    elif D_loss.__name__ == "loss_hinge_dis":
        cutoff = 0.0
    elif D_loss.__name__ == "loss_wgan_dis":
        raise NotImplementedError

    if device == 0: logger.info("Calculate Accuracies....")

    if eval_generated_sample:
        for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
            zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)
            if latent_op:
                zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step,
                                     1.0, latent_op_alpha, latent_op_beta, False, device)

            real_images, real_labels = next(data_iter)
            real_images, real_labels = real_images.to(device), real_labels.to(device)

            fake_images = generator(zs, fake_labels, evaluation=True)

            with torch.no_grad():
                if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]:
                    _, _, dis_out_fake = discriminator(fake_images, fake_labels)
                    _, _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ACGAN":
                    _, dis_out_fake = discriminator(fake_images, fake_labels)
                    _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ProjGAN" or conditional_strategy == "no":
                    dis_out_fake = discriminator(fake_images, fake_labels)
                    dis_out_real = discriminator(real_images, real_labels)
                else:
                    raise NotImplementedError

                dis_out_fake = dis_out_fake.detach().cpu().numpy()
                dis_out_real = dis_out_real.detach().cpu().numpy()

            if batch_id == 0:
                confid = np.concatenate((dis_out_fake, dis_out_real), axis=0)
                confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)
            else:
                confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0)
                confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)

        real_confid = confid[confid_label==1.0]
        fake_confid = confid[confid_label==0.0]

        true_positive = real_confid[np.where(real_confid>cutoff)]
        true_negative = fake_confid[np.where(fake_confid<cutoff)]

        only_real_acc = len(true_positive)/len(real_confid)
        only_fake_acc = len(true_negative)/len(fake_confid)

        return only_real_acc, only_fake_acc
    else:
        for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
            real_images, real_labels = next(data_iter)
            real_images, real_labels = real_images.to(device), real_labels.to(device)

            with torch.no_grad():
                if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]:
                    _, _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ACGAN":
                    _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ProjGAN" or conditional_strategy == "no":
                    dis_out_real = discriminator(real_images, real_labels)
                else:
                    raise NotImplementedError

                dis_out_real = dis_out_real.detach().cpu().numpy()

            if batch_id == 0:
                confid = dis_out_real
                confid_label = np.asarray([1.0]*len(dis_out_real), np.float32)
            else:
                confid = np.concatenate((confid, dis_out_real), axis=0)
                confid_label = np.concatenate((confid_label, [1.0]*len(dis_out_real)), axis=0)

        real_confid = confid[confid_label==1.0]
        true_positive = real_confid[np.where(real_confid>cutoff)]
        only_real_acc = len(true_positive)/len(real_confid)

        return only_real_acc
Exemple #6
0
    def evaluation(self, step):
        with torch.no_grad():
            self.logger.info(
                "Start Evaluation ({step} Step): {run_name}".format(
                    step=step, run_name=self.run_name))
            is_best = False

            self.dis_model.eval()
            self.gen_model.eval()
            if self.Gen_copy is not None:
                self.Gen_copy.train()
                generator = self.Gen_copy
            else:
                generator = self.gen_model

            if self.dataset_name == "imagenet" or self.dataset_name == "tiny_imagenet":
                num_eval = {'train': 50000, 'valid': 50000}
            elif self.dataset_name == "cifar10":
                num_eval = {'train': 50000, 'test': 10000}

            if self.latent_op:
                self.fixed_noise = latent_optimise(
                    self.fixed_noise, self.fixed_fake_labels, generator,
                    self.dis_model, self.latent_op_step, self.latent_op_rate,
                    self.latent_op_alpha, self.latent_op_beta, False,
                    self.second_device)

            fid_score, self.m1, self.s1 = calculate_fid_score(
                self.eval_dataloader, generator, self.dis_model,
                self.inception_model, num_eval[self.type4eval_dataset],
                self.truncated_factor, self.prior, self.latent_op,
                self.latent_op_step4eval, self.latent_op_alpha,
                self.latent_op_beta, self.second_device, self.mu, self.sigma,
                self.run_name)

            ### pre-calculate an inception score
            ### calculating inception score using the below will give you an underestimated one.
            ### plz use the official tensorflow implementation(inception_tensorflow.py).
            kl_score, kl_std = calculate_incep_score(
                self.eval_dataloader, generator, self.dis_model,
                self.inception_model, num_eval[self.type4eval_dataset],
                self.truncated_factor, self.prior, self.latent_op,
                self.latent_op_step4eval, self.latent_op_alpha,
                self.latent_op_beta, 10, self.second_device)

            real_train_acc, fake_acc = calculate_accuracy(
                self.train_dataloader,
                generator,
                self.dis_model,
                self.D_loss,
                10000,
                self.truncated_factor,
                self.prior,
                self.latent_op,
                self.latent_op_step,
                self.latent_op_alpha,
                self.latent_op_beta,
                self.second_device,
                eval_generated_sample=True)

            if self.type4eval_dataset == 'train':
                acc_dict = {'real_train': real_train_acc, 'fake': fake_acc}
            else:
                real_eval_acc = calculate_accuracy(self.eval_dataloader,
                                                   generator,
                                                   self.dis_model,
                                                   self.D_loss,
                                                   10000,
                                                   self.truncated_factor,
                                                   self.prior,
                                                   self.latent_op,
                                                   self.latent_op_step,
                                                   self.latent_op_alpha,
                                                   self.latent_op_beta,
                                                   self.second_device,
                                                   eval_generated_sample=False)
                acc_dict = {
                    'real_train': real_train_acc,
                    'real_valid': real_eval_acc,
                    'fake': fake_acc
                }

            self.writer.add_scalars('Accuracy', acc_dict, step)

            if self.best_fid is None:
                self.best_fid, self.best_step, is_best = fid_score, step, True
            else:
                if fid_score <= self.best_fid:
                    self.best_fid, self.best_step, is_best = fid_score, step, True

            self.writer.add_scalars(
                'FID score', {
                    'using {type} moments'.format(type=self.type4eval_dataset):
                    fid_score
                }, step)
            self.writer.add_scalars(
                'IS score', {
                    '{num} generated images'.format(num=str(num_eval[self.type4eval_dataset])):
                    kl_score
                }, step)
            self.logger.info(
                'FID score (Step: {step}, Using {type} moments): {FID}'.format(
                    step=step, type=self.type4eval_dataset, FID=fid_score))
            self.logger.info(
                'Inception score (Step: {step}, {num} generated images): {IS}'.
                format(step=step,
                       num=str(num_eval[self.type4eval_dataset]),
                       IS=kl_score))
            self.logger.info(
                'Best FID score (Step: {step}, Using {type} moments): {FID}'.
                format(step=self.best_step,
                       type=self.type4eval_dataset,
                       FID=self.best_fid))

            self.dis_model.train()
            self.gen_model.train()
            if self.Gen_copy is not None:
                self.Gen_copy.train()

        return is_best
Exemple #7
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    if self.diff_aug:
                        images = DiffAugment(images, policy=self.policy)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_real, dis_out_real = self.dis_model(
                            images, real_labels)
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_real = self.dis_model(images, real_labels)
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)

                    if self.conditional_strategy == "ACGAN":
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += gradient_penelty_lambda * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        if self.conditional_strategy == "ACGAN":
                            cls_out_real_aug, dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                            dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        else:
                            raise NotImplementedError
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.d_spectral_norm:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.conditional_strategy == "ACGAN":
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.ema:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.g_spectral_norm:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                with torch.no_grad():
                    if self.Gen_copy is not None:
                        self.Gen_copy.train()
                        generator = self.Gen_copy
                    else:
                        self.gen_model.eval()
                        generator = self.gen_model

                    generated_images = generator(self.fixed_noise,
                                                 self.fixed_fake_labels)
                    self.writer.add_images('Generated samples',
                                           (generated_images + 1) / 2,
                                           step_count)
                    self.gen_model.train()

            if step_count % self.save_every == 0 or step_count == total_step:
                if self.evaluate:
                    is_best = self.evaluation(step_count)
                    self.save(step_count, is_best)
                else:
                    self.save(step_count, False)
def generate_images(z_prior, truncation_factor, batch_size, z_dim, num_classes,
                    y_sampler, radius, generator, discriminator, is_train,
                    LOSS, RUN, MODEL, device, is_stylegan, generator_mapping,
                    generator_synthesis, style_mixing_p, stylegan_update_emas,
                    cal_trsp_cost):
    if is_train:
        truncation_factor = -1.0
        lo_steps = LOSS.lo_steps4train
        apply_langevin = False
    else:
        lo_steps = LOSS.lo_steps4eval
        if truncation_factor != -1:
            if is_stylegan:
                assert 0 <= truncation_factor <= 1, "Stylegan truncation_factor must lie btw 0(strong truncation) ~ 1(no truncation)"
            else:
                assert 0 <= truncation_factor, "truncation_factor must lie btw 0(strong truncation) ~ inf(no truncation)"

    zs, fake_labels, zs_eps = sample_zy(
        z_prior=z_prior,
        batch_size=batch_size,
        z_dim=z_dim,
        num_classes=num_classes,
        truncation_factor=-1 if is_stylegan else truncation_factor,
        y_sampler=y_sampler,
        radius=radius,
        device=device)
    info_discrete_c, info_conti_c = None, None
    if MODEL.info_type in ["discrete", "both"]:
        info_discrete_c = torch.randint(
            MODEL.info_dim_discrete_c, (batch_size, MODEL.info_num_discrete_c),
            device=device)
        zs = torch.cat(
            (zs, F.one_hot(info_discrete_c, MODEL.info_dim_discrete_c).view(
                batch_size, -1)),
            dim=1)
    if MODEL.info_type in ["continuous", "both"]:
        info_conti_c = torch.rand(
            batch_size, MODEL.info_num_conti_c, device=device) * 2 - 1
        zs = torch.cat((zs, info_conti_c), dim=1)

    trsp_cost = None
    if LOSS.apply_lo:
        zs, trsp_cost = losses.latent_optimise(zs=zs,
                                               fake_labels=fake_labels,
                                               generator=generator,
                                               discriminator=discriminator,
                                               batch_size=batch_size,
                                               lo_rate=LOSS.lo_rate,
                                               lo_steps=lo_steps,
                                               lo_alpha=LOSS.lo_alpha,
                                               lo_beta=LOSS.lo_beta,
                                               eval=not is_train,
                                               cal_trsp_cost=cal_trsp_cost,
                                               device=device)
    if not is_train and RUN.langevin_sampling:
        zs = langevin_sampling(zs=zs,
                               z_dim=z_dim,
                               fake_labels=fake_labels,
                               generator=generator,
                               discriminator=discriminator,
                               batch_size=batch_size,
                               langevin_rate=RUN.langevin_rate,
                               langevin_noise_std=RUN.langevin_noise_std,
                               langevin_decay=RUN.langevin_decay,
                               langevin_decay_steps=RUN.langevin_decay_steps,
                               langevin_steps=RUN.langevin_steps,
                               device=device)
    if is_stylegan:
        ws, fake_images = stylegan_generate_images(
            zs=zs,
            fake_labels=fake_labels,
            num_classes=num_classes,
            style_mixing_p=style_mixing_p,
            update_emas=stylegan_update_emas,
            generator_mapping=generator_mapping,
            generator_synthesis=generator_synthesis,
            truncation_psi=truncation_factor,
            truncation_cutoff=RUN.truncation_cutoff)
    else:
        fake_images = generator(zs, fake_labels, eval=not is_train)
        ws = None

    if zs_eps is not None:
        if is_stylegan:
            ws_eps, fake_images_eps = stylegan_generate_images(
                zs=zs_eps,
                fake_labels=fake_labels,
                num_classes=num_classes,
                style_mixing_p=style_mixing_p,
                update_emas=stylegan_update_emas,
                generator_mapping=generator_mapping,
                generator_synthesis=generator_synthesis,
                truncation_psi=truncation_factor,
                truncation_cutoff=RUN.truncation_cutoff)
        else:
            _, fake_images_eps = generator(zs_eps,
                                           fake_labels,
                                           eval=not is_train)
    else:
        fake_images_eps = None
    return fake_images, fake_labels, fake_images_eps, trsp_cost, ws, info_discrete_c, info_conti_c
Exemple #9
0
    def valid_and_save(self, step):
        self.dis_model.eval()
        self.gen_model.eval()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
            generator = self.Gen_copy
        else:
            generator = self.gen_model
        if self.evaluate:
            if self.latent_op:
                self.fixed_noise = latent_optimise(
                    self.fixed_noise, self.fixed_fake_labels, generator,
                    self.dis_model, self.latent_op_step, self.latent_op_rate,
                    self.latent_op_alpha, self.latent_op_beta, False,
                    self.second_device)

            fake_images = generator(self.fixed_noise,
                                    self.fixed_fake_labels).detach().cpu()
            plot_generated_samples_path = join(
                'figures', self.run_name,
                "[{}]generated_samples.png".format(step))
            plot_img_canvas(fake_images, plot_generated_samples_path,
                            self.logger)
            self.writer.add_images('Generated samples', (fake_images + 1) / 2,
                                   step)

            fid_score, self.m1, self.s1 = calculate_fid_score(
                self.evaluation_dataloader, generator,
                self.dis_model, self.inception_model,
                len(self.evaluation_dataloader.dataset), self.truncated_factor,
                self.prior, self.latent_op, self.latent_op_step4eval,
                self.latent_op_alpha, self.latent_op_beta, self.second_device,
                self.mu, self.sigma)

            ### pre-calculate an inception score
            ### calculating inception score using the below will give you an underestimated one.
            ### plz use the official tensorflow implementation(inception_tensorflow.py).
            kl_score, kl_std = calculate_incep_score(
                self.evaluation_dataloader, generator, self.dis_model,
                self.inception_model, 10000, self.truncated_factor, self.prior,
                self.latent_op, self.latent_op_step4eval, self.latent_op_alpha,
                self.latent_op_beta, 2, self.second_device)

            self.writer.add_scalars('FID_score',
                                    {'using_train_moments': fid_score}, step)
            self.writer.add_scalars('IS_score',
                                    {'50000_generated_images': kl_score}, step)
        else:
            kl_score = "xxx"
            kl_std = "xxx"
            fid_score = "xxx"

        checkpoint_name = SAVE_FORMAT.format(
            step=step,
            Inception_mean=kl_score,
            Inception_std=kl_std,
            FID=fid_score,
        )

        g_states = {
            'seed': self.config['seed'],
            'run_name': self.run_name,
            'step': step,
            'state_dict': self.gen_model.state_dict(),
            'optimizer': self.G_optimizer.state_dict(),
        }

        d_states = {
            'seed': self.config['seed'],
            'run_name': self.run_name,
            'step': step,
            'state_dict': self.dis_model.state_dict(),
            'optimizer': self.D_optimizer.state_dict(),
            'best_val_fid': self.best_val_fid,
            'best_checkpoint_fid_path': self.best_checkpoint_fid_path,
            'best_val_is': self.best_val_is,
            'best_checkpoint_is_path': self.best_checkpoint_is_path,
        }

        checkpoint_output_path = join(self.checkpoint_dir, checkpoint_name)
        g_checkpoint_output_path = join(self.checkpoint_dir,
                                        "model=G-" + checkpoint_name)
        d_checkpoint_output_path = join(self.checkpoint_dir,
                                        "model=D-" + checkpoint_name)

        torch.save(g_states, g_checkpoint_output_path)
        torch.save(d_states, d_checkpoint_output_path)

        if self.Gen_copy is not None:
            g_ema_states = {'state_dict': self.Gen_copy.state_dict()}
            g_ema_checkpoint_output_path = join(
                self.checkpoint_dir, "model=G_ema-" + checkpoint_name)
            torch.save(g_ema_states, g_ema_checkpoint_output_path)

        if self.evaluate:
            representative_val_fid = fid_score
            if self.best_val_fid is None or self.best_val_fid > representative_val_fid:
                self.best_val_fid = representative_val_fid
                self.best_checkpoint_fid_path = checkpoint_output_path

            representative_val_is = kl_score
            if self.best_val_is is None or self.best_val_is < representative_val_is:
                self.best_val_is = representative_val_is
                self.best_checkpoint_is_path = checkpoint_output_path

        if self.logger:
            self.logger.info(
                "Saved model to {}".format(checkpoint_output_path))

        if self.evaluate:
            self.logger.info("Current best model(FID) is {}".format(
                self.best_checkpoint_fid_path))
            self.logger.info("Current best model(IS) is {}".format(
                self.best_checkpoint_is_path))

        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
Exemple #10
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    _, cls_out_real, dis_out_real = self.dis_model(
                        images, real_labels)
                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)
                    if self.auxiliary_classifier:
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += self.lambda4gp * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        _, _, dis_out_real_aug = self.dis_model(
                            images_aug, real_labels)
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.config['d_spectral_norm']:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_D_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_D_update':
                            l2_norm_grad_z_aft_D_update.mean().item()
                        }, step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.auxiliary_classifier:
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[
                        'ortho_reg']:
                    if isinstance(self.gen_model, DataParallel):
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.module.shared.parameters()
                              ])
                    else:
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.shared.parameters()
                              ])

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.config['ema']:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.config['g_spectral_norm']:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                self.writer.add_images('Generated samples',
                                       (fake_images + 1) / 2, step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_G_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_G_update':
                            l2_norm_grad_z_aft_G_update.mean().item()
                        }, step_count)

            if step_count % self.save_every == 0 or step_count == total_step:
                self.valid_and_save(step_count)
Exemple #11
0
def calculate_acc_confidence(dataloader, generator, discriminator, G_loss,
                             num_evaluate, truncated_factor, prior, latent_op,
                             latent_op_step, latent_op_alpha, latent_op_beta,
                             device):
    generator.eval()
    discriminator.eval()
    data_iter = iter(dataloader)
    batch_size = dataloader.batch_size

    if isinstance(generator, DataParallel):
        z_dim = generator.module.z_dim
        num_classes = generator.module.num_classes
    else:
        z_dim = generator.z_dim
        num_classes = generator.num_classes

    if num_evaluate % batch_size == 0:
        total_batch = num_evaluate // batch_size
    else:
        raise Exception("num_evaluate '%' batch4metrics should be 0!")

    if G_loss.__name__ == "loss_dcgan_gen":
        cutoff = 0.5
        fake_target = 0.0
    elif G_loss.__name__ == "loss_hinge_gen":
        cutoff = 0.0
        fake_target = -1.0
    elif G_loss.__name__ == "loss_wgan_gen":
        raise NotImplementedError

    for batch_id in tqdm(range(total_batch)):
        z, fake_labels = sample_latents(prior, batch_size, z_dim,
                                        truncated_factor, num_classes, None,
                                        device)
        if latent_op:
            z = latent_optimise(z, fake_labels, generator, discriminator,
                                latent_op_step, 1.0, latent_op_alpha,
                                latent_op_beta, False, device)

        images_real, real_labels = next(data_iter)
        images_real, real_labels = images_real.to(device), real_labels.to(
            device)

        with torch.no_grad():
            images_gen = generator(z, fake_labels)
            _, _, dis_out_fake = discriminator(images_gen, fake_labels)
            _, _, dis_out_real = discriminator(images_real, real_labels)
            dis_out_fake = dis_out_fake.detach().cpu().numpy()
            dis_out_real = dis_out_real.detach().cpu().numpy()

        if batch_id == 0:
            confid = np.concatenate((dis_out_fake, dis_out_real), axis=0)
            confid_label = np.concatenate(
                ([fake_target] * batch_size, [1.0] * batch_size), axis=0)
        else:
            confid = np.concatenate((confid, dis_out_fake, dis_out_real),
                                    axis=0)
            confid_label = np.concatenate(
                (confid_label, [fake_target] * batch_size, [1.0] * batch_size),
                axis=0)

    real_confid = confid[confid_label == 1.0]
    fake_confid = confid[confid_label == fake_target]

    true_positive = real_confid[np.where(real_confid > cutoff)]
    true_negative = fake_confid[np.where(fake_confid < cutoff)]

    only_real_acc = len(true_positive) / len(real_confid)
    only_fake_acc = len(true_negative) / len(fake_confid)
    mixed_acc = (len(true_positive) + len(true_negative)) / len(confid)

    generator.train()
    discriminator.train()

    return only_real_acc, only_fake_acc, mixed_acc, confid, confid_label