Example #1
0
def run_gan():
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

    train_images = train_images.reshape(train_images.shape[0], 28, 28,
                                        1).astype('float32')
    train_images = (train_images - 127.5) / 127.5  # Normalize images to [-1,1]
    print(train_images.shape)

    train_labels = to_categorical(train_labels)
    print(train_labels.shape)

    # Batch and shuffle the data
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    gan = CGAN(gen_lr, disc_lr, noise_dim=NOISE_DIM)
    gan.create_generator()
    gan.create_discriminator()

    if model_test:
        # Test generator
        random_noise = tf.random.normal([1, NOISE_DIM])
        condition = tf.zeros(shape=(1, 10))
        generated_image = gan.generator([random_noise, condition])
        plt.imshow(generated_image[0, :, :, 0], cmap='gray')
        plt.show()
        # Test Discriminator
        prob = gan.discriminator([generated_image, condition])
        print("Probability of image being real: {}".format(sigmoid(prob)))

    gan.set_noise_seed(num_examples_to_generate)
    print(gan.label_seed.shape)
    gan.set_checkpoint(path=save_ckpt_path)
    gen_loss_array, disc_loss_array = gan.train(train_dataset, epochs=EPOCHS)

    # Plot Discriminator Loss
    plt.plot(range(EPOCHS), gen_loss_array)
    plt.plot(range(EPOCHS), disc_loss_array)
    plt.show()
Example #2
0
generated_imgs = []

fixed_noise = [torch.randn(1, 100, 1, 1) for _ in range(5)]

for i, c in enumerate(conditions):
    try:
        noise = fixed_noise[i % 5]
    except Exception:
        raise IndexError
        print(i)
        print(25 % (i + 1))
        print(len(fixed_noise))
        exit(-1)

    c = vocab.encode_feature(c)
    c = np.expand_dims(c, 0) * 0.9 + 0.05
    noise, c = Variable(torch.Tensor(noise)), Variable(torch.Tensor(c))

    img_v = model.generator(noise, c)
    img = cvt_output(img_v)
    generated_imgs.append(img)

imgs = np.array(generated_imgs)
save_imgs(imgs)
"""
imgs = Variable(torch.Tensor(imgs))

torchvision.utils.save_image(imgs.data, 'test.jpg', nrow=5)
"""