os.path.join(paths.log_path, "log.txt")), logging.StreamHandler() ]) # fetching device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.debug(f"{device}") # training configuration with open("train_config.json", "r") as f: train_config = json.load(f) args = Namespace(**train_config) # initializing networks and optimizers if args.type == "DCGAN": G, D = utils.get_gan(GANType.DCGAN, device) G_optim, D_optim = utils.get_optimizers(G, D) elif args.type == "SN_DCGAN": G, D = utils.get_gan(GANType.SN_DCGAN, device, args.n_power_iterations) G_optim, D_optim = utils.get_optimizers(G, D) # initializing loader for data data_loader = utils.get_data_loader(args.batch_size, args.img_size) # setting up loss and GT adversarial_loss = nn.BCELoss() real_gt, fake_gt = utils.get_gt(args.batch_size, device) # for logging log_batch_size = 25 log_noise = utils.get_latent_batch(log_batch_size, device)
def generate_new_images(model_name, cgan_digit=None, generation_mode=True, slerp=True, a=None, b=None, should_display=True): """ Generate imagery using pre-trained generator (using vanilla_generator_000000.pth by default) Args: model_name (str): model name you want to use (default lookup location is BINARIES_PATH). cgan_digit (int): if specified generate that exact digit. generation_mode (enum): generate a single image from a random vector, interpolate between the 2 chosen latent vectors, or perform arithmetic over latent vectors (note: not every mode is supported for every model type) slerp (bool): if True use spherical interpolation otherwise use linear interpolation. a, b (numpy arrays): latent vectors, if set to None you'll be prompted to choose images you like, and use corresponding latent vectors instead. should_display (bool): Display the generated images before saving them. """ model_path = os.path.join(BINARIES_PATH, model_name) assert os.path.exists( model_path ), f'Could not find the model {model_path}. You first need to train your generator.' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Prepare the correct (vanilla, cGAN, DCGAN, ...) model, load the weights and put the model into evaluation mode model_state = torch.load(model_path) gan_type = model_state["gan_type"] print(f'Found {gan_type} GAN!') _, generator = utils.get_gan(device, gan_type) generator.load_state_dict(model_state["state_dict"], strict=True) generator.eval() # Generate a single image, save it and potentially display it if generation_mode == GenerationMode.SINGLE_IMAGE: generated_imgs_path = os.path.join(DATA_DIR_PATH, 'generated_imagery') os.makedirs(generated_imgs_path, exist_ok=True) generated_img, _ = generate_from_random_latent_vector( generator, cgan_digit if gan_type == GANType.CGAN.name else None) utils.save_and_maybe_display_image(generated_imgs_path, generated_img, should_display=should_display) # Pick 2 images you like between which you'd like to interpolate (by typing 'y' into console) elif generation_mode == GenerationMode.INTERPOLATION: assert gan_type == GANType.VANILLA.name or gan_type == GANType.DCGAN.name, f'Got {gan_type} but only VANILLA/DCGAN are supported for the interpolation mode.' interpolation_name = "spherical" if slerp else "linear" interpolation_fn = spherical_interpolation if slerp else linear_interpolation grid_interpolated_imgs_path = os.path.join( DATA_DIR_PATH, 'interpolated_imagery') # combined results dir decomposed_interpolated_imgs_path = os.path.join( grid_interpolated_imgs_path, f'tmp_{gan_type}_{interpolation_name}_dump' ) # dump separate results if os.path.exists(decomposed_interpolated_imgs_path): shutil.rmtree(decomposed_interpolated_imgs_path) os.makedirs(grid_interpolated_imgs_path, exist_ok=True) os.makedirs(decomposed_interpolated_imgs_path, exist_ok=True) latent_vector_a, latent_vector_b = [None, None] # If a and b were not specified loop until the user picked the 2 images he/she likes. found_good_vectors_flag = False if a is None or b is None: while not found_good_vectors_flag: generated_img, latent_vector = generate_from_random_latent_vector( generator) plt.imshow(generated_img) plt.title('Do you like this image?') plt.show() user_input = input( "Do you like this generated image? [y for yes]:") if user_input == 'y': if latent_vector_a is None: latent_vector_a = latent_vector print('Saved the first latent vector.') elif latent_vector_b is None: latent_vector_b = latent_vector print('Saved the second latent vector.') found_good_vectors_flag = True else: print('Well lets generate a new one!') continue else: print( 'Skipping latent vectors selection section and using cached ones.' ) latent_vector_a, latent_vector_b = [a, b] # Cache latent vectors if a is None or b is None: np.save(os.path.join(grid_interpolated_imgs_path, 'a.npy'), latent_vector_a) np.save(os.path.join(grid_interpolated_imgs_path, 'b.npy'), latent_vector_b) print(f'Lets do some {interpolation_name} interpolation!') interpolation_resolution = 47 # number of images between the vectors a and b num_interpolated_imgs = interpolation_resolution + 2 # + 2 so that we include a and b generated_imgs = [] for i in range(num_interpolated_imgs): t = i / (num_interpolated_imgs - 1) # goes from 0. to 1. current_latent_vector = interpolation_fn(t, latent_vector_a, latent_vector_b) generated_img = generate_from_specified_numpy_latent_vector( generator, current_latent_vector) print(f'Generated image [{i+1}/{num_interpolated_imgs}].') utils.save_and_maybe_display_image( decomposed_interpolated_imgs_path, generated_img, should_display=should_display) # Move from channel last to channel first (CHW->HWC), PyTorch's save_image function expects BCHW format generated_imgs.append( torch.tensor(np.moveaxis(generated_img, 2, 0))) interpolated_block_img = torch.stack(generated_imgs) interpolated_block_img = nn.Upsample( scale_factor=2.5, mode='nearest')(interpolated_block_img) save_image( interpolated_block_img, os.path.join( grid_interpolated_imgs_path, utils.get_available_file_name(grid_interpolated_imgs_path)), nrow=int(np.sqrt(num_interpolated_imgs))) elif generation_mode == GenerationMode.VECTOR_ARITHMETIC: assert gan_type == GANType.DCGAN.name, f'Got {gan_type} but only DCGAN is supported for arithmetic mode.' # Generate num_options face images and create a grid image from them num_options = 100 generated_imgs = [] latent_vectors = [] padding = 2 for i in range(num_options): generated_img, latent_vector = generate_from_random_latent_vector( generator) generated_imgs.append( torch.tensor(np.moveaxis(generated_img, 2, 0))) # make_grid expects CHW format latent_vectors.append(latent_vector) stacked_tensor_imgs = torch.stack(generated_imgs) final_tensor_img = make_grid(stacked_tensor_imgs, nrow=int(np.sqrt(num_options)), padding=padding) display_img = np.moveaxis(final_tensor_img.numpy(), 0, 2) # For storing latent vectors num_of_vectors_per_category = 3 happy_woman_latent_vectors = [] neutral_woman_latent_vectors = [] neutral_man_latent_vectors = [] # Make it easy - by clicking on the plot you pick the image. def onclick(event): if event.dblclick: pass else: # single click if event.button == 1: # left click x_coord = event.xdata y_coord = event.ydata column = int(x_coord / (64 + padding)) row = int(y_coord / (64 + padding)) # Store latent vector corresponding to the image that the user clicked on. if len(happy_woman_latent_vectors ) < num_of_vectors_per_category: happy_woman_latent_vectors.append( latent_vectors[10 * row + column]) print( f'Picked image row={row}, column={column} as {len(happy_woman_latent_vectors)}. happy woman.' ) elif len(neutral_woman_latent_vectors ) < num_of_vectors_per_category: neutral_woman_latent_vectors.append( latent_vectors[10 * row + column]) print( f'Picked image row={row}, column={column} as {len(neutral_woman_latent_vectors)}. neutral woman.' ) elif len(neutral_man_latent_vectors ) < num_of_vectors_per_category: neutral_man_latent_vectors.append( latent_vectors[10 * row + column]) print( f'Picked image row={row}, column={column} as {len(neutral_man_latent_vectors)}. neutral man.' ) else: plt.close() plt.figure(figsize=(10, 10)) plt.imshow(display_img) # This is just an example you could also pick 3 neutral woman images with sunglasses, etc. plt.title( 'Click on 3 happy women, 3 neutral women and \n 3 neutral men images (order matters!)' ) cid = plt.gcf().canvas.mpl_connect('button_press_event', onclick) plt.show() plt.gcf().canvas.mpl_disconnect(cid) print('Done choosing images.') # Calculate the average latent vector for every category (happy woman, neutral woman, neutral man) happy_woman_avg_latent_vector = np.mean( np.array(happy_woman_latent_vectors), axis=0) neutral_woman_avg_latent_vector = np.mean( np.array(neutral_woman_latent_vectors), axis=0) neutral_man_avg_latent_vector = np.mean( np.array(neutral_man_latent_vectors), axis=0) # By subtracting neutral woman from the happy woman we capture the "vector of smiling". Adding that vector # to a neutral man we get a happy man's latent vector! Our latent space has amazingly beautiful structure! happy_man_latent_vector = neutral_man_avg_latent_vector + ( happy_woman_avg_latent_vector - neutral_woman_avg_latent_vector) # Generate images from these latent vectors happy_women_imgs = np.hstack([ generate_from_specified_numpy_latent_vector(generator, v) for v in happy_woman_latent_vectors ]) neutral_women_imgs = np.hstack([ generate_from_specified_numpy_latent_vector(generator, v) for v in neutral_woman_latent_vectors ]) neutral_men_imgs = np.hstack([ generate_from_specified_numpy_latent_vector(generator, v) for v in neutral_man_latent_vectors ]) happy_woman_avg_img = generate_from_specified_numpy_latent_vector( generator, happy_woman_avg_latent_vector) neutral_woman_avg_img = generate_from_specified_numpy_latent_vector( generator, neutral_woman_avg_latent_vector) neutral_man_avg_img = generate_from_specified_numpy_latent_vector( generator, neutral_man_avg_latent_vector) happy_man_img = generate_from_specified_numpy_latent_vector( generator, happy_man_latent_vector) display_vector_arithmetic_results([ happy_women_imgs, happy_woman_avg_img, neutral_women_imgs, neutral_woman_avg_img, neutral_men_imgs, neutral_man_avg_img, happy_man_img ]) else: raise Exception(f'Generation mode not yet supported.')
def train_gan(training_config): writer = SummaryWriter() device = torch.device("cpu") # Download MNIST dataset in the directory data mnist_data_loader = utils.get_mnist_data_loader( training_config['batch_size']) discriminator_net, generator_net = utils.get_gan(device, GANType.CLASSIC.name) discriminator_opt, generator_opt = utils.get_optimizers( discriminator_net, generator_net) adversarial_loss = nn.BCELoss() real_image_gt = torch.ones((training_config['batch_size'], 1), device=device) fake_image_gt = torch.zeros((training_config['batch_size'], 1), device=device) ref_batch_size = 16 ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device) discriminator_loss_values = [] generator_loss_values = [] img_cnt = 0 ts = time.time() utils.print_training_info_to_console(training_config) for epoch in range(training_config['num_epochs']): for batch_idx, (real_images, _) in enumerate(mnist_data_loader): real_images = real_images.to(device) # Train discriminator discriminator_opt.zero_grad() real_discriminator_loss = adversarial_loss( discriminator_net(real_images), real_image_gt) fake_images = generator_net( utils.get_gaussian_latent_batch(training_config['batch_size'], device)) fake_images_predictions = discriminator_net(fake_images.detach()) fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_image_gt) discriminator_loss = real_discriminator_loss + fake_discriminator_loss discriminator_loss.backward() discriminator_opt.step() # Train generator generator_opt.zero_grad() generated_images_prediction = discriminator_net( generator_net( utils.get_gaussian_latent_batch( training_config['batch_size'], device))) generator_loss = adversarial_loss(generated_images_prediction, real_image_gt) generator_loss.backward() generator_opt.step() # Logging and checkpoint creation generator_loss_values.append(generator_loss.item()) discriminator_loss_values.append(discriminator_loss.item()) if training_config['enable_tensorboard']: writer.add_scalars( 'Losses/g-and-d', { 'g': generator_loss.item(), 'd': discriminator_loss.item() }, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config[ 'debug_imagery_log_freq'] is not None and batch_idx % training_config[ 'debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample( scale_factor=2, mode='nearest')(log_generated_images) intermediate_imagery_grid = make_grid( log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True) writer.add_image( 'intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config[ 'console_log_freq'] is not None and batch_idx % training_config[ 'console_log_freq'] == 0: print( f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]' ) # Save intermediate generator images if training_config[ 'debug_imagery_log_freq'] is not None and batch_idx % training_config[ 'debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample( scale_factor=2, mode='nearest')(log_generated_images) save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True) img_cnt += 1 # Save generator checkpoint if training_config['checkpoint_freq'] is not None and ( epoch + 1 ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f"Classic_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth" torch.save( utils.get_training_state(generator_net, GANType.CLASSIC.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) torch.save(utils.get_training_state(generator_net, GANType.CLASSIC.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
type=int, default=128, help="Size of batch for GAN inference.") parser.add_argument("--num_images", type=int, required=True, help="How many images to generate.") args = parser.parse_args() data_dir = "./data/fake/fake" os.makedirs(data_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # since norm-layers are frozen in eval, we can use DCGAn for both confs G, _ = utils.get_gan(GANType.DCGAN, device, 3) G.load_state_dict(torch.load(args.ckpt_path)) G.eval() img_cnt = 0 with torch.no_grad(): while img_cnt < args.num_images: noise = utils.get_latent_batch(args.batch_size, device) fake_imgs = G(noise).cpu() for idx, fake_img in enumerate(fake_imgs): path = os.path.join(data_dir, str(img_cnt + idx) + ".png") save_image(fake_img, path, normalize=True) img_cnt += args.batch_size
def train_vanilla_gan(training_config): writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU # Prepare MNIST data loader (it will download MNIST the first time you run it) mnist_data_loader = utils.get_mnist_data_loader(training_config['batch_size']) # Fetch feed-forward nets (place them on GPU if present) and optimizers which will tweak their weights discriminator_net, generator_net = utils.get_gan(device, GANType.VANILLA.name) discriminator_opt, generator_opt = utils.get_optimizers(discriminator_net, generator_net) # 1s will configure BCELoss into -log(x) whereas 0s will configure it to -log(1-x) # So that means we can effectively use binary cross-entropy loss to achieve adversarial loss! adversarial_loss = nn.BCELoss() real_images_gt = torch.ones((training_config['batch_size'], 1), device=device) fake_images_gt = torch.zeros((training_config['batch_size'], 1), device=device) # For logging purposes ref_batch_size = 16 ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device) # Track G's quality during training discriminator_loss_values = [] generator_loss_values = [] img_cnt = 0 ts = time.time() # start measuring time # GAN training loop, it's always smart to first train the discriminator so as to avoid mode collapse! utils.print_training_info_to_console(training_config) for epoch in range(training_config['num_epochs']): for batch_idx, (real_images, _) in enumerate(mnist_data_loader): real_images = real_images.to(device) # Place imagery on GPU (if present) # # Train discriminator: maximize V = log(D(x)) + log(1-D(G(z))) or equivalently minimize -V # Note: D = discriminator, x = real images, G = generator, z = latent Gaussian vectors, G(z) = fake images # # Zero out .grad variables in discriminator network (otherwise we would have corrupt results) discriminator_opt.zero_grad() # -log(D(x)) <- we minimize this by making D(x)/discriminator_net(real_images) as close to 1 as possible real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt) # G(z) | G == generator_net and z == utils.get_gaussian_latent_batch(batch_size, device) fake_images = generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device)) # D(G(z)), we call detach() so that we don't calculate gradients for the generator during backward() fake_images_predictions = discriminator_net(fake_images.detach()) # -log(1 - D(G(z))) <- we minimize this by making D(G(z)) as close to 0 as possible fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt) discriminator_loss = real_discriminator_loss + fake_discriminator_loss discriminator_loss.backward() # this will populate .grad vars in the discriminator net discriminator_opt.step() # perform D weights update according to optimizer's strategy # # Train generator: minimize V1 = log(1-D(G(z))) or equivalently maximize V2 = log(D(G(z))) (or min of -V2) # The original expression (V1) had problems with diminishing gradients for G when D is too good. # # if you want to cause mode collapse probably the easiest way to do that would be to add "for i in range(n)" # here (simply train G more frequent than D), n = 10 worked for me other values will also work - experiment. # Zero out .grad variables in discriminator network (otherwise we would have corrupt results) generator_opt.zero_grad() # D(G(z)) (see above for explanations) generated_images_predictions = discriminator_net(generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device))) # By placing real_images_gt here we minimize -log(D(G(z))) which happens when D approaches 1 # i.e. we're tricking D into thinking that these generated images are real! generator_loss = adversarial_loss(generated_images_predictions, real_images_gt) generator_loss.backward() # this will populate .grad vars in the G net (also in D but we won't use those) generator_opt.step() # perform G weights update according to optimizer's strategy # # Logging and checkpoint creation # generator_loss_values.append(generator_loss.item()) discriminator_loss_values.append(discriminator_loss.item()) if training_config['enable_tensorboard']: writer.add_scalars('losses/g-and-d', {'g': generator_loss.item(), 'd': discriminator_loss.item()}, len(mnist_data_loader) * epoch + batch_idx + 1) # Save debug imagery to tensorboard also (some redundancy but it may be more beginner-friendly) if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample(scale_factor=2, mode='nearest')(log_generated_images) intermediate_imagery_grid = make_grid(log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True) writer.add_image('intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0: print(f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]') # Save intermediate generator images (more convenient like this than through tensorboard) if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images) save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True) img_cnt += 1 # Save generator checkpoint if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth" torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) # Save the latest generator in the binaries directory torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))