def eval_hpl(ae_sd_path, zgen_sd_path, zdis_sd_path): """ Evaluate on held-out set Setting *_sd_path=None will evaluate random intialization for that net """ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if not os.path.isdir(DATASET_PATH): os.mkdir(DATASET_PATH) mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH) test_batch, _ = next( iter(mnist_test_loader)) # Used for visual checkpoints of progress z_size = 128 autoencoder = AE(input_shape=784, z_size=z_size).to(device) if ae_sd_path: autoencoder.load_state_dict( torch.load(os.path.join(MODEL_PATH, ae_sd_path), map_location=device)) zgen = LatentGeneratorMLP(z_size, 32).to(device) if zgen_sd_path: zgen.load_state_dict( torch.load(os.path.join(MODEL_PATH, zgen_sd_path), map_location=device)) zdis = LatentDiscriminatorMLP(z_size, 32).to(device) if zdis_sd_path: zdis.load_state_dict( torch.load(os.path.join(MODEL_PATH, zdis_sd_path), map_location=device)) print(f"Evaluating HPL on MNIST digits held-out set") autoencoder.eval() zgen.eval() zdis.eval() Dloss = 0 Gloss = 0 for image_tensors, _ in tqdm(mnist_test_loader): ones_target = torch.ones(len(image_tensors), 1).to(device) zeros_target = torch.zeros(len(image_tensors), 1).to(device) # Compute discriminator loss with real and fake latent codes input_features = image_tensors.view(-1, 784).to(device) input_features = (input_features * 2) - 1 real_codes = autoencoder.encode(input_features) real_codes = real_codes.detach() real_scores = zdis(real_codes) real_loss = F.binary_cross_entropy_with_logits( real_scores, 0.9 * ones_target) # Smoothed "real" label prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes.detach()) fake_loss = F.binary_cross_entropy_with_logits(fake_scores, zeros_target) zdis_loss = real_loss + fake_loss Dloss += zdis_loss.item() # Compute generator loss for maximizing fooling of zdis prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes) zgen_loss = F.binary_cross_entropy_with_logits(fake_scores, ones_target) Gloss += zgen_loss.item() Dloss /= len(mnist_test_loader) Gloss /= len(mnist_test_loader) print(f"Held-out D-loss = {Dloss:.6f}, G-loss = {Gloss:.6f}") # Display reconstructions on first batch prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testZ = zgen(prior) testX = (autoencoder.decode(testZ) + 1) * 0.5 save_images(os.path.join(RESULT_PATH, "HPL-grid-eval"), testX)
def train_gan(epochs): device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if not os.path.isdir(DATASET_PATH): os.mkdir(DATASET_PATH) mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH) mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH) test_batch, _ = next(iter(mnist_test_loader)) # Used for visual checkpoints of progress z_size = 128 gen = PixelGeneratorMLP(z_size, 32, 784).to(device) gen_opt = optim.Adam(gen.parameters(), lr=1e-3) dis = PixelDiscriminatorMLP(784, 32, 1).to(device) dis_opt = optim.Adam(dis.parameters(), lr=1e-3) print(gen) print(dis) with torch.no_grad(): testZ = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testX = (gen(testZ) + 1) * 0.5 save_images(os.path.join(RESULT_PATH, "GAN-grid-0"), testX) print(f"Training GAN for {epochs} epochs on MNIST digits") for epoch in range(epochs): Dloss = 0 Gloss = 0 for image_tensors, _ in tqdm(mnist_train_loader): ones_target = torch.ones(len(image_tensors), 1).to(device) zeros_target = torch.zeros(len(image_tensors), 1).to(device) # Compute discriminator loss with real and fake latent codes real_images = image_tensors.view(-1, 784).to(device) real_images = (real_images * 2) - 1 # real_images += torch.normal(torch.zeros(real_images.shape), torch.ones(real_images.shape) * (1. / (4. * 255.))).to(device) # add a bit of noise # real_images = torch.clamp(real_images, -1, 1) real_scores = dis(real_images) real_loss = F.binary_cross_entropy_with_logits(real_scores, 0.9 * ones_target) # Smoothed "real" label prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_images = gen(prior) fake_scores = dis(fake_images.detach()) fake_loss = F.binary_cross_entropy_with_logits(fake_scores, zeros_target) dis_loss = real_loss + fake_loss Dloss += dis_loss.item() dis_opt.zero_grad() dis_loss.backward() dis_opt.step() # Compute generator loss for maximizing fooling of dis prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_images = gen(prior) fake_scores = dis(fake_images) gen_loss = F.binary_cross_entropy_with_logits(fake_scores, ones_target) Gloss += gen_loss.item() # print(gen_loss.item()) gen_opt.zero_grad() gen_loss.backward() gen_opt.step() Dloss /= len(mnist_train_loader) Gloss /= len(mnist_train_loader) with torch.no_grad(): testZ = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testX = (gen(testZ) + 1) * 0.5 meanX, stdX = testX.mean(), testX.std() print("epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, mean-X = {:.3f}, std-X = {:.3f}".format(epoch + 1, epochs, Dloss, Gloss, meanX, stdX)) save_images(os.path.join(RESULT_PATH, f"GAN-grid-{epoch+1}"), testX) if not os.path.isdir(MODEL_PATH): os.mkdir(MODEL_PATH) torch.save(gen.state_dict(), os.path.join(MODEL_PATH, f"gen{epochs}.pt")) torch.save(dis.state_dict(), os.path.join(MODEL_PATH, f"dis{epochs}.pt"))
def train_hpl_from_GAN(epochs, gen_sd_path): """ Trains latent-space generator which attempts to map the prior Z distribution to match the unconditional latent distribution of a GAN generator. Starts with a pre-trained generator (which is frozen and unchanged). Learns pixel-space encoder for producing the "real" Z vectors for the latent discriminator's loss. """ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if not os.path.isdir(DATASET_PATH): os.mkdir(DATASET_PATH) mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH) mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH) test_batch, _ = next( iter(mnist_test_loader)) # Used for visual checkpoints of progress z_size = 128 # Pixel-space GANs xgen = PixelGeneratorMLP(z_size, 32, 784).to(device) if gen_sd_path: xgen.load_state_dict( torch.load(os.path.join(MODEL_PATH, gen_sd_path), map_location=device)) xgen.eval( ) # stops dropout, otherwise it may be harder for encoder to learn mapping from X back to Z encoder = Encoder(input_shape=784, z_size=z_size).to(device) encoder_opt = optim.Adam(encoder.parameters(), lr=1e-3) # Latent-space GANs zgen = LatentGeneratorMLP(z_size, 32).to(device) zgen_opt = optim.Adam(zgen.parameters(), lr=1e-3) zdis = LatentDiscriminatorMLP(z_size, 32).to(device) zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3) print(zgen) print(zdis) with torch.no_grad(): prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testZ = zgen(prior) testX = (xgen(testZ) + 1) * 0.5 save_images(os.path.join(RESULT_PATH, f"HPL-grid-0"), testX) print(f"Training HPL transfer mapping for {epochs} epochs on MNIST digits") for epoch in range(epochs): Eloss = 0 Dloss = 0 Gloss = 0 for image_tensors, _ in tqdm(mnist_train_loader): ones_target = torch.ones(len(image_tensors), 1).to(device) zeros_target = torch.zeros(len(image_tensors), 1).to(device) # Compute discriminator loss with real and fake latent codes input_features = image_tensors.view(-1, 784).to(device) input_features = (input_features * 2) - 1 real_codes = encoder(input_features) real_codes = real_codes.detach() real_scores = zdis(real_codes) real_loss = F.binary_cross_entropy_with_logits( real_scores, 0.9 * ones_target) # Smoothed "real" label prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes.detach()) fake_loss = F.binary_cross_entropy_with_logits( fake_scores, zeros_target) zdis_loss = real_loss + fake_loss Dloss += zdis_loss.item() zdis_opt.zero_grad() zdis_loss.backward() zdis_opt.step() # Compute generator loss for maximizing fooling of zdis prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes) zgen_loss = F.binary_cross_entropy_with_logits( fake_scores, ones_target) Gloss += zgen_loss.item() zgen_opt.zero_grad() zgen_loss.backward() zgen_opt.step() # Compute encoder loss for matching Z input to pixel-space generator prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_images = xgen(prior) prior_recon = encoder(fake_images.detach()) encoder_loss = F.mse_loss(prior_recon, prior) Eloss += encoder_loss.item() encoder_opt.zero_grad() encoder_loss.backward() encoder_opt.step() Dloss /= len(mnist_train_loader) Gloss /= len(mnist_train_loader) Eloss /= len(mnist_train_loader) with torch.no_grad(): prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testZ = zgen(prior) testX = (xgen(testZ) + 1) * 0.5 meanZ, stdZ = testZ.mean(), testZ.std() print( "epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, E-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}" .format(epoch + 1, epochs, Dloss, Gloss, Eloss, meanZ, stdZ)) save_images(os.path.join(RESULT_PATH, f"HPL-grid-{epoch+1}"), testX) if not os.path.isdir(MODEL_PATH): os.mkdir(MODEL_PATH) torch.save(zgen.state_dict(), os.path.join(MODEL_PATH, f"zgen{epochs}.pt")) torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"zdis{epochs}.pt")) torch.save(encoder.state_dict(), os.path.join(MODEL_PATH, f"enc{epochs}.pt"))
def train_hpl(epochs, state_dict_path_AE): """ Trains latent-space generator which attempts to map the prior Z distribution to match the latent distribution of an AutoEncoder. Starts with a pre-trained AutoEncoder (which is frozen and unchanged). """ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") if not os.path.isdir(DATASET_PATH): os.mkdir(DATASET_PATH) mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH) mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH) test_batch, _ = next( iter(mnist_test_loader)) # Used for visual checkpoints of progress z_size = 128 autoencoder = AE(input_shape=784, z_size=z_size).to(device) if state_dict_path_AE: autoencoder.load_state_dict( torch.load(os.path.join(MODEL_PATH, state_dict_path_AE), map_location=device)) zgen = LatentGeneratorMLP(z_size, 32).to(device) zgen_opt = optim.Adam(zgen.parameters(), lr=1e-3) zdis = LatentDiscriminatorMLP(z_size, 32).to(device) zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3) print(zgen) print(zdis) with torch.no_grad(): prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testZ = zgen(prior) testX = (autoencoder.decode(testZ) + 1) * 0.5 save_images(os.path.join(RESULT_PATH, "HPL-grid-0"), testX) print(f"Training HPL transfer mapping for {epochs} epochs on MNIST digits") for epoch in range(epochs): Dloss = 0 Gloss = 0 for image_tensors, _ in tqdm(mnist_train_loader): ones_target = torch.ones(len(image_tensors), 1).to(device) zeros_target = torch.zeros(len(image_tensors), 1).to(device) # Compute discriminator loss with real and fake latent codes input_features = image_tensors.view(-1, 784).to(device) input_features = (input_features * 2) - 1 real_codes = autoencoder.encode(input_features) real_codes = real_codes.detach() real_scores = zdis(real_codes) real_loss = F.binary_cross_entropy_with_logits( real_scores, 0.9 * ones_target) # Smoothed "real" label prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes.detach()) fake_loss = F.binary_cross_entropy_with_logits( fake_scores, zeros_target) zdis_loss = real_loss + fake_loss Dloss += zdis_loss.item() zdis_opt.zero_grad() zdis_loss.backward() zdis_opt.step() # Compute generator loss for maximizing fooling of zdis prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1 fake_codes = zgen(prior) fake_scores = zdis(fake_codes) zgen_loss = F.binary_cross_entropy_with_logits( fake_scores, ones_target) Gloss += zgen_loss.item() zgen_opt.zero_grad() zgen_loss.backward() zgen_opt.step() Dloss /= len(mnist_train_loader) Gloss /= len(mnist_train_loader) with torch.no_grad(): prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1 testZ = zgen(prior) testX = (autoencoder.decode(testZ) + 1) * 0.5 meanZ, stdZ = testZ.mean(), testZ.std() print( "epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}" .format(epoch + 1, epochs, Dloss, Gloss, meanZ, stdZ)) save_images(os.path.join(RESULT_PATH, f"HPL-grid-{epoch+1}"), testX) if not os.path.isdir(MODEL_PATH): os.mkdir(MODEL_PATH) torch.save(zgen.state_dict(), os.path.join(MODEL_PATH, f"zgen{epochs}.pt")) torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"zdis{epochs}.pt"))
def train_adversarial_autoencoder(epochs): """ Note "real_codes" and "fake_codes" are reversed compared to HPL real --> from prior ~ Uniform [0, 1) fake --> from AAE encoder (generator) """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.isdir(DATASET_PATH): os.mkdir(DATASET_PATH) mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH) mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH) test_batch, _ = next(iter(mnist_test_loader)) # Used for visual checkpoints of progress autoencoder = AE(input_shape=784).to(device) ae_opt = optim.Adam(autoencoder.parameters(), lr=1e-3) zdis = LatentDiscriminatorMLP(128).to(device) zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3) with torch.no_grad(): testZ = torch.rand(len(test_batch), 128).to(device) # Directly from random uniform distribution [0, 1) testX = autoencoder.decode(testZ) save_images(os.path.join(RESULT_PATH, "AAE-grid-0"), testX) print(f"Training AAE for {epochs} epochs on MNIST digits") for epoch in range(epochs): Dloss = 0 Gloss = 0 AEloss = 0 for image_tensors, _ in tqdm(mnist_train_loader): ones_target = torch.ones(len(image_tensors), 1).to(device) zeros_target = torch.zeros(len(image_tensors), 1).to(device) # Compute discriminator loss with real and fake latent codes input_features = image_tensors.view(-1, 784).to(device) fake_codes = autoencoder.encode(input_features) fake_scores = zdis(fake_codes.detach()) fake_loss = F.binary_cross_entropy(fake_scores, zeros_target) real_codes = torch.rand(len(image_tensors), 128).to(device) real_scores = zdis(real_codes.detach()) real_loss = F.binary_cross_entropy(real_scores, 0.9 * ones_target) zdis_loss = real_loss + fake_loss Dloss += zdis_loss.item() zdis_opt.zero_grad() zdis_loss.backward() zdis_opt.step() # Compute AAE loss = reconstruction loss + generator loss for maximizing fooling of zdis fake_scores = zdis(fake_codes) gen_loss = F.binary_cross_entropy(fake_scores, ones_target) Gloss += gen_loss.item() reconstructions = autoencoder.decode(fake_codes) recon_loss = F.mse_loss(reconstructions, input_features) AEloss += recon_loss.item() ae_loss = gen_loss + recon_loss ae_opt.zero_grad() ae_loss.backward() ae_opt.step() num_batch = len(mnist_train_loader) Dloss /= num_batch Gloss /= num_batch AEloss /= num_batch with torch.no_grad(): prior = torch.rand(len(test_batch), 128).to(device) # Directly from random uniform distribution [0, 1) testX = autoencoder.decode(prior) test_features = test_batch.view(-1, 784).to(device) testZ = autoencoder.encode(test_features) meanZ, stdZ = testZ.mean(), testZ.std() print("epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, AE-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}".format(epoch + 1, epochs, Dloss, Gloss, AEloss, meanZ, stdZ)) save_images(os.path.join(RESULT_PATH, f"AAE-grid-{epoch+1}"), testX) if not os.path.isdir(MODEL_PATH): os.mkdir(MODEL_PATH) torch.save(autoencoder.state_dict(), os.path.join(MODEL_PATH, f"aae{epochs}.pt")) torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"pzdis{epochs}.pt"))