Exemplo n.º 1
0
def train_autoencoder(epochs):
    """
    Adapted from: https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH)

    mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH)
    test_batch, _ = next(
        iter(mnist_test_loader))  # Used for visual checkpoints of progress
    test_input = test_batch.view(-1, 784).to(device)
    test_input = (test_input * 2) - 1

    autoencoder = AE(input_shape=784, z_size=128).to(device)
    optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3)

    with torch.no_grad():
        reconstructions = autoencoder(test_input)
        reconstructions = (reconstructions + 1) * 0.5
    save_reconstructions(os.path.join(RESULT_PATH, "AE-grid-0"), test_batch,
                         reconstructions)

    print(f"Training for {epochs} epochs on MNIST digits")
    for epoch in range(epochs):
        loss = 0
        for image_tensors, _ in tqdm(mnist_train_loader):
            # reshape mini-batch data to [N, 784] matrix
            # load it to the active device
            input_features = image_tensors.view(-1, 784).to(device)
            input_features = (input_features * 2) - 1

            # reset the gradients back to zero
            # PyTorch accumulates gradients on subsequent backward passes
            optimizer.zero_grad()

            # compute reconstructions
            reconstructions = autoencoder(input_features)

            # compute training reconstruction loss
            train_loss = F.mse_loss(reconstructions, input_features)

            # compute accumulated gradients
            train_loss.backward()

            # perform parameter update based on current gradients
            optimizer.step()

            # add the mini-batch training loss to epoch loss
            loss += train_loss.item()

        # compute the epoch training loss
        loss = loss / len(mnist_train_loader)

        # display the epoch training loss
        print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

        # save visual results on first batch of held-out set
        with torch.no_grad():
            reconstructions = autoencoder(test_input)
            reconstructions = (reconstructions + 1) * 0.5
        save_reconstructions(os.path.join(RESULT_PATH, f"AE-grid-{epoch+1}"),
                             test_batch, reconstructions)

    if not os.path.isdir(MODEL_PATH):
        os.mkdir(MODEL_PATH)

    torch.save(autoencoder.state_dict(),
               os.path.join(MODEL_PATH, f"ae{epochs}.pt"))
Exemplo n.º 2
0
def train_gan(epochs):
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH)
    
    mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH)
    test_batch, _ = next(iter(mnist_test_loader))  # Used for visual checkpoints of progress

    z_size = 128

    gen = PixelGeneratorMLP(z_size, 32, 784).to(device)
    gen_opt = optim.Adam(gen.parameters(), lr=1e-3)

    dis = PixelDiscriminatorMLP(784, 32, 1).to(device)
    dis_opt = optim.Adam(dis.parameters(), lr=1e-3)

    print(gen)
    print(dis)

    with torch.no_grad():
        testZ = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
        testX = (gen(testZ) + 1) * 0.5
    save_images(os.path.join(RESULT_PATH, "GAN-grid-0"), testX)
    
    print(f"Training GAN for {epochs} epochs on MNIST digits")
    for epoch in range(epochs):
        Dloss = 0
        Gloss = 0
        for image_tensors, _ in tqdm(mnist_train_loader):
            ones_target = torch.ones(len(image_tensors), 1).to(device)
            zeros_target = torch.zeros(len(image_tensors), 1).to(device)

            # Compute discriminator loss with real and fake latent codes
            real_images = image_tensors.view(-1, 784).to(device)
            real_images = (real_images * 2) - 1

            # real_images += torch.normal(torch.zeros(real_images.shape), torch.ones(real_images.shape) * (1. / (4. * 255.))).to(device)  # add a bit of noise
            # real_images = torch.clamp(real_images, -1, 1)

            real_scores = dis(real_images)
            real_loss = F.binary_cross_entropy_with_logits(real_scores, 0.9 * ones_target)  # Smoothed "real" label

            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_images = gen(prior)
            fake_scores = dis(fake_images.detach())
            fake_loss = F.binary_cross_entropy_with_logits(fake_scores, zeros_target)

            dis_loss = real_loss + fake_loss
            Dloss += dis_loss.item()

            dis_opt.zero_grad()
            dis_loss.backward()
            dis_opt.step()

            # Compute generator loss for maximizing fooling of dis
            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_images = gen(prior)
            fake_scores = dis(fake_images)
            gen_loss = F.binary_cross_entropy_with_logits(fake_scores, ones_target)
            Gloss += gen_loss.item()

            # print(gen_loss.item())

            gen_opt.zero_grad()
            gen_loss.backward()
            gen_opt.step()

        Dloss /= len(mnist_train_loader)
        Gloss /= len(mnist_train_loader)

        with torch.no_grad():
            testZ = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
            testX = (gen(testZ) + 1) * 0.5
            meanX, stdX = testX.mean(), testX.std()

        print("epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, mean-X = {:.3f}, std-X = {:.3f}".format(epoch + 1, epochs, Dloss, Gloss, meanX, stdX))

        save_images(os.path.join(RESULT_PATH, f"GAN-grid-{epoch+1}"), testX)

    if not os.path.isdir(MODEL_PATH):
        os.mkdir(MODEL_PATH)

    torch.save(gen.state_dict(), os.path.join(MODEL_PATH, f"gen{epochs}.pt"))
    torch.save(dis.state_dict(), os.path.join(MODEL_PATH, f"dis{epochs}.pt"))
Exemplo n.º 3
0
def train_hpl_from_GAN(epochs, gen_sd_path):
    """
    Trains latent-space generator which attempts to map the prior Z distribution to
    match the unconditional latent distribution of a GAN generator.

    Starts with a pre-trained generator (which is frozen and unchanged).
    Learns pixel-space encoder for producing the "real" Z vectors for the latent discriminator's loss.
    """
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH)

    mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH)
    test_batch, _ = next(
        iter(mnist_test_loader))  # Used for visual checkpoints of progress

    z_size = 128

    # Pixel-space GANs
    xgen = PixelGeneratorMLP(z_size, 32, 784).to(device)
    if gen_sd_path:
        xgen.load_state_dict(
            torch.load(os.path.join(MODEL_PATH, gen_sd_path),
                       map_location=device))
    xgen.eval(
    )  # stops dropout, otherwise it may be harder for encoder to learn mapping from X back to Z

    encoder = Encoder(input_shape=784, z_size=z_size).to(device)
    encoder_opt = optim.Adam(encoder.parameters(), lr=1e-3)

    # Latent-space GANs
    zgen = LatentGeneratorMLP(z_size, 32).to(device)
    zgen_opt = optim.Adam(zgen.parameters(), lr=1e-3)

    zdis = LatentDiscriminatorMLP(z_size, 32).to(device)
    zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3)

    print(zgen)
    print(zdis)

    with torch.no_grad():
        prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
        testZ = zgen(prior)
        testX = (xgen(testZ) + 1) * 0.5
    save_images(os.path.join(RESULT_PATH, f"HPL-grid-0"), testX)

    print(f"Training HPL transfer mapping for {epochs} epochs on MNIST digits")
    for epoch in range(epochs):
        Eloss = 0
        Dloss = 0
        Gloss = 0
        for image_tensors, _ in tqdm(mnist_train_loader):
            ones_target = torch.ones(len(image_tensors), 1).to(device)
            zeros_target = torch.zeros(len(image_tensors), 1).to(device)

            # Compute discriminator loss with real and fake latent codes
            input_features = image_tensors.view(-1, 784).to(device)
            input_features = (input_features * 2) - 1
            real_codes = encoder(input_features)
            real_codes = real_codes.detach()

            real_scores = zdis(real_codes)
            real_loss = F.binary_cross_entropy_with_logits(
                real_scores, 0.9 * ones_target)  # Smoothed "real" label

            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_codes = zgen(prior)
            fake_scores = zdis(fake_codes.detach())
            fake_loss = F.binary_cross_entropy_with_logits(
                fake_scores, zeros_target)

            zdis_loss = real_loss + fake_loss
            Dloss += zdis_loss.item()

            zdis_opt.zero_grad()
            zdis_loss.backward()
            zdis_opt.step()

            # Compute generator loss for maximizing fooling of zdis
            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_codes = zgen(prior)
            fake_scores = zdis(fake_codes)
            zgen_loss = F.binary_cross_entropy_with_logits(
                fake_scores, ones_target)
            Gloss += zgen_loss.item()

            zgen_opt.zero_grad()
            zgen_loss.backward()
            zgen_opt.step()

            # Compute encoder loss for matching Z input to pixel-space generator
            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_images = xgen(prior)
            prior_recon = encoder(fake_images.detach())
            encoder_loss = F.mse_loss(prior_recon, prior)

            Eloss += encoder_loss.item()

            encoder_opt.zero_grad()
            encoder_loss.backward()
            encoder_opt.step()

        Dloss /= len(mnist_train_loader)
        Gloss /= len(mnist_train_loader)
        Eloss /= len(mnist_train_loader)

        with torch.no_grad():
            prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
            testZ = zgen(prior)
            testX = (xgen(testZ) + 1) * 0.5
            meanZ, stdZ = testZ.mean(), testZ.std()

        print(
            "epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, E-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}"
            .format(epoch + 1, epochs, Dloss, Gloss, Eloss, meanZ, stdZ))

        save_images(os.path.join(RESULT_PATH, f"HPL-grid-{epoch+1}"), testX)

    if not os.path.isdir(MODEL_PATH):
        os.mkdir(MODEL_PATH)

    torch.save(zgen.state_dict(), os.path.join(MODEL_PATH, f"zgen{epochs}.pt"))
    torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"zdis{epochs}.pt"))
    torch.save(encoder.state_dict(), os.path.join(MODEL_PATH,
                                                  f"enc{epochs}.pt"))
Exemplo n.º 4
0
def train_hpl(epochs, state_dict_path_AE):
    """
    Trains latent-space generator which attempts to map the prior Z distribution to
    match the latent distribution of an AutoEncoder.

    Starts with a pre-trained AutoEncoder (which is frozen and unchanged).
    """
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH)

    mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH)
    test_batch, _ = next(
        iter(mnist_test_loader))  # Used for visual checkpoints of progress

    z_size = 128

    autoencoder = AE(input_shape=784, z_size=z_size).to(device)
    if state_dict_path_AE:
        autoencoder.load_state_dict(
            torch.load(os.path.join(MODEL_PATH, state_dict_path_AE),
                       map_location=device))

    zgen = LatentGeneratorMLP(z_size, 32).to(device)
    zgen_opt = optim.Adam(zgen.parameters(), lr=1e-3)

    zdis = LatentDiscriminatorMLP(z_size, 32).to(device)
    zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3)

    print(zgen)
    print(zdis)

    with torch.no_grad():
        prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
        testZ = zgen(prior)
        testX = (autoencoder.decode(testZ) + 1) * 0.5
    save_images(os.path.join(RESULT_PATH, "HPL-grid-0"), testX)

    print(f"Training HPL transfer mapping for {epochs} epochs on MNIST digits")
    for epoch in range(epochs):
        Dloss = 0
        Gloss = 0
        for image_tensors, _ in tqdm(mnist_train_loader):
            ones_target = torch.ones(len(image_tensors), 1).to(device)
            zeros_target = torch.zeros(len(image_tensors), 1).to(device)

            # Compute discriminator loss with real and fake latent codes
            input_features = image_tensors.view(-1, 784).to(device)
            input_features = (input_features * 2) - 1
            real_codes = autoencoder.encode(input_features)
            real_codes = real_codes.detach()

            real_scores = zdis(real_codes)
            real_loss = F.binary_cross_entropy_with_logits(
                real_scores, 0.9 * ones_target)  # Smoothed "real" label

            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_codes = zgen(prior)
            fake_scores = zdis(fake_codes.detach())
            fake_loss = F.binary_cross_entropy_with_logits(
                fake_scores, zeros_target)

            zdis_loss = real_loss + fake_loss
            Dloss += zdis_loss.item()

            zdis_opt.zero_grad()
            zdis_loss.backward()
            zdis_opt.step()

            # Compute generator loss for maximizing fooling of zdis
            prior = (torch.rand(len(image_tensors), z_size).to(device) * 2) - 1
            fake_codes = zgen(prior)
            fake_scores = zdis(fake_codes)
            zgen_loss = F.binary_cross_entropy_with_logits(
                fake_scores, ones_target)
            Gloss += zgen_loss.item()

            zgen_opt.zero_grad()
            zgen_loss.backward()
            zgen_opt.step()

        Dloss /= len(mnist_train_loader)
        Gloss /= len(mnist_train_loader)

        with torch.no_grad():
            prior = (torch.rand(len(test_batch), z_size).to(device) * 2) - 1
            testZ = zgen(prior)
            testX = (autoencoder.decode(testZ) + 1) * 0.5
            meanZ, stdZ = testZ.mean(), testZ.std()

        print(
            "epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}"
            .format(epoch + 1, epochs, Dloss, Gloss, meanZ, stdZ))

        save_images(os.path.join(RESULT_PATH, f"HPL-grid-{epoch+1}"), testX)

    if not os.path.isdir(MODEL_PATH):
        os.mkdir(MODEL_PATH)

    torch.save(zgen.state_dict(), os.path.join(MODEL_PATH, f"zgen{epochs}.pt"))
    torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"zdis{epochs}.pt"))
Exemplo n.º 5
0
def train_adversarial_autoencoder(epochs):
    """
    Note "real_codes" and "fake_codes" are reversed compared to HPL
    real --> from prior ~ Uniform [0, 1)
    fake --> from AAE encoder (generator)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not os.path.isdir(DATASET_PATH):
        os.mkdir(DATASET_PATH)

    mnist_train_loader = get_mnist_train_data(store_location=DATASET_PATH)
    
    mnist_test_loader = get_mnist_test_data(store_location=DATASET_PATH)
    test_batch, _ = next(iter(mnist_test_loader))  # Used for visual checkpoints of progress

    autoencoder = AE(input_shape=784).to(device)
    ae_opt = optim.Adam(autoencoder.parameters(), lr=1e-3)

    zdis = LatentDiscriminatorMLP(128).to(device)
    zdis_opt = optim.Adam(zdis.parameters(), lr=1e-3)

    with torch.no_grad():
        testZ = torch.rand(len(test_batch), 128).to(device)  # Directly from random uniform distribution [0, 1)
        testX = autoencoder.decode(testZ)
    save_images(os.path.join(RESULT_PATH, "AAE-grid-0"), testX)

    print(f"Training AAE for {epochs} epochs on MNIST digits")
    for epoch in range(epochs):
        Dloss = 0
        Gloss = 0
        AEloss = 0
        for image_tensors, _ in tqdm(mnist_train_loader):
            ones_target = torch.ones(len(image_tensors), 1).to(device)
            zeros_target = torch.zeros(len(image_tensors), 1).to(device)

            # Compute discriminator loss with real and fake latent codes
            input_features = image_tensors.view(-1, 784).to(device)
            fake_codes = autoencoder.encode(input_features)

            fake_scores = zdis(fake_codes.detach())
            fake_loss = F.binary_cross_entropy(fake_scores, zeros_target)

            real_codes = torch.rand(len(image_tensors), 128).to(device)
            real_scores = zdis(real_codes.detach())
            real_loss = F.binary_cross_entropy(real_scores, 0.9 * ones_target)

            zdis_loss = real_loss + fake_loss
            Dloss += zdis_loss.item()

            zdis_opt.zero_grad()
            zdis_loss.backward()
            zdis_opt.step()

            # Compute AAE loss = reconstruction loss + generator loss for maximizing fooling of zdis
            fake_scores = zdis(fake_codes)
            gen_loss = F.binary_cross_entropy(fake_scores, ones_target)
            Gloss += gen_loss.item()

            reconstructions = autoencoder.decode(fake_codes)
            recon_loss = F.mse_loss(reconstructions, input_features)
            AEloss += recon_loss.item()

            ae_loss = gen_loss + recon_loss

            ae_opt.zero_grad()
            ae_loss.backward()
            ae_opt.step()

        num_batch = len(mnist_train_loader)
        Dloss /= num_batch
        Gloss /= num_batch
        AEloss /= num_batch

        with torch.no_grad():
            prior = torch.rand(len(test_batch), 128).to(device)  # Directly from random uniform distribution [0, 1)
            testX = autoencoder.decode(prior)

            test_features = test_batch.view(-1, 784).to(device)
            testZ = autoencoder.encode(test_features)
            meanZ, stdZ = testZ.mean(), testZ.std()

        print("epoch : {}/{}, D-loss = {:.6f}, G-loss = {:.6f}, AE-loss = {:.6f}, mean-Z = {:.3f}, std-Z = {:.3f}".format(epoch + 1, epochs, Dloss, Gloss, AEloss, meanZ, stdZ))

        save_images(os.path.join(RESULT_PATH, f"AAE-grid-{epoch+1}"), testX)

    if not os.path.isdir(MODEL_PATH):
        os.mkdir(MODEL_PATH)

    torch.save(autoencoder.state_dict(), os.path.join(MODEL_PATH, f"aae{epochs}.pt"))
    torch.save(zdis.state_dict(), os.path.join(MODEL_PATH, f"pzdis{epochs}.pt"))