Beispiel #1
0
                                os.path.join(paths.log_path, "log.txt")),
                            logging.StreamHandler()
                        ])

    # fetching device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.debug(f"{device}")

    # training configuration
    with open("train_config.json", "r") as f:
        train_config = json.load(f)
        args = Namespace(**train_config)

    # initializing networks and optimizers
    if args.type == "DCGAN":
        G, D = utils.get_gan(GANType.DCGAN, device)
        G_optim, D_optim = utils.get_optimizers(G, D)
    elif args.type == "SN_DCGAN":
        G, D = utils.get_gan(GANType.SN_DCGAN, device, args.n_power_iterations)
        G_optim, D_optim = utils.get_optimizers(G, D)

    # initializing loader for data
    data_loader = utils.get_data_loader(args.batch_size, args.img_size)

    # setting up loss and GT
    adversarial_loss = nn.BCELoss()
    real_gt, fake_gt = utils.get_gt(args.batch_size, device)

    # for logging
    log_batch_size = 25
    log_noise = utils.get_latent_batch(log_batch_size, device)
Beispiel #2
0
def generate_new_images(model_name,
                        cgan_digit=None,
                        generation_mode=True,
                        slerp=True,
                        a=None,
                        b=None,
                        should_display=True):
    """ Generate imagery using pre-trained generator (using vanilla_generator_000000.pth by default)

    Args:
        model_name (str): model name you want to use (default lookup location is BINARIES_PATH).
        cgan_digit (int): if specified generate that exact digit.
        generation_mode (enum):  generate a single image from a random vector, interpolate between the 2 chosen latent
         vectors, or perform arithmetic over latent vectors (note: not every mode is supported for every model type)
        slerp (bool): if True use spherical interpolation otherwise use linear interpolation.
        a, b (numpy arrays): latent vectors, if set to None you'll be prompted to choose images you like,
         and use corresponding latent vectors instead.
        should_display (bool): Display the generated images before saving them.

    """

    model_path = os.path.join(BINARIES_PATH, model_name)
    assert os.path.exists(
        model_path
    ), f'Could not find the model {model_path}. You first need to train your generator.'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare the correct (vanilla, cGAN, DCGAN, ...) model, load the weights and put the model into evaluation mode
    model_state = torch.load(model_path)
    gan_type = model_state["gan_type"]
    print(f'Found {gan_type} GAN!')
    _, generator = utils.get_gan(device, gan_type)
    generator.load_state_dict(model_state["state_dict"], strict=True)
    generator.eval()

    # Generate a single image, save it and potentially display it
    if generation_mode == GenerationMode.SINGLE_IMAGE:
        generated_imgs_path = os.path.join(DATA_DIR_PATH, 'generated_imagery')
        os.makedirs(generated_imgs_path, exist_ok=True)

        generated_img, _ = generate_from_random_latent_vector(
            generator, cgan_digit if gan_type == GANType.CGAN.name else None)
        utils.save_and_maybe_display_image(generated_imgs_path,
                                           generated_img,
                                           should_display=should_display)

    # Pick 2 images you like between which you'd like to interpolate (by typing 'y' into console)
    elif generation_mode == GenerationMode.INTERPOLATION:
        assert gan_type == GANType.VANILLA.name or gan_type == GANType.DCGAN.name, f'Got {gan_type} but only VANILLA/DCGAN are supported for the interpolation mode.'

        interpolation_name = "spherical" if slerp else "linear"
        interpolation_fn = spherical_interpolation if slerp else linear_interpolation

        grid_interpolated_imgs_path = os.path.join(
            DATA_DIR_PATH, 'interpolated_imagery')  # combined results dir
        decomposed_interpolated_imgs_path = os.path.join(
            grid_interpolated_imgs_path,
            f'tmp_{gan_type}_{interpolation_name}_dump'
        )  # dump separate results
        if os.path.exists(decomposed_interpolated_imgs_path):
            shutil.rmtree(decomposed_interpolated_imgs_path)
        os.makedirs(grid_interpolated_imgs_path, exist_ok=True)
        os.makedirs(decomposed_interpolated_imgs_path, exist_ok=True)

        latent_vector_a, latent_vector_b = [None, None]

        # If a and b were not specified loop until the user picked the 2 images he/she likes.
        found_good_vectors_flag = False
        if a is None or b is None:
            while not found_good_vectors_flag:
                generated_img, latent_vector = generate_from_random_latent_vector(
                    generator)
                plt.imshow(generated_img)
                plt.title('Do you like this image?')
                plt.show()
                user_input = input(
                    "Do you like this generated image? [y for yes]:")
                if user_input == 'y':
                    if latent_vector_a is None:
                        latent_vector_a = latent_vector
                        print('Saved the first latent vector.')
                    elif latent_vector_b is None:
                        latent_vector_b = latent_vector
                        print('Saved the second latent vector.')
                        found_good_vectors_flag = True
                else:
                    print('Well lets generate a new one!')
                    continue
        else:
            print(
                'Skipping latent vectors selection section and using cached ones.'
            )
            latent_vector_a, latent_vector_b = [a, b]

        # Cache latent vectors
        if a is None or b is None:
            np.save(os.path.join(grid_interpolated_imgs_path, 'a.npy'),
                    latent_vector_a)
            np.save(os.path.join(grid_interpolated_imgs_path, 'b.npy'),
                    latent_vector_b)

        print(f'Lets do some {interpolation_name} interpolation!')
        interpolation_resolution = 47  # number of images between the vectors a and b
        num_interpolated_imgs = interpolation_resolution + 2  # + 2 so that we include a and b

        generated_imgs = []
        for i in range(num_interpolated_imgs):
            t = i / (num_interpolated_imgs - 1)  # goes from 0. to 1.
            current_latent_vector = interpolation_fn(t, latent_vector_a,
                                                     latent_vector_b)
            generated_img = generate_from_specified_numpy_latent_vector(
                generator, current_latent_vector)

            print(f'Generated image [{i+1}/{num_interpolated_imgs}].')
            utils.save_and_maybe_display_image(
                decomposed_interpolated_imgs_path,
                generated_img,
                should_display=should_display)

            # Move from channel last to channel first (CHW->HWC), PyTorch's save_image function expects BCHW format
            generated_imgs.append(
                torch.tensor(np.moveaxis(generated_img, 2, 0)))

        interpolated_block_img = torch.stack(generated_imgs)
        interpolated_block_img = nn.Upsample(
            scale_factor=2.5, mode='nearest')(interpolated_block_img)
        save_image(
            interpolated_block_img,
            os.path.join(
                grid_interpolated_imgs_path,
                utils.get_available_file_name(grid_interpolated_imgs_path)),
            nrow=int(np.sqrt(num_interpolated_imgs)))

    elif generation_mode == GenerationMode.VECTOR_ARITHMETIC:
        assert gan_type == GANType.DCGAN.name, f'Got {gan_type} but only DCGAN is supported for arithmetic mode.'

        # Generate num_options face images and create a grid image from them
        num_options = 100
        generated_imgs = []
        latent_vectors = []
        padding = 2
        for i in range(num_options):
            generated_img, latent_vector = generate_from_random_latent_vector(
                generator)
            generated_imgs.append(
                torch.tensor(np.moveaxis(generated_img, 2,
                                         0)))  # make_grid expects CHW format
            latent_vectors.append(latent_vector)
        stacked_tensor_imgs = torch.stack(generated_imgs)
        final_tensor_img = make_grid(stacked_tensor_imgs,
                                     nrow=int(np.sqrt(num_options)),
                                     padding=padding)
        display_img = np.moveaxis(final_tensor_img.numpy(), 0, 2)

        # For storing latent vectors
        num_of_vectors_per_category = 3
        happy_woman_latent_vectors = []
        neutral_woman_latent_vectors = []
        neutral_man_latent_vectors = []

        # Make it easy - by clicking on the plot you pick the image.
        def onclick(event):
            if event.dblclick:
                pass
            else:  # single click
                if event.button == 1:  # left click
                    x_coord = event.xdata
                    y_coord = event.ydata
                    column = int(x_coord / (64 + padding))
                    row = int(y_coord / (64 + padding))

                    # Store latent vector corresponding to the image that the user clicked on.
                    if len(happy_woman_latent_vectors
                           ) < num_of_vectors_per_category:
                        happy_woman_latent_vectors.append(
                            latent_vectors[10 * row + column])
                        print(
                            f'Picked image row={row}, column={column} as {len(happy_woman_latent_vectors)}. happy woman.'
                        )
                    elif len(neutral_woman_latent_vectors
                             ) < num_of_vectors_per_category:
                        neutral_woman_latent_vectors.append(
                            latent_vectors[10 * row + column])
                        print(
                            f'Picked image row={row}, column={column} as {len(neutral_woman_latent_vectors)}. neutral woman.'
                        )
                    elif len(neutral_man_latent_vectors
                             ) < num_of_vectors_per_category:
                        neutral_man_latent_vectors.append(
                            latent_vectors[10 * row + column])
                        print(
                            f'Picked image row={row}, column={column} as {len(neutral_man_latent_vectors)}. neutral man.'
                        )
                    else:
                        plt.close()

        plt.figure(figsize=(10, 10))
        plt.imshow(display_img)
        # This is just an example you could also pick 3 neutral woman images with sunglasses, etc.
        plt.title(
            'Click on 3 happy women, 3 neutral women and \n 3 neutral men images (order matters!)'
        )
        cid = plt.gcf().canvas.mpl_connect('button_press_event', onclick)
        plt.show()
        plt.gcf().canvas.mpl_disconnect(cid)
        print('Done choosing images.')

        # Calculate the average latent vector for every category (happy woman, neutral woman, neutral man)
        happy_woman_avg_latent_vector = np.mean(
            np.array(happy_woman_latent_vectors), axis=0)
        neutral_woman_avg_latent_vector = np.mean(
            np.array(neutral_woman_latent_vectors), axis=0)
        neutral_man_avg_latent_vector = np.mean(
            np.array(neutral_man_latent_vectors), axis=0)

        # By subtracting neutral woman from the happy woman we capture the "vector of smiling". Adding that vector
        # to a neutral man we get a happy man's latent vector! Our latent space has amazingly beautiful structure!
        happy_man_latent_vector = neutral_man_avg_latent_vector + (
            happy_woman_avg_latent_vector - neutral_woman_avg_latent_vector)

        # Generate images from these latent vectors
        happy_women_imgs = np.hstack([
            generate_from_specified_numpy_latent_vector(generator, v)
            for v in happy_woman_latent_vectors
        ])
        neutral_women_imgs = np.hstack([
            generate_from_specified_numpy_latent_vector(generator, v)
            for v in neutral_woman_latent_vectors
        ])
        neutral_men_imgs = np.hstack([
            generate_from_specified_numpy_latent_vector(generator, v)
            for v in neutral_man_latent_vectors
        ])

        happy_woman_avg_img = generate_from_specified_numpy_latent_vector(
            generator, happy_woman_avg_latent_vector)
        neutral_woman_avg_img = generate_from_specified_numpy_latent_vector(
            generator, neutral_woman_avg_latent_vector)
        neutral_man_avg_img = generate_from_specified_numpy_latent_vector(
            generator, neutral_man_avg_latent_vector)

        happy_man_img = generate_from_specified_numpy_latent_vector(
            generator, happy_man_latent_vector)

        display_vector_arithmetic_results([
            happy_women_imgs, happy_woman_avg_img, neutral_women_imgs,
            neutral_woman_avg_img, neutral_men_imgs, neutral_man_avg_img,
            happy_man_img
        ])
    else:
        raise Exception(f'Generation mode not yet supported.')
Beispiel #3
0
def train_gan(training_config):
    writer = SummaryWriter()
    device = torch.device("cpu")

    # Download MNIST dataset in the directory data
    mnist_data_loader = utils.get_mnist_data_loader(
        training_config['batch_size'])

    discriminator_net, generator_net = utils.get_gan(device,
                                                     GANType.CLASSIC.name)
    discriminator_opt, generator_opt = utils.get_optimizers(
        discriminator_net, generator_net)

    adversarial_loss = nn.BCELoss()
    real_image_gt = torch.ones((training_config['batch_size'], 1),
                               device=device)
    fake_image_gt = torch.zeros((training_config['batch_size'], 1),
                                device=device)

    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()

    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):
            real_images = real_images.to(device)

            # Train discriminator
            discriminator_opt.zero_grad()

            real_discriminator_loss = adversarial_loss(
                discriminator_net(real_images), real_image_gt)

            fake_images = generator_net(
                utils.get_gaussian_latent_batch(training_config['batch_size'],
                                                device))
            fake_images_predictions = discriminator_net(fake_images.detach())
            fake_discriminator_loss = adversarial_loss(fake_images_predictions,
                                                       fake_image_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()
            discriminator_opt.step()

            # Train generator
            generator_opt.zero_grad()

            generated_images_prediction = discriminator_net(
                generator_net(
                    utils.get_gaussian_latent_batch(
                        training_config['batch_size'], device)))

            generator_loss = adversarial_loss(generated_images_prediction,
                                              real_image_gt)

            generator_loss.backward()
            generator_opt.step()

            # Logging and checkpoint creation
            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars(
                    'Losses/g-and-d', {
                        'g': generator_loss.item(),
                        'd': discriminator_loss.item()
                    },
                    len(mnist_data_loader) * epoch + batch_idx + 1)

                if training_config[
                        'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                            'debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(
                            scale_factor=2,
                            mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(
                            log_generated_images_resized,
                            nrow=int(np.sqrt(ref_batch_size)),
                            normalize=True)
                        writer.add_image(
                            'intermediate generated imagery',
                            intermediate_imagery_grid,
                            len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_idx % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]'
                )

            # Save intermediate generator images
            if training_config[
                    'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                        'debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(
                        scale_factor=2, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized,
                               os.path.join(training_config['debug_path'],
                                            f'{str(img_cnt).zfill(6)}.jpg'),
                               nrow=int(np.sqrt(ref_batch_size)),
                               normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (
                    epoch + 1
            ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"Classic_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(
                    utils.get_training_state(generator_net,
                                             GANType.CLASSIC.name),
                    os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    torch.save(utils.get_training_state(generator_net, GANType.CLASSIC.name),
               os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
Beispiel #4
0
                        type=int,
                        default=128,
                        help="Size of batch for GAN inference.")
    parser.add_argument("--num_images",
                        type=int,
                        required=True,
                        help="How many images to generate.")
    args = parser.parse_args()

    data_dir = "./data/fake/fake"
    os.makedirs(data_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # since norm-layers are frozen in eval, we can use DCGAn for both confs
    G, _ = utils.get_gan(GANType.DCGAN, device, 3)
    G.load_state_dict(torch.load(args.ckpt_path))
    G.eval()

    img_cnt = 0
    with torch.no_grad():
        while img_cnt < args.num_images:
            noise = utils.get_latent_batch(args.batch_size, device)
            fake_imgs = G(noise).cpu()

            for idx, fake_img in enumerate(fake_imgs):
                path = os.path.join(data_dir, str(img_cnt + idx) + ".png")
                save_image(fake_img, path, normalize=True)

            img_cnt += args.batch_size
Beispiel #5
0
def train_vanilla_gan(training_config):
    writer = SummaryWriter()  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU

    # Prepare MNIST data loader (it will download MNIST the first time you run it)
    mnist_data_loader = utils.get_mnist_data_loader(training_config['batch_size'])

    # Fetch feed-forward nets (place them on GPU if present) and optimizers which will tweak their weights
    discriminator_net, generator_net = utils.get_gan(device, GANType.VANILLA.name)
    discriminator_opt, generator_opt = utils.get_optimizers(discriminator_net, generator_net)

    # 1s will configure BCELoss into -log(x) whereas 0s will configure it to -log(1-x)
    # So that means we can effectively use binary cross-entropy loss to achieve adversarial loss!
    adversarial_loss = nn.BCELoss()
    real_images_gt = torch.ones((training_config['batch_size'], 1), device=device)
    fake_images_gt = torch.zeros((training_config['batch_size'], 1), device=device)

    # For logging purposes
    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)  # Track G's quality during training
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()  # start measuring time

    # GAN training loop, it's always smart to first train the discriminator so as to avoid mode collapse!
    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):

            real_images = real_images.to(device)  # Place imagery on GPU (if present)

            #
            # Train discriminator: maximize V = log(D(x)) + log(1-D(G(z))) or equivalently minimize -V
            # Note: D = discriminator, x = real images, G = generator, z = latent Gaussian vectors, G(z) = fake images
            #

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            discriminator_opt.zero_grad()

            # -log(D(x)) <- we minimize this by making D(x)/discriminator_net(real_images) as close to 1 as possible
            real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt)

            # G(z) | G == generator_net and z == utils.get_gaussian_latent_batch(batch_size, device)
            fake_images = generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device))
            # D(G(z)), we call detach() so that we don't calculate gradients for the generator during backward()
            fake_images_predictions = discriminator_net(fake_images.detach())
            # -log(1 - D(G(z))) <- we minimize this by making D(G(z)) as close to 0 as possible
            fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()  # this will populate .grad vars in the discriminator net
            discriminator_opt.step()  # perform D weights update according to optimizer's strategy

            #
            # Train generator: minimize V1 = log(1-D(G(z))) or equivalently maximize V2 = log(D(G(z))) (or min of -V2)
            # The original expression (V1) had problems with diminishing gradients for G when D is too good.
            #

            # if you want to cause mode collapse probably the easiest way to do that would be to add "for i in range(n)"
            # here (simply train G more frequent than D), n = 10 worked for me other values will also work - experiment.

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            generator_opt.zero_grad()

            # D(G(z)) (see above for explanations)
            generated_images_predictions = discriminator_net(generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device)))
            # By placing real_images_gt here we minimize -log(D(G(z))) which happens when D approaches 1
            # i.e. we're tricking D into thinking that these generated images are real!
            generator_loss = adversarial_loss(generated_images_predictions, real_images_gt)

            generator_loss.backward()  # this will populate .grad vars in the G net (also in D but we won't use those)
            generator_opt.step()  # perform G weights update according to optimizer's strategy

            #
            # Logging and checkpoint creation
            #

            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars('losses/g-and-d', {'g': generator_loss.item(), 'd': discriminator_loss.item()}, len(mnist_data_loader) * epoch + batch_idx + 1)
                # Save debug imagery to tensorboard also (some redundancy but it may be more beginner-friendly)
                if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(scale_factor=2, mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                        writer.add_image('intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0:
                print(f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]')

            # Save intermediate generator images (more convenient like this than through tensorboard)
            if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    # Save the latest generator in the binaries directory
    torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))