Exemplo n.º 1
0
    def generate_images(self, gen, dis, truncated_factor, prior, latent_op,
                        latent_op_step, latent_op_alpha, latent_op_beta,
                        batch_size):
        if isinstance(gen, DataParallel):
            z_dim = gen.module.z_dim
            num_classes = gen.module.num_classes
            conditional_strategy = dis.module.conditional_strategy
        else:
            z_dim = gen.z_dim
            num_classes = gen.num_classes
            conditional_strategy = dis.conditional_strategy

        zs, fake_labels = sample_latents(prior, batch_size, z_dim,
                                         truncated_factor, num_classes, None,
                                         self.device)

        if latent_op:
            zs = latent_optimise(zs, fake_labels, gen, dis,
                                 conditional_strategy, latent_op_step, 1.0,
                                 latent_op_alpha, latent_op_beta, False,
                                 self.device)

        with torch.no_grad():
            batch_images = gen(zs, fake_labels, evaluation=True)

        return batch_images
Exemplo n.º 2
0
def generate_images_for_KNN(batch_size, real_label, gen_model, dis_model,
                            truncated_factor, prior, latent_op, latent_op_step,
                            latent_op_alpha, latent_op_beta, device):
    if isinstance(gen_model, DataParallel) or isinstance(
            gen_model, DistributedDataParallel):
        z_dim = gen_model.module.z_dim
        num_classes = gen_model.module.num_classes
        conditional_strategy = dis_model.module.conditional_strategy
    else:
        z_dim = gen_model.z_dim
        num_classes = gen_model.num_classes
        conditional_strategy = dis_model.conditional_strategy

    zs, fake_labels = sample_latents(prior, batch_size, z_dim,
                                     truncated_factor, num_classes, None,
                                     device, real_label)

    if latent_op:
        zs = latent_optimise(zs, fake_labels, gen_model, dis_model,
                             conditional_strategy, latent_op_step, 1.0,
                             latent_op_alpha, latent_op_beta, False, device)

    with torch.no_grad():
        batch_images = gen_model(zs, fake_labels, evaluation=True)

    return batch_images, list(fake_labels.detach().cpu().numpy())
Exemplo n.º 3
0
def apply_accumulate_stat(generator, acml_step, prior, batch_size, z_dim, num_classes, device):
    generator.train()
    generator.apply(reset_bn_stat)
    for i in range(acml_step):
        new_batch_size = random.randint(1, batch_size)
        z, fake_labels = sample_latents(prior, new_batch_size, z_dim, 1, num_classes, None, device)
        generated_images = generator(z, fake_labels)
    generator.eval()
Exemplo n.º 4
0
def generate_images(batch_size, gen, dis, truncated_factor, prior, latent_op, latent_op_step,
                    latent_op_alpha, latent_op_beta, device):
    if isinstance(gen, DataParallel):
        z_dim = gen.module.z_dim
        num_classes = gen.module.num_classes
    else:
        z_dim = gen.z_dim
        num_classes = gen.num_classes
    
    z, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)

    if latent_op:
        z = latent_optimise(z, fake_labels, gen, dis, latent_op_step, 1.0, latent_op_alpha, latent_op_beta, False, device)
    
    with torch.no_grad():
        batch_images = gen(z, fake_labels)

    return batch_images
Exemplo n.º 5
0
def calculate_accuracy(dataloader, generator, discriminator, D_loss, num_evaluate, truncated_factor, prior, latent_op,
                       latent_op_step, latent_op_alpha, latent_op_beta, device, cr, logger, eval_generated_sample=False):
    data_iter = iter(dataloader)
    batch_size = dataloader.batch_size
    disable_tqdm = device != 0

    if isinstance(generator, DataParallel) or isinstance(generator, DistributedDataParallel):
        z_dim = generator.module.z_dim
        num_classes = generator.module.num_classes
        conditional_strategy = discriminator.module.conditional_strategy
    else:
        z_dim = generator.z_dim
        num_classes = generator.num_classes
        conditional_strategy = discriminator.conditional_strategy

    total_batch = num_evaluate//batch_size

    if D_loss.__name__ in ["loss_dcgan_dis", "loss_lsgan_dis"]:
        cutoff = 0.0
    elif D_loss.__name__ == "loss_hinge_dis":
        cutoff = 0.0
    elif D_loss.__name__ == "loss_wgan_dis":
        raise NotImplementedError

    if device == 0: logger.info("Calculate Accuracies....")

    if eval_generated_sample:
        for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
            zs, fake_labels = sample_latents(prior, batch_size, z_dim, truncated_factor, num_classes, None, device)
            if latent_op:
                zs = latent_optimise(zs, fake_labels, generator, discriminator, conditional_strategy, latent_op_step,
                                     1.0, latent_op_alpha, latent_op_beta, False, device)

            real_images, real_labels = next(data_iter)
            real_images, real_labels = real_images.to(device), real_labels.to(device)

            fake_images = generator(zs, fake_labels, evaluation=True)

            with torch.no_grad():
                if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]:
                    _, _, dis_out_fake = discriminator(fake_images, fake_labels)
                    _, _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ACGAN":
                    _, dis_out_fake = discriminator(fake_images, fake_labels)
                    _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ProjGAN" or conditional_strategy == "no":
                    dis_out_fake = discriminator(fake_images, fake_labels)
                    dis_out_real = discriminator(real_images, real_labels)
                else:
                    raise NotImplementedError

                dis_out_fake = dis_out_fake.detach().cpu().numpy()
                dis_out_real = dis_out_real.detach().cpu().numpy()

            if batch_id == 0:
                confid = np.concatenate((dis_out_fake, dis_out_real), axis=0)
                confid_label = np.concatenate(([0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)
            else:
                confid = np.concatenate((confid, dis_out_fake, dis_out_real), axis=0)
                confid_label = np.concatenate((confid_label, [0.0]*len(dis_out_fake), [1.0]*len(dis_out_real)), axis=0)

        real_confid = confid[confid_label==1.0]
        fake_confid = confid[confid_label==0.0]

        true_positive = real_confid[np.where(real_confid>cutoff)]
        true_negative = fake_confid[np.where(fake_confid<cutoff)]

        only_real_acc = len(true_positive)/len(real_confid)
        only_fake_acc = len(true_negative)/len(fake_confid)

        return only_real_acc, only_fake_acc
    else:
        for batch_id in tqdm(range(total_batch), disable=disable_tqdm):
            real_images, real_labels = next(data_iter)
            real_images, real_labels = real_images.to(device), real_labels.to(device)

            with torch.no_grad():
                if conditional_strategy in ["ContraGAN", "Proxy_NCA_GAN", "NT_Xent_GAN"]:
                    _, _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ACGAN":
                    _, dis_out_real = discriminator(real_images, real_labels)
                elif conditional_strategy == "ProjGAN" or conditional_strategy == "no":
                    dis_out_real = discriminator(real_images, real_labels)
                else:
                    raise NotImplementedError

                dis_out_real = dis_out_real.detach().cpu().numpy()

            if batch_id == 0:
                confid = dis_out_real
                confid_label = np.asarray([1.0]*len(dis_out_real), np.float32)
            else:
                confid = np.concatenate((confid, dis_out_real), axis=0)
                confid_label = np.concatenate((confid_label, [1.0]*len(dis_out_real)), axis=0)

        real_confid = confid[confid_label==1.0]
        true_positive = real_confid[np.where(real_confid>cutoff)]
        only_real_acc = len(true_positive)/len(real_confid)

        return only_real_acc
Exemplo n.º 6
0
    def __init__(
        self,
        run_name,
        best_step,
        dataset_name,
        type4eval_dataset,
        logger,
        writer,
        n_gpus,
        gen_model,
        dis_model,
        inception_model,
        Gen_copy,
        Gen_ema,
        train_dataloader,
        eval_dataloader,
        conditional_strategy,
        z_dim,
        num_classes,
        hypersphere_dim,
        d_spectral_norm,
        g_spectral_norm,
        G_optimizer,
        D_optimizer,
        batch_size,
        g_steps_per_iter,
        d_steps_per_iter,
        accumulation_steps,
        total_step,
        G_loss,
        D_loss,
        contrastive_lambda,
        tempering_type,
        tempering_step,
        start_temperature,
        end_temperature,
        gradient_penalty_for_dis,
        gradient_penelty_lambda,
        weight_clipping_for_dis,
        weight_clipping_bound,
        consistency_reg,
        consistency_lambda,
        diff_aug,
        prior,
        truncated_factor,
        ema,
        latent_op,
        latent_op_rate,
        latent_op_step,
        latent_op_step4eval,
        latent_op_alpha,
        latent_op_beta,
        latent_norm_reg_weight,
        default_device,
        second_device,
        print_every,
        save_every,
        checkpoint_dir,
        evaluate,
        mu,
        sigma,
        best_fid,
        best_fid_checkpoint_path,
        train_config,
        model_config,
    ):

        self.run_name = run_name
        self.best_step = best_step
        self.dataset_name = dataset_name
        self.type4eval_dataset = type4eval_dataset
        self.logger = logger
        self.writer = writer
        self.n_gpus = n_gpus

        self.gen_model = gen_model
        self.dis_model = dis_model
        self.inception_model = inception_model
        self.Gen_copy = Gen_copy
        self.Gen_ema = Gen_ema

        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader

        self.conditional_strategy = conditional_strategy
        self.z_dim = z_dim
        self.num_classes = num_classes
        self.hypersphere_dim = hypersphere_dim
        self.d_spectral_norm = d_spectral_norm
        self.g_spectral_norm = g_spectral_norm

        self.G_optimizer = G_optimizer
        self.D_optimizer = D_optimizer
        self.batch_size = batch_size
        self.g_steps_per_iter = g_steps_per_iter
        self.d_steps_per_iter = d_steps_per_iter
        self.accumulation_steps = accumulation_steps
        self.total_step = total_step

        self.G_loss = G_loss
        self.D_loss = D_loss
        self.contrastive_lambda = contrastive_lambda
        self.tempering_type = tempering_type
        self.tempering_step = tempering_step
        self.start_temperature = start_temperature
        self.end_temperature = end_temperature
        self.gradient_penalty_for_dis = gradient_penalty_for_dis
        self.gradient_penelty_lambda = gradient_penelty_lambda
        self.weight_clipping_for_dis = weight_clipping_for_dis
        self.weight_clipping_bound = weight_clipping_bound
        self.consistency_reg = consistency_reg
        self.consistency_lambda = consistency_lambda

        self.diff_aug = diff_aug
        self.prior = prior
        self.truncated_factor = truncated_factor
        self.ema = ema
        self.latent_op = latent_op
        self.latent_op_rate = latent_op_rate
        self.latent_op_step = latent_op_step
        self.latent_op_step4eval = latent_op_step4eval
        self.latent_op_alpha = latent_op_alpha
        self.latent_op_beta = latent_op_beta
        self.latent_norm_reg_weight = latent_norm_reg_weight

        self.default_device = default_device
        self.second_device = second_device
        self.print_every = print_every
        self.save_every = save_every
        self.checkpoint_dir = checkpoint_dir
        self.evaluate = evaluate
        self.mu = mu
        self.sigma = sigma
        self.best_fid = best_fid
        self.best_fid_checkpoint_path = best_fid_checkpoint_path
        self.train_config = train_config
        self.model_config = model_config

        self.start_time = datetime.now()
        self.l2_loss = torch.nn.MSELoss()
        self.ce_loss = torch.nn.CrossEntropyLoss()

        if self.conditional_strategy == 'ContraGAN':
            self.contrastive_criterion = Conditional_Embedding_Contrastive_loss(
                self.second_device, self.batch_size)
            self.tempering_range = self.end_temperature - self.start_temperature
            assert tempering_type == "constant" or tempering_type == "continuous" or tempering_type == "discrete", \
                "tempering_type should be one of constant, continuous, or discrete"
            if self.tempering_type == 'discrete':
                self.tempering_interval = self.total_step // (
                    self.tempering_step + 1)

        if self.conditional_strategy != "no":
            if self.dataset_name == "cifar10":
                cls_wise_sampling = "all"
            else:
                cls_wise_sampling = "some"
        else:
            cls_wise_sampling = "no"

        self.policy = "color,translation,cutout"

        self.fixed_noise, self.fixed_fake_labels = sample_latents(
            self.prior,
            self.batch_size,
            self.z_dim,
            1,
            self.num_classes,
            None,
            self.second_device,
            cls_wise_sampling=cls_wise_sampling)
Exemplo n.º 7
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    if self.diff_aug:
                        images = DiffAugment(images, policy=self.policy)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_real, dis_out_real = self.dis_model(
                            images, real_labels)
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_real = self.dis_model(images, real_labels)
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)

                    if self.conditional_strategy == "ACGAN":
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += gradient_penelty_lambda * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        if self.conditional_strategy == "ACGAN":
                            cls_out_real_aug, dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                            dis_out_real_aug = self.dis_model(
                                images_aug, real_labels)
                        else:
                            raise NotImplementedError
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.d_spectral_norm:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    if self.conditional_strategy == "ACGAN":
                        cls_out_fake, dis_out_fake = self.dis_model(
                            fake_images, fake_labels)
                    elif self.conditional_strategy == "cGAN" or self.conditional_strategy == "no":
                        dis_out_fake = self.dis_model(fake_images, fake_labels)
                    else:
                        raise NotImplementedError

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.conditional_strategy == "ACGAN":
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.ema:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.g_spectral_norm:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                with torch.no_grad():
                    if self.Gen_copy is not None:
                        self.Gen_copy.train()
                        generator = self.Gen_copy
                    else:
                        self.gen_model.eval()
                        generator = self.gen_model

                    generated_images = generator(self.fixed_noise,
                                                 self.fixed_fake_labels)
                    self.writer.add_images('Generated samples',
                                           (generated_images + 1) / 2,
                                           step_count)
                    self.gen_model.train()

            if step_count % self.save_every == 0 or step_count == total_step:
                if self.evaluate:
                    is_best = self.evaluation(step_count)
                    self.save(step_count, is_best)
                else:
                    self.save(step_count, False)
Exemplo n.º 8
0
    def run_ours(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            if self.tempering_type == 'continuous':
                t = self.start_temperature + step_count * (
                    self.end_temperature - self.start_temperature) / total_step
            elif self.tempering_type == 'discrete':
                t = self.start_temperature + \
                    (step_count//self.tempering_interval)*(self.end_temperature-self.start_temperature)/self.tempering_step
            else:
                t = self.start_temperature
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    if self.diff_aug:
                        images = DiffAugment(images, policy=self.policy)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    real_cls_mask = make_mask(real_labels, self.num_classes,
                                              self.second_device)

                    cls_real_proxies, cls_real_embed, dis_real_authen_out = self.dis_model(
                        images, real_labels)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_real_authen_out,
                                                dis_fake_authen_out)
                    dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                        cls_real_embed, cls_real_proxies, real_cls_mask,
                        real_labels, t)
                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model(
                            images_aug, real_labels)
                        consistency_loss = self.l2_loss(
                            dis_real_authen_out, dis_real_aug_authen_out)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps
                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.d_spectral_norm:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    fake_cls_mask = make_mask(fake_labels, self.num_classes,
                                              self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    if self.diff_aug:
                        fake_images = DiffAugment(fake_images,
                                                  policy=self.policy)

                    cls_fake_proxies, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_fake_authen_out)
                    gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                        cls_fake_embed, cls_fake_proxies, fake_cls_mask,
                        fake_labels, t)
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps
                    gen_acml_loss.backward()

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.ema:
                    self.Gen_ema.update(step_count)

                step_count += 1
            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature=t,
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.g_spectral_norm:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)

                with torch.no_grad():
                    if self.Gen_copy is not None:
                        self.Gen_copy.train()
                        generator = self.Gen_copy
                    else:
                        self.gen_model.eval()
                        generator = self.gen_model

                    generated_images = generator(self.fixed_noise,
                                                 self.fixed_fake_labels)
                    self.writer.add_images('Generated samples',
                                           (generated_images + 1) / 2,
                                           step_count)
                    self.gen_model.train()

            if step_count % self.save_every == 0 or step_count == total_step:
                if self.evaluate:
                    is_best = self.evaluation(step_count)
                    self.save(step_count, is_best)
                else:
                    self.save(step_count, False)
Exemplo n.º 9
0
    def __init__(self, run_name, logger, writer, n_gpus, gen_model, dis_model,
                 inception_model, Gen_copy, Gen_ema, train_dataloader,
                 evaluation_dataloader, G_loss, D_loss, auxiliary_classifier,
                 contrastive_training, softmax_posterior, contrastive_softmax,
                 hyper_dim, contrastive_lambda, tempering, discrete_tempering,
                 tempering_times, start_temperature, end_temperature,
                 gradient_penalty_for_dis, lambda4lp, lambda4gp,
                 weight_clipping_for_dis, weight_clipping_bound, latent_op,
                 latent_op_rate, latent_op_step, latent_op_step4eval,
                 latent_op_alpha, latent_op_beta, latent_norm_reg_weight,
                 consistency_reg, consistency_lambda, make_positive_aug,
                 G_optimizer, D_optimizer, default_device, second_device,
                 batch_size, z_dim, num_classes, truncated_factor, prior,
                 g_steps_per_iter, d_steps_per_iter, accumulation_steps,
                 lambda4ortho, print_every, save_every, checkpoint_dir,
                 evaluate, mu, sigma, best_val_fid, best_checkpoint_fid_path,
                 best_val_is, best_checkpoint_is_path, config):

        self.run_name = run_name
        self.logger = logger
        self.writer = writer
        self.n_gpus = n_gpus
        self.gen_model = gen_model
        self.dis_model = dis_model
        self.inception_model = inception_model
        self.Gen_copy = Gen_copy
        self.Gen_ema = Gen_ema
        self.train_dataloader = train_dataloader
        self.evaluation_dataloader = evaluation_dataloader

        self.G_loss = G_loss
        self.D_loss = D_loss
        self.auxiliary_classifier = auxiliary_classifier
        self.contrastive_training = contrastive_training
        self.softmax_posterior = softmax_posterior
        self.contrastive_softmax = contrastive_softmax
        self.hyper_dim = hyper_dim
        self.contrastive_lambda = contrastive_lambda
        self.tempering = tempering
        self.discrete_tempering = discrete_tempering
        self.tempering_times = tempering_times
        self.start_temperature = start_temperature
        self.end_temperature = end_temperature
        self.gradient_penalty_for_dis = gradient_penalty_for_dis
        self.lambda4lp = lambda4lp
        self.lambda4gp = lambda4gp
        self.weight_clipping_for_dis = weight_clipping_for_dis
        self.weight_clipping_bound = weight_clipping_bound
        self.latent_op = latent_op
        self.latent_op_rate = latent_op_rate
        self.latent_op_step = latent_op_step
        self.latent_op_step4eval = latent_op_step4eval
        self.latent_op_alpha = latent_op_alpha
        self.latent_op_beta = latent_op_beta
        self.latent_norm_reg_weight = latent_norm_reg_weight
        self.consistency_reg = consistency_reg
        self.consistency_lambda = consistency_lambda
        self.make_positive_aug = make_positive_aug
        self.G_optimizer = G_optimizer
        self.D_optimizer = D_optimizer
        self.default_device = default_device
        self.second_device = second_device

        self.batch_size = batch_size
        self.z_dim = z_dim
        self.num_classes = num_classes
        self.truncated_factor = truncated_factor
        self.prior = prior
        self.g_steps_per_iter = g_steps_per_iter
        self.d_steps_per_iter = d_steps_per_iter
        self.accumulation_steps = accumulation_steps
        self.lambda4ortho = lambda4ortho

        self.print_every = print_every
        self.save_every = save_every
        self.checkpoint_dir = checkpoint_dir
        self.config = config

        self.start_time = datetime.now()
        self.l2_loss = torch.nn.MSELoss()
        self.ce_loss = torch.nn.CrossEntropyLoss()
        if self.softmax_posterior:
            self.ce_criterion = Cross_Entropy_loss(
                self.hyper_dim, self.num_classes,
                self.config['d_spectral_norm']).to(self.second_device)
        if self.contrastive_softmax:
            self.contrastive_criterion = Conditional_Embedding_Contrastive_loss(
                self.second_device, self.batch_size)

        fixed_feed = next(iter(self.train_dataloader))
        self.fixed_images, self.fixed_real_labels = fixed_feed[0].to(
            self.second_device), fixed_feed[1].to(self.second_device)
        self.fixed_noise, self.fixed_fake_labels = sample_latents(
            self.prior, self.batch_size, self.z_dim, 1, self.num_classes, None,
            self.second_device)

        self.evaluate = evaluate
        self.mu = mu
        self.sigma = sigma
        self.best_val_fid = best_val_fid
        self.best_val_is = best_val_is
        self.best_checkpoint_fid_path = best_checkpoint_fid_path
        self.best_checkpoint_is_path = best_checkpoint_is_path
Exemplo n.º 10
0
    def run(self, current_step, total_step):
        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_index in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, False,
                            self.second_device)

                    _, cls_out_real, dis_out_real = self.dis_model(
                        images, real_labels)
                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_out_real, dis_out_fake)
                    if self.auxiliary_classifier:
                        dis_acml_loss += (
                            self.ce_loss(cls_out_real, real_labels) +
                            self.ce_loss(cls_out_fake, fake_labels))

                    if self.gradient_penalty_for_dis:
                        dis_acml_loss += self.lambda4gp * calc_derv4gp(
                            self.dis_model, images, fake_images, real_labels,
                            self.second_device)

                    if self.consistency_reg:
                        images_aug = images_aug.to(self.second_device)
                        _, _, dis_out_real_aug = self.dis_model(
                            images_aug, real_labels)
                        consistency_loss = self.l2_loss(
                            dis_out_real, dis_out_real_aug)
                        dis_acml_loss += self.consistency_lambda * consistency_loss

                    dis_acml_loss = dis_acml_loss / self.accumulation_steps

                    dis_acml_loss.backward()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.config['d_spectral_norm']:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_D_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_D_update':
                            l2_norm_grad_z_aft_D_update.mean().item()
                        }, step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)

                    if self.latent_op:
                        z, transport_cost = latent_optimise(
                            z, fake_labels, self.gen_model, self.dis_model,
                            self.latent_op_step, self.latent_op_rate,
                            self.latent_op_alpha, self.latent_op_beta, True,
                            self.second_device)

                    fake_images = self.gen_model(z, fake_labels)
                    _, cls_out_fake, dis_out_fake = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_out_fake)
                    if self.auxiliary_classifier:
                        gen_acml_loss += self.ce_loss(cls_out_fake,
                                                      fake_labels)

                    if self.latent_op:
                        gen_acml_loss += transport_cost * self.latent_norm_reg_weight
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps

                    gen_acml_loss.backward()

                if self.lambda4ortho is float and self.lambda4ortho > 0 and self.config[
                        'ortho_reg']:
                    if isinstance(self.gen_model, DataParallel):
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.module.shared.parameters()
                              ])
                    else:
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.shared.parameters()
                              ])

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.config['ema']:
                    self.Gen_ema.update(step_count)

                step_count += 1

            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature='No',
                    dis_loss=dis_acml_loss.item(),
                    gen_loss=gen_acml_loss.item(),
                )
                self.logger.info(log_message)

                if self.config['g_spectral_norm']:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                self.writer.add_images('Generated samples',
                                       (fake_images + 1) / 2, step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_G_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_G_update':
                            l2_norm_grad_z_aft_G_update.mean().item()
                        }, step_count)

            if step_count % self.save_every == 0 or step_count == total_step:
                self.valid_and_save(step_count)
Exemplo n.º 11
0
    def run_ours(self, current_step, total_step):
        if self.tempering and self.discrete_tempering:
            temperature_range = self.end_temperature - self.start_temperature
            temperature_change = temperature_range / self.tempering_times
            temperatures = [
                self.start_temperature + time * temperature_change
                for time in range(self.tempering_times + 1)
            ]
            temperatures += [
                self.start_temperature +
                self.tempering_times * temperature_change
            ]
            interval = total_step // (self.tempering_times + 1)

        self.dis_model.train()
        self.gen_model.train()
        if self.Gen_copy is not None:
            self.Gen_copy.train()
        step_count = current_step
        train_iter = iter(self.train_dataloader)
        cls_real_aug_embed = None
        while step_count <= total_step:
            # ================== TRAIN D ================== #
            if self.tempering and not self.discrete_tempering:
                t = self.start_temperature + step_count * (
                    self.end_temperature - self.start_temperature) / total_step
            elif self.tempering and self.discrete_tempering:
                t = temperatures[step_count // interval]
            else:
                t = self.start_temperature
            for step_index in range(self.d_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                dis_loss = 0
                for acml_step in range(self.accumulation_steps):
                    try:
                        if self.consistency_reg or self.make_positive_aug:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)
                    except StopIteration:
                        train_iter = iter(self.train_dataloader)
                        if self.consistency_reg or self.make_positive_aug:
                            images, real_labels, images_aug = next(train_iter)
                        else:
                            images, real_labels = next(train_iter)

                    images, real_labels = images.to(
                        self.second_device), real_labels.to(self.second_device)
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    real_cls_mask = make_mask(real_labels, self.num_classes,
                                              self.second_device)

                    cls_real_anchor, cls_real_embed, dis_real_authen_out = self.dis_model(
                        images, real_labels)

                    fake_images = self.gen_model(z, fake_labels)
                    cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    dis_acml_loss = self.D_loss(dis_real_authen_out,
                                                dis_fake_authen_out)

                    if self.consistency_reg or self.make_positive_aug:
                        images_aug = images_aug.to(self.second_device)
                        _, cls_real_aug_embed, dis_real_aug_authen_out = self.dis_model(
                            images_aug, real_labels)
                        if self.consistency_reg:
                            consistency_loss = self.l2_loss(
                                dis_real_authen_out, dis_real_aug_authen_out)
                            dis_acml_loss += self.consistency_lambda * consistency_loss

                    if self.softmax_posterior:
                        dis_acml_loss += self.ce_criterion(
                            cls_real_embed, real_labels)
                    elif self.contrastive_softmax:
                        dis_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                            cls_real_embed, cls_real_anchor, real_cls_mask,
                            real_labels, t, cls_real_aug_embed)
                    dis_acml_loss = dis_acml_loss / self.accumulation_steps
                    dis_acml_loss.backward()
                    dis_loss += dis_acml_loss.item()

                self.D_optimizer.step()

                if self.weight_clipping_for_dis:
                    for p in self.dis_model.parameters():
                        p.data.clamp_(-self.weight_clipping_bound,
                                      self.weight_clipping_bound)

            if step_count % self.print_every == 0 and step_count != 0 and self.logger:
                if self.config['d_spectral_norm']:
                    dis_sigmas = calculate_all_sn(self.dis_model)
                    self.writer.add_scalars('SN_of_dis', dis_sigmas,
                                            step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_D_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_D_update':
                            l2_norm_grad_z_aft_D_update.mean().item()
                        }, step_count)

            # ================== TRAIN G ================== #
            for step_index in range(self.g_steps_per_iter):
                self.D_optimizer.zero_grad()
                self.G_optimizer.zero_grad()
                gen_loss = 0
                for acml_step in range(self.accumulation_steps):
                    z, fake_labels = sample_latents(self.prior,
                                                    self.batch_size,
                                                    self.z_dim, 1,
                                                    self.num_classes, None,
                                                    self.second_device)
                    fake_cls_mask = make_mask(fake_labels, self.num_classes,
                                              self.second_device)

                    fake_images = self.gen_model(z, fake_labels)

                    cls_fake_anchor, cls_fake_embed, dis_fake_authen_out = self.dis_model(
                        fake_images, fake_labels)

                    gen_acml_loss = self.G_loss(dis_fake_authen_out)
                    if self.softmax_posterior:
                        gen_acml_loss += self.ce_criterion(
                            cls_fake_embed, fake_labels)
                    elif self.contrastive_softmax:
                        gen_acml_loss += self.contrastive_lambda * self.contrastive_criterion(
                            cls_fake_embed, cls_fake_anchor, fake_cls_mask,
                            fake_labels, t)
                    gen_acml_loss = gen_acml_loss / self.accumulation_steps
                    gen_acml_loss.backward()
                    gen_loss += gen_acml_loss.item()

                if isinstance(
                        self.lambda4ortho, float
                ) and self.lambda4ortho > 0 and self.config['ortho_reg']:
                    if isinstance(self.gen_model, DataParallel):
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.module.shared.parameters()
                              ])
                    else:
                        ortho(self.gen_model,
                              self.lambda4ortho,
                              blacklist=[
                                  param for param in
                                  self.gen_model.shared.parameters()
                              ])

                self.G_optimizer.step()

                # if ema is True: we update parameters of the Gen_copy in adaptive way.
                if self.config['ema']:
                    self.Gen_ema.update(step_count)

                step_count += 1
            if step_count % self.print_every == 0 and self.logger:
                log_message = LOG_FORMAT.format(
                    step=step_count,
                    progress=step_count / total_step,
                    elapsed=elapsed_time(self.start_time),
                    temperature=t,
                    dis_loss=dis_loss,
                    gen_loss=gen_loss,
                )
                self.logger.info(log_message)

                if self.config['g_spectral_norm']:
                    gen_sigmas = calculate_all_sn(self.gen_model)
                    self.writer.add_scalars('SN_of_gen', gen_sigmas,
                                            step_count)

                self.writer.add_scalars(
                    'Losses', {
                        'discriminator': dis_acml_loss.item(),
                        'generator': gen_acml_loss.item()
                    }, step_count)
                self.writer.add_images('Generated samples',
                                       (fake_images + 1) / 2, step_count)

                if self.config['calculate_z_grad']:
                    _, l2_norm_grad_z_aft_G_update = calc_derv(
                        self.fixed_noise, self.fixed_fake_labels,
                        self.dis_model, self.second_device, self.gen_model)
                    self.writer.add_scalars(
                        'L2_norm_grad', {
                            'z_aft_G_update':
                            l2_norm_grad_z_aft_G_update.mean().item()
                        }, step_count)

            if step_count % self.save_every == 0 or step_count == total_step:
                self.valid_and_save(step_count)
Exemplo n.º 12
0
def calculate_acc_confidence(dataloader, generator, discriminator, G_loss,
                             num_evaluate, truncated_factor, prior, latent_op,
                             latent_op_step, latent_op_alpha, latent_op_beta,
                             device):
    generator.eval()
    discriminator.eval()
    data_iter = iter(dataloader)
    batch_size = dataloader.batch_size

    if isinstance(generator, DataParallel):
        z_dim = generator.module.z_dim
        num_classes = generator.module.num_classes
    else:
        z_dim = generator.z_dim
        num_classes = generator.num_classes

    if num_evaluate % batch_size == 0:
        total_batch = num_evaluate // batch_size
    else:
        raise Exception("num_evaluate '%' batch4metrics should be 0!")

    if G_loss.__name__ == "loss_dcgan_gen":
        cutoff = 0.5
        fake_target = 0.0
    elif G_loss.__name__ == "loss_hinge_gen":
        cutoff = 0.0
        fake_target = -1.0
    elif G_loss.__name__ == "loss_wgan_gen":
        raise NotImplementedError

    for batch_id in tqdm(range(total_batch)):
        z, fake_labels = sample_latents(prior, batch_size, z_dim,
                                        truncated_factor, num_classes, None,
                                        device)
        if latent_op:
            z = latent_optimise(z, fake_labels, generator, discriminator,
                                latent_op_step, 1.0, latent_op_alpha,
                                latent_op_beta, False, device)

        images_real, real_labels = next(data_iter)
        images_real, real_labels = images_real.to(device), real_labels.to(
            device)

        with torch.no_grad():
            images_gen = generator(z, fake_labels)
            _, _, dis_out_fake = discriminator(images_gen, fake_labels)
            _, _, dis_out_real = discriminator(images_real, real_labels)
            dis_out_fake = dis_out_fake.detach().cpu().numpy()
            dis_out_real = dis_out_real.detach().cpu().numpy()

        if batch_id == 0:
            confid = np.concatenate((dis_out_fake, dis_out_real), axis=0)
            confid_label = np.concatenate(
                ([fake_target] * batch_size, [1.0] * batch_size), axis=0)
        else:
            confid = np.concatenate((confid, dis_out_fake, dis_out_real),
                                    axis=0)
            confid_label = np.concatenate(
                (confid_label, [fake_target] * batch_size, [1.0] * batch_size),
                axis=0)

    real_confid = confid[confid_label == 1.0]
    fake_confid = confid[confid_label == fake_target]

    true_positive = real_confid[np.where(real_confid > cutoff)]
    true_negative = fake_confid[np.where(fake_confid < cutoff)]

    only_real_acc = len(true_positive) / len(real_confid)
    only_fake_acc = len(true_negative) / len(fake_confid)
    mixed_acc = (len(true_positive) + len(true_negative)) / len(confid)

    generator.train()
    discriminator.train()

    return only_real_acc, only_fake_acc, mixed_acc, confid, confid_label