Пример #1
0
    def prepare_DIPs(self):
        # x is stand for the sharp image, k is stand for the kernel
        self.x_dip = ImageDIP(self.opt["ImageDIP"]).cuda()
        self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda()

        # fixed input vectors of DIPs
        # zk and zx are the length of the corresponding vectors
        self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda()
        self.dip_zx = util.get_noise(8, "noise", self.opt["img_size"]).cuda()
Пример #2
0
def extract_feature_real_fake():
    inception_model, gen = get_model()
    inception_model.fc = torch.nn.Identity()
    # summary(inception_model, (3, 299, 299))

    fake_features_list = []
    real_features_list = []

    gen.eval()
    n_samples = 512  # The total number of samples

    dataloader = get_dataloader()

    cur_samples = 0
    with torch.no_grad(
    ):  # You don't need to calculate gradients here, so you do this to save memory
        for real_example, _ in tqdm(dataloader, total=n_samples //
                                    batch_size):  # Go by batch
            real_samples = real_example
            real_features = inception_model(
                real_samples.to(device)).detach().to(
                    'cpu')  # Move features to CPU
            real_features_list.append(real_features)

            fake_samples = get_noise(len(real_example), z_dim).to(device)
            fake_samples = preprocess(gen(fake_samples))
            fake_features = inception_model(
                fake_samples.to(device)).detach().to('cpu')
            fake_features_list.append(fake_features)
            cur_samples += len(real_samples)
            if cur_samples >= n_samples:
                break

    return fake_features_list, real_features_list
Пример #3
0
def run_conditional_gen_with_regularization():
    gen, classifier, opt = load_pretrained_models()

    fake_image_history = []
    ### Change me! ###
    target_indices = feature_names.index(
        "Eyeglasses"
    )  # Feel free to change this value to any string from feature_names from earlier!
    other_indices = [
        cur_idx != target_indices for cur_idx, _ in enumerate(feature_names)
    ]
    noise = get_noise(n_images, z_dim).to(device).requires_grad_()
    original_classifications = classifier(gen(noise)).detach()
    for i in range(grad_steps):
        opt.zero_grad()
        fake = gen(noise)
        fake_image_history += [fake]
        fake_score = _get_score(classifier(fake),
                                original_classifications,
                                target_indices,
                                other_indices,
                                penalty_weight=0.1)
        fake_score.backward()
        noise.data = _calculate_updated_noise(noise, 1 / grad_steps)

    plt.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2]
    show_tensor_images(torch.cat(fake_image_history[::skip], dim=2),
                       num_images=n_images,
                       nrow=n_images)
Пример #4
0
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    """
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare
               the discriminator's predictions to the ground truth reality of the images
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        num_images: the number of images the generator should produce,
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    """

    # Create noise vectors and generate a batch (num_images) of fake images.
    noise_vec = get_noise(num_images, z_dim, device)
    gen_output = gen(noise_vec).detach()

    # Get the discriminator's prediction of the fake image and calculate the loss.
    disc_output_fake = disc(gen_output)
    disc_loss_fake = criterion(disc_output_fake, torch.zeros_like(disc_output_fake))

    # Get the discriminator's prediction of the real image and calculate the loss.
    disc_output_real = disc(real)
    disc_loss_real = criterion(disc_output_real, torch.ones_like(disc_output_real))

    # Calculate the discriminator's loss by averaging the real and fake loss.
    disc_loss = (disc_loss_fake + disc_loss_real) / 2

    return disc_loss
Пример #5
0
def train_dcgan(gen, disc, dataloader, epochs, gen_opt, disc_opt, criterion,
                z_dim):
    gen = gen.to(device)
    disc = disc.to(device)

    data_size = len(dataloader.dataset)

    print()
    print(f'Start training on {device_name}')
    print(64 * '-')

    for epoch in range(epochs):
        generator_loss = 0.
        discriminator_loss = 0.
        # Dataloader returns the batches

        for real, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs - 1}"):
            cur_batch_size = len(real)

            real = real.to(device)

            # Update discriminator
            disc_opt.zero_grad()
            disc_loss = get_disc_loss(gen, disc, criterion, real,
                                      cur_batch_size, z_dim, device)
            disc_loss.backward()
            disc_opt.step()

            # Update generator
            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size,
                                    z_dim, device)
            gen_loss.backward()
            gen_opt.step()

            # Keep track of the epoch sum discriminator loss
            discriminator_loss += disc_loss.item() * cur_batch_size

            # Keep track of the epoch sum generator loss
            generator_loss += gen_loss.item() * cur_batch_size

        mean_discriminator_loss = discriminator_loss / data_size
        mean_generator_loss = generator_loss / data_size

        write_loss_to_file(mean_discriminator_loss, 'discriminator_loss.txt')
        write_loss_to_file(mean_generator_loss, 'generator_loss.txt')

        print(
            f"Generator loss: {mean_generator_loss:.4f}     discriminator loss: {mean_discriminator_loss:.4f}"
        )

        # Visualization
        fake_noise = get_noise(64, z_dim, device=device)
        fake = gen(fake_noise)
        save_tensor_images_dcgan(fake, f'dcgan-{epoch}')
Пример #6
0
def run_conditional_gen():
    gen, classifier, opt = load_pretrained_models()
    fake_image_history = []

    ### Change me! ###
    target_indices = feature_names.index(
        "Male"
    )  # Feel free to change this value to any string from feature_names!

    noise = get_noise(n_images, z_dim).to(device).requires_grad_()
    for i in range(grad_steps):
        opt.zero_grad()
        fake = gen(noise)
        fake_image_history.append(fake)
        fake_classes_score = classifier(fake)[:, target_indices].mean()
        fake_classes_score.backward()
        noise.data = _calculate_updated_noise(noise, 1 / grad_steps)

    plt.rcParams['figure.figsize'] = [n_images * 2, grad_steps * 2 / skip]
    show_tensor_images(torch.cat(fake_image_history[::skip], dim=2),
                       num_images=n_images,
                       nrow=n_images)
Пример #7
0
def train_wgangp(gen, crit, dataloader, epochs, gen_opt, crit_opt, z_dim,
                 c_lambda):
    gen = gen.to(device)
    crit = crit.to(device)
    crit_repeats = 5
    data_size = len(dataloader.dataset)

    print()
    print(f'Start training on {device_name}')
    print(64 * '-')

    for epoch in range(epochs):
        critic_losses = 0.
        generator_losses = 0.

        # Dataloader returns the batches
        for real, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs - 1}"):

            cur_batch_size = len(real)
            real = real.to(device)
            mean_iteration_critic_loss = 0

            # train the critic for n times
            for _ in range(crit_repeats):
                # Update discriminator
                crit_opt.zero_grad()
                fake_noise = get_noise(cur_batch_size, z_dim, device)
                fake = gen(fake_noise)
                crit_fake_pred = crit(fake.detach())
                crit_real_pred = crit(real)

                epsilon = torch.rand(len(real),
                                     1,
                                     1,
                                     1,
                                     device=device,
                                     requires_grad=True)
                # gradient of mixed images
                gradient = get_gradient(crit, real, fake.detach(), epsilon)
                gp = gradient_penalty(gradient)
                # only gradient penalty when training the critic
                crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp,
                                          c_lambda)

                mean_iteration_critic_loss += crit_loss.item() / crit_repeats
                # Update gradients
                crit_loss.backward()
                # Update optimizer
                crit_opt.step()
            critic_losses += mean_iteration_critic_loss

            # Update generator
            gen_opt.zero_grad()
            fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
            fake_2 = gen(fake_noise_2)
            crit_fake_pred = crit(fake_2)

            gen_loss = get_gen_loss(crit_fake_pred)
            gen_loss.backward()

            # Update the weights
            gen_opt.step()

            # Keep track of the average generator loss
            generator_losses += gen_loss.item()

        mean_critic_loss = critic_losses / data_size
        mean_generator_loss = generator_losses / data_size

        write_loss_to_file(mean_critic_loss, 'discriminator_loss.txt')
        write_loss_to_file(mean_generator_loss, 'generator_loss.txt')

        print(
            f"Generator loss: {mean_generator_loss:.4f}     discriminator loss: {mean_critic_loss:.4f}"
        )

        # Visualization
        fake_noise = get_noise(64, z_dim, device=device)
        fake = gen(fake_noise)
        save_tensor_images_dcgan(fake, f'wgangp-{epoch}')
Пример #8
0
def test_generator():
    gen, _, _ = load_pretrained_models()
    noise = get_noise(n_images, z_dim).to(device)
    fake = gen(noise)
    save_tensor_images_dcgan(fake, '', show=True)
Пример #9
0
def train_cgan(gen, disc, dataloader, epochs, gen_opt, disc_opt, criterion, z_dim, n_classes, mnist_shape):
    gen = gen.to(device)
    disc = disc.to(device)

    data_size = len(dataloader.dataset)

    print()
    print(f'Start training on {device_name}')
    print(64 * '-')

    for epoch in range(epochs):
        generator_loss = 0.
        discriminator_loss = 0.
        # Dataloader returns the batches

        gen.train()
        for real, labels in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs - 1}"):
            cur_batch_size = len(real)

            real = real.to(device)
            # for mnist, n_classes=10
            # one_hot_labels.shape=128,10
            one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            # image_one_hot_labels.shape=128,10,28,28
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

            ##########################
            # Update discriminator
            disc_opt.zero_grad()
            # fake_noise.shape=128,64
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            # noise_and_labels.shape=128,74
            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
            # fake.shape=128,1,28,28
            fake = gen(noise_and_labels)
            # fake_image_and_labels.shape=128,11,28,28
            fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
            # real_image_and_labels.shape=128,11,28,28
            real_image_and_labels = combine_vectors(real, image_one_hot_labels)
            disc_fake_pred = disc(fake_image_and_labels)
            disc_real_pred = disc(real_image_and_labels)
            disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))

            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            # 在这个情况里,做了一次从gen到disc的forward,但是做了两次backward,所以需要在第一次back的时候保留
            # 计算图
            disc_loss.backward(retain_graph=True)
            disc_opt.step()

            # Keep track of the epoch sum discriminator loss
            discriminator_loss += disc_loss.item() * cur_batch_size
            ##########################

            # Update generator
            gen_opt.zero_grad()

            # fake_image_and_labels contains the computational graph of the last step.
            # double disc_fake_pred = disc to prevent inplace operator error

            # pred the fake image with the newly updated discriminator
            disc_fake_pred_new = disc(fake_image_and_labels)
            gen_loss = criterion(disc_fake_pred_new, torch.ones_like(disc_fake_pred))
            # 此时需要从gen_loss算到disc_fake_pred,到fake,到gen,所以仍然需要第一次正向传播产生的计算图
            gen_loss.backward()
            gen_opt.step()
            ##########################

            # Keep track of the epoch sum generator loss
            generator_loss += gen_loss.item() * cur_batch_size

        mean_discriminator_loss = discriminator_loss / data_size
        mean_generator_loss = generator_loss / data_size

        write_loss_to_file(mean_discriminator_loss, 'discriminator_loss.txt')
        write_loss_to_file(mean_generator_loss, 'generator_loss.txt')

        print(f"Generator loss: {mean_generator_loss:.4f}     discriminator loss: {mean_discriminator_loss:.4f}")

        # Visualization/validation
        labels = [i % 10 for i in range(100)]
        labels = torch.as_tensor(labels, dtype=torch.long)
        fake_noise = get_noise(n_samples=100, z_dim=z_dim, device=device)
        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        gen.eval()
        fake = gen(noise_and_labels)
        save_tensor_images_cgan(fake, f'cgan-{epoch}', num_images=100)
Пример #10
0
 def initialize_dip(self):
     self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda().detach()
     self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda()